Skip to content

Commit 5565027

Browse files
authored
feat: Builder and HugrMut add_op_xxx default to open extensions (#622)
A lot of this is refactoring existing code that explicitly uses `add_node_xxx(NodeType::open_extensions(...))` to `add_op_xxx(...)`, and updating tests - in most cases just `validate` -> `update_validate`. Algorithmically, there was only one change - in `infer.rs` - where Functions, FuncDefs and Aliases are given the empty ExtensionSet if they are open (otherwise these end up unsolved). Also there are some gymnastics in a couple of the tests where we want the correct error out of validation but can't do inference because the Hugr is, well, invalid....so we use a couple of different techniques to carry solutions from earlier, valid, Hugrs over to the invalid one. closes #424
1 parent 92b936e commit 5565027

File tree

8 files changed

+126
-130
lines changed

8 files changed

+126
-130
lines changed

src/algorithm/nest_cfgs.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,7 @@ pub(crate) mod test {
646646
])
647647
);
648648
transform_cfg_to_nested(&mut IdentityCfgMap::new(rc));
649-
h.validate(&PRELUDE_REGISTRY).unwrap();
649+
h.update_validate(&PRELUDE_REGISTRY).unwrap();
650650
assert_eq!(1, depth(&h, entry));
651651
assert_eq!(1, depth(&h, exit));
652652
for n in [split, left, right, merge, head, tail] {
@@ -753,7 +753,7 @@ pub(crate) mod test {
753753
let root = h.root();
754754
let m = SiblingMut::<CfgID>::try_new(&mut h, root).unwrap();
755755
transform_cfg_to_nested(&mut IdentityCfgMap::new(m));
756-
h.validate(&PRELUDE_REGISTRY).unwrap();
756+
h.update_validate(&PRELUDE_REGISTRY).unwrap();
757757
assert_eq!(1, depth(&h, entry));
758758
assert_eq!(3, depth(&h, head));
759759
for n in [split, left, right, merge] {

src/builder.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -149,18 +149,18 @@ pub(crate) mod test {
149149
let mut hugr = Hugr::new(NodeType::pure(ops::DFG {
150150
signature: signature.clone(),
151151
}));
152-
hugr.add_node_with_parent(
152+
hugr.add_op_with_parent(
153153
hugr.root(),
154-
NodeType::open_extensions(ops::Input {
154+
ops::Input {
155155
types: signature.input,
156-
}),
156+
},
157157
)
158158
.unwrap();
159-
hugr.add_node_with_parent(
159+
hugr.add_op_with_parent(
160160
hugr.root(),
161-
NodeType::open_extensions(ops::Output {
161+
ops::Output {
162162
types: signature.output,
163-
}),
163+
},
164164
)
165165
.unwrap();
166166
hugr

src/builder/conditional.rs

+2-3
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,9 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> ConditionalBuilder<B> {
126126
let case_node =
127127
// add case before any existing subsequent cases
128128
if let Some(&sibling_node) = self.case_nodes[case + 1..].iter().flatten().next() {
129-
// TODO: Allow this to be non-pure
130-
self.hugr_mut().add_node_before(sibling_node, NodeType::open_extensions(case_op))?
129+
self.hugr_mut().add_op_before(sibling_node, case_op)?
131130
} else {
132-
self.add_child_node(NodeType::open_extensions(case_op))?
131+
self.add_child_op(case_op)?
133132
};
134133

135134
self.case_nodes[case] = Some(case_node);

src/extension/infer.rs

+75-81
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,12 @@ impl UnificationContext {
316316
m_output,
317317
node_type.op_signature().extension_reqs,
318318
);
319+
if matches!(
320+
node_type.tag(),
321+
OpTag::Alias | OpTag::Function | OpTag::FuncDefn
322+
) {
323+
self.add_solution(m_input, ExtensionSet::new());
324+
}
319325
}
320326
// We have a solution for everything!
321327
Some(sig) => {
@@ -337,16 +343,16 @@ impl UnificationContext {
337343
| Some(EdgeKind::ControlFlow)
338344
)
339345
}) {
346+
let m_tgt = *self
347+
.extensions
348+
.get(&(tgt_node, Direction::Incoming))
349+
.unwrap();
340350
for (src_node, _) in hugr.linked_ports(tgt_node, port) {
341351
let m_src = self
342352
.extensions
343353
.get(&(src_node, Direction::Outgoing))
344354
.unwrap();
345-
let m_tgt = self
346-
.extensions
347-
.get(&(tgt_node, Direction::Incoming))
348-
.unwrap();
349-
self.add_constraint(*m_src, Constraint::Equal(*m_tgt));
355+
self.add_constraint(*m_src, Constraint::Equal(m_tgt));
350356
}
351357
}
352358
}
@@ -720,11 +726,11 @@ mod test {
720726
let root_node = NodeType::open_extensions(op);
721727
let mut hugr = Hugr::new(root_node);
722728

723-
let input = NodeType::open_extensions(ops::Input::new(type_row![NAT, NAT]));
724-
let output = NodeType::open_extensions(ops::Output::new(type_row![NAT]));
729+
let input = ops::Input::new(type_row![NAT, NAT]);
730+
let output = ops::Output::new(type_row![NAT]);
725731

726-
let input = hugr.add_node_with_parent(hugr.root(), input)?;
727-
let output = hugr.add_node_with_parent(hugr.root(), output)?;
732+
let input = hugr.add_op_with_parent(hugr.root(), input)?;
733+
let output = hugr.add_op_with_parent(hugr.root(), output)?;
728734

729735
assert_matches!(hugr.get_io(hugr.root()), Some(_));
730736

@@ -740,29 +746,29 @@ mod test {
740746
let mult_c_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT])
741747
.with_extension_delta(&ExtensionSet::singleton(&C));
742748

743-
let add_a = hugr.add_node_with_parent(
749+
let add_a = hugr.add_op_with_parent(
744750
hugr.root(),
745-
NodeType::open_extensions(ops::DFG {
751+
ops::DFG {
746752
signature: add_a_sig,
747-
}),
753+
},
748754
)?;
749-
let add_b = hugr.add_node_with_parent(
755+
let add_b = hugr.add_op_with_parent(
750756
hugr.root(),
751-
NodeType::open_extensions(ops::DFG {
757+
ops::DFG {
752758
signature: add_b_sig,
753-
}),
759+
},
754760
)?;
755-
let add_ab = hugr.add_node_with_parent(
761+
let add_ab = hugr.add_op_with_parent(
756762
hugr.root(),
757-
NodeType::open_extensions(ops::DFG {
763+
ops::DFG {
758764
signature: add_ab_sig,
759-
}),
765+
},
760766
)?;
761-
let mult_c = hugr.add_node_with_parent(
767+
let mult_c = hugr.add_op_with_parent(
762768
hugr.root(),
763-
NodeType::open_extensions(ops::DFG {
769+
ops::DFG {
764770
signature: mult_c_sig,
765-
}),
771+
},
766772
)?;
767773

768774
hugr.connect(input, 0, add_a, 0)?;
@@ -896,29 +902,26 @@ mod test {
896902
let [input, output] = hugr.get_io(hugr.root()).unwrap();
897903
let add_r_sig = FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs);
898904

899-
let add_r = hugr.add_node_with_parent(
905+
let add_r = hugr.add_op_with_parent(
900906
hugr.root(),
901-
NodeType::open_extensions(ops::DFG {
907+
ops::DFG {
902908
signature: add_r_sig,
903-
}),
909+
},
904910
)?;
905911

906912
// Dangling thingy
907913
let src_sig = FunctionType::new(type_row![], type_row![NAT])
908914
.with_extension_delta(&ExtensionSet::new());
909915

910-
let src = hugr.add_node_with_parent(
911-
hugr.root(),
912-
NodeType::open_extensions(ops::DFG { signature: src_sig }),
913-
)?;
916+
let src = hugr.add_op_with_parent(hugr.root(), ops::DFG { signature: src_sig })?;
914917

915918
let mult_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT]);
916919
// Mult has open extension requirements, which we should solve to be "R"
917-
let mult = hugr.add_node_with_parent(
920+
let mult = hugr.add_op_with_parent(
918921
hugr.root(),
919-
NodeType::open_extensions(ops::DFG {
922+
ops::DFG {
920923
signature: mult_sig,
921-
}),
924+
},
922925
)?;
923926

924927
hugr.connect(input, 0, add_r, 0)?;
@@ -978,18 +981,18 @@ mod test {
978981
) -> Result<[Node; 3], Box<dyn Error>> {
979982
let op: OpType = op.into();
980983

981-
let node = hugr.add_node_with_parent(parent, NodeType::open_extensions(op))?;
982-
let input = hugr.add_node_with_parent(
984+
let node = hugr.add_op_with_parent(parent, op)?;
985+
let input = hugr.add_op_with_parent(
983986
node,
984-
NodeType::open_extensions(ops::Input {
987+
ops::Input {
985988
types: op_sig.input,
986-
}),
989+
},
987990
)?;
988-
let output = hugr.add_node_with_parent(
991+
let output = hugr.add_op_with_parent(
989992
node,
990-
NodeType::open_extensions(ops::Output {
993+
ops::Output {
991994
types: op_sig.output,
992-
}),
995+
},
993996
)?;
994997
Ok([node, input, output])
995998
}
@@ -1010,20 +1013,20 @@ mod test {
10101013
Into::<OpType>::into(op).signature(),
10111014
)?;
10121015

1013-
let lift1 = hugr.add_node_with_parent(
1016+
let lift1 = hugr.add_op_with_parent(
10141017
case,
1015-
NodeType::open_extensions(ops::LeafOp::Lift {
1018+
ops::LeafOp::Lift {
10161019
type_row: type_row![NAT],
10171020
new_extension: first_ext,
1018-
}),
1021+
},
10191022
)?;
10201023

1021-
let lift2 = hugr.add_node_with_parent(
1024+
let lift2 = hugr.add_op_with_parent(
10221025
case,
1023-
NodeType::open_extensions(ops::LeafOp::Lift {
1026+
ops::LeafOp::Lift {
10241027
type_row: type_row![NAT],
10251028
new_extension: second_ext,
1026-
}),
1029+
},
10271030
)?;
10281031

10291032
hugr.connect(case_in, 0, lift1, 0)?;
@@ -1088,17 +1091,17 @@ mod test {
10881091
}));
10891092

10901093
let root = hugr.root();
1091-
let input = hugr.add_node_with_parent(
1094+
let input = hugr.add_op_with_parent(
10921095
root,
1093-
NodeType::open_extensions(ops::Input {
1096+
ops::Input {
10941097
types: type_row![NAT],
1095-
}),
1098+
},
10961099
)?;
1097-
let output = hugr.add_node_with_parent(
1100+
let output = hugr.add_op_with_parent(
10981101
root,
1099-
NodeType::open_extensions(ops::Output {
1102+
ops::Output {
11001103
types: type_row![NAT],
1101-
}),
1104+
},
11021105
)?;
11031106

11041107
// Make identical dataflow nodes which add extension requirement "A" or "B"
@@ -1119,12 +1122,12 @@ mod test {
11191122
.unwrap();
11201123

11211124
let lift = hugr
1122-
.add_node_with_parent(
1125+
.add_op_with_parent(
11231126
node,
1124-
NodeType::open_extensions(ops::LeafOp::Lift {
1127+
ops::LeafOp::Lift {
11251128
type_row: type_row![NAT],
11261129
new_extension: ext,
1127-
}),
1130+
},
11281131
)
11291132
.unwrap();
11301133

@@ -1171,7 +1174,7 @@ mod test {
11711174

11721175
let [bb, bb_in, bb_out] = create_with_io(hugr, bb_parent, dfb, dfb_sig)?;
11731176

1174-
let dfg = hugr.add_node_with_parent(bb, NodeType::open_extensions(op))?;
1177+
let dfg = hugr.add_op_with_parent(bb, op)?;
11751178

11761179
hugr.connect(bb_in, 0, dfg, 0)?;
11771180
hugr.connect(dfg, 0, bb_out, 0)?;
@@ -1203,23 +1206,20 @@ mod test {
12031206
extension_delta: entry_extensions,
12041207
};
12051208

1206-
let exit = hugr.add_node_with_parent(
1209+
let exit = hugr.add_op_with_parent(
12071210
root,
1208-
NodeType::open_extensions(ops::BasicBlock::Exit {
1211+
ops::BasicBlock::Exit {
12091212
cfg_outputs: exit_types.into(),
1210-
}),
1213+
},
12111214
)?;
12121215

1213-
let entry = hugr.add_node_before(exit, NodeType::open_extensions(dfb))?;
1214-
let entry_in = hugr.add_node_with_parent(
1216+
let entry = hugr.add_op_before(exit, dfb)?;
1217+
let entry_in = hugr.add_op_with_parent(entry, ops::Input { types: inputs })?;
1218+
let entry_out = hugr.add_op_with_parent(
12151219
entry,
1216-
NodeType::open_extensions(ops::Input { types: inputs }),
1217-
)?;
1218-
let entry_out = hugr.add_node_with_parent(
1219-
entry,
1220-
NodeType::open_extensions(ops::Output {
1220+
ops::Output {
12211221
types: vec![entry_tuple_sum].into(),
1222-
}),
1222+
},
12231223
)?;
12241224

12251225
Ok(([entry, entry_in, entry_out], exit))
@@ -1270,12 +1270,12 @@ mod test {
12701270
type_row![NAT],
12711271
)?;
12721272

1273-
let mkpred = hugr.add_node_with_parent(
1273+
let mkpred = hugr.add_op_with_parent(
12741274
entry,
1275-
NodeType::open_extensions(make_opaque(
1275+
make_opaque(
12761276
A,
12771277
FunctionType::new(vec![NAT], twoway(NAT)).with_extension_delta(&a),
1278-
)),
1278+
),
12791279
)?;
12801280

12811281
// Internal wiring for DFGs
@@ -1366,12 +1366,9 @@ mod test {
13661366
type_row![NAT],
13671367
)?;
13681368

1369-
let entry_mid = hugr.add_node_with_parent(
1369+
let entry_mid = hugr.add_op_with_parent(
13701370
entry,
1371-
NodeType::open_extensions(make_opaque(
1372-
UNKNOWN_EXTENSION,
1373-
FunctionType::new(vec![NAT], twoway(NAT)),
1374-
)),
1371+
make_opaque(UNKNOWN_EXTENSION, FunctionType::new(vec![NAT], twoway(NAT))),
13751372
)?;
13761373

13771374
hugr.connect(entry_in, 0, entry_mid, 0)?;
@@ -1455,12 +1452,12 @@ mod test {
14551452
type_row![NAT],
14561453
)?;
14571454

1458-
let entry_dfg = hugr.add_node_with_parent(
1455+
let entry_dfg = hugr.add_op_with_parent(
14591456
entry,
1460-
NodeType::open_extensions(make_opaque(
1457+
make_opaque(
14611458
UNKNOWN_EXTENSION,
14621459
FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(&entry_ext),
1463-
)),
1460+
),
14641461
)?;
14651462

14661463
hugr.connect(entry_in, 0, entry_dfg, 0)?;
@@ -1536,12 +1533,9 @@ mod test {
15361533
type_row![NAT],
15371534
)?;
15381535

1539-
let entry_mid = hugr.add_node_with_parent(
1536+
let entry_mid = hugr.add_op_with_parent(
15401537
entry,
1541-
NodeType::open_extensions(make_opaque(
1542-
UNKNOWN_EXTENSION,
1543-
FunctionType::new(vec![NAT], oneway(NAT)),
1544-
)),
1538+
make_opaque(UNKNOWN_EXTENSION, FunctionType::new(vec![NAT], oneway(NAT))),
15451539
)?;
15461540

15471541
hugr.connect(entry_in, 0, entry_mid, 0)?;

0 commit comments

Comments
 (0)