@@ -316,6 +316,12 @@ impl UnificationContext {
316
316
m_output,
317
317
node_type. op_signature ( ) . extension_reqs ,
318
318
) ;
319
+ if matches ! (
320
+ node_type. tag( ) ,
321
+ OpTag :: Alias | OpTag :: Function | OpTag :: FuncDefn
322
+ ) {
323
+ self . add_solution ( m_input, ExtensionSet :: new ( ) ) ;
324
+ }
319
325
}
320
326
// We have a solution for everything!
321
327
Some ( sig) => {
@@ -337,16 +343,16 @@ impl UnificationContext {
337
343
| Some ( EdgeKind :: ControlFlow )
338
344
)
339
345
} ) {
346
+ let m_tgt = * self
347
+ . extensions
348
+ . get ( & ( tgt_node, Direction :: Incoming ) )
349
+ . unwrap ( ) ;
340
350
for ( src_node, _) in hugr. linked_ports ( tgt_node, port) {
341
351
let m_src = self
342
352
. extensions
343
353
. get ( & ( src_node, Direction :: Outgoing ) )
344
354
. unwrap ( ) ;
345
- let m_tgt = self
346
- . extensions
347
- . get ( & ( tgt_node, Direction :: Incoming ) )
348
- . unwrap ( ) ;
349
- self . add_constraint ( * m_src, Constraint :: Equal ( * m_tgt) ) ;
355
+ self . add_constraint ( * m_src, Constraint :: Equal ( m_tgt) ) ;
350
356
}
351
357
}
352
358
}
@@ -720,11 +726,11 @@ mod test {
720
726
let root_node = NodeType :: open_extensions ( op) ;
721
727
let mut hugr = Hugr :: new ( root_node) ;
722
728
723
- let input = NodeType :: open_extensions ( ops:: Input :: new ( type_row ! [ NAT , NAT ] ) ) ;
724
- let output = NodeType :: open_extensions ( ops:: Output :: new ( type_row ! [ NAT ] ) ) ;
729
+ let input = ops:: Input :: new ( type_row ! [ NAT , NAT ] ) ;
730
+ let output = ops:: Output :: new ( type_row ! [ NAT ] ) ;
725
731
726
- let input = hugr. add_node_with_parent ( hugr. root ( ) , input) ?;
727
- let output = hugr. add_node_with_parent ( hugr. root ( ) , output) ?;
732
+ let input = hugr. add_op_with_parent ( hugr. root ( ) , input) ?;
733
+ let output = hugr. add_op_with_parent ( hugr. root ( ) , output) ?;
728
734
729
735
assert_matches ! ( hugr. get_io( hugr. root( ) ) , Some ( _) ) ;
730
736
@@ -740,29 +746,29 @@ mod test {
740
746
let mult_c_sig = FunctionType :: new ( type_row ! [ NAT , NAT ] , type_row ! [ NAT ] )
741
747
. with_extension_delta ( & ExtensionSet :: singleton ( & C ) ) ;
742
748
743
- let add_a = hugr. add_node_with_parent (
749
+ let add_a = hugr. add_op_with_parent (
744
750
hugr. root ( ) ,
745
- NodeType :: open_extensions ( ops:: DFG {
751
+ ops:: DFG {
746
752
signature : add_a_sig,
747
- } ) ,
753
+ } ,
748
754
) ?;
749
- let add_b = hugr. add_node_with_parent (
755
+ let add_b = hugr. add_op_with_parent (
750
756
hugr. root ( ) ,
751
- NodeType :: open_extensions ( ops:: DFG {
757
+ ops:: DFG {
752
758
signature : add_b_sig,
753
- } ) ,
759
+ } ,
754
760
) ?;
755
- let add_ab = hugr. add_node_with_parent (
761
+ let add_ab = hugr. add_op_with_parent (
756
762
hugr. root ( ) ,
757
- NodeType :: open_extensions ( ops:: DFG {
763
+ ops:: DFG {
758
764
signature : add_ab_sig,
759
- } ) ,
765
+ } ,
760
766
) ?;
761
- let mult_c = hugr. add_node_with_parent (
767
+ let mult_c = hugr. add_op_with_parent (
762
768
hugr. root ( ) ,
763
- NodeType :: open_extensions ( ops:: DFG {
769
+ ops:: DFG {
764
770
signature : mult_c_sig,
765
- } ) ,
771
+ } ,
766
772
) ?;
767
773
768
774
hugr. connect ( input, 0 , add_a, 0 ) ?;
@@ -896,29 +902,26 @@ mod test {
896
902
let [ input, output] = hugr. get_io ( hugr. root ( ) ) . unwrap ( ) ;
897
903
let add_r_sig = FunctionType :: new ( type_row ! [ NAT ] , type_row ! [ NAT ] ) . with_extension_delta ( & rs) ;
898
904
899
- let add_r = hugr. add_node_with_parent (
905
+ let add_r = hugr. add_op_with_parent (
900
906
hugr. root ( ) ,
901
- NodeType :: open_extensions ( ops:: DFG {
907
+ ops:: DFG {
902
908
signature : add_r_sig,
903
- } ) ,
909
+ } ,
904
910
) ?;
905
911
906
912
// Dangling thingy
907
913
let src_sig = FunctionType :: new ( type_row ! [ ] , type_row ! [ NAT ] )
908
914
. with_extension_delta ( & ExtensionSet :: new ( ) ) ;
909
915
910
- let src = hugr. add_node_with_parent (
911
- hugr. root ( ) ,
912
- NodeType :: open_extensions ( ops:: DFG { signature : src_sig } ) ,
913
- ) ?;
916
+ let src = hugr. add_op_with_parent ( hugr. root ( ) , ops:: DFG { signature : src_sig } ) ?;
914
917
915
918
let mult_sig = FunctionType :: new ( type_row ! [ NAT , NAT ] , type_row ! [ NAT ] ) ;
916
919
// Mult has open extension requirements, which we should solve to be "R"
917
- let mult = hugr. add_node_with_parent (
920
+ let mult = hugr. add_op_with_parent (
918
921
hugr. root ( ) ,
919
- NodeType :: open_extensions ( ops:: DFG {
922
+ ops:: DFG {
920
923
signature : mult_sig,
921
- } ) ,
924
+ } ,
922
925
) ?;
923
926
924
927
hugr. connect ( input, 0 , add_r, 0 ) ?;
@@ -978,18 +981,18 @@ mod test {
978
981
) -> Result < [ Node ; 3 ] , Box < dyn Error > > {
979
982
let op: OpType = op. into ( ) ;
980
983
981
- let node = hugr. add_node_with_parent ( parent, NodeType :: open_extensions ( op ) ) ?;
982
- let input = hugr. add_node_with_parent (
984
+ let node = hugr. add_op_with_parent ( parent, op ) ?;
985
+ let input = hugr. add_op_with_parent (
983
986
node,
984
- NodeType :: open_extensions ( ops:: Input {
987
+ ops:: Input {
985
988
types : op_sig. input ,
986
- } ) ,
989
+ } ,
987
990
) ?;
988
- let output = hugr. add_node_with_parent (
991
+ let output = hugr. add_op_with_parent (
989
992
node,
990
- NodeType :: open_extensions ( ops:: Output {
993
+ ops:: Output {
991
994
types : op_sig. output ,
992
- } ) ,
995
+ } ,
993
996
) ?;
994
997
Ok ( [ node, input, output] )
995
998
}
@@ -1010,20 +1013,20 @@ mod test {
1010
1013
Into :: < OpType > :: into ( op) . signature ( ) ,
1011
1014
) ?;
1012
1015
1013
- let lift1 = hugr. add_node_with_parent (
1016
+ let lift1 = hugr. add_op_with_parent (
1014
1017
case,
1015
- NodeType :: open_extensions ( ops:: LeafOp :: Lift {
1018
+ ops:: LeafOp :: Lift {
1016
1019
type_row : type_row ! [ NAT ] ,
1017
1020
new_extension : first_ext,
1018
- } ) ,
1021
+ } ,
1019
1022
) ?;
1020
1023
1021
- let lift2 = hugr. add_node_with_parent (
1024
+ let lift2 = hugr. add_op_with_parent (
1022
1025
case,
1023
- NodeType :: open_extensions ( ops:: LeafOp :: Lift {
1026
+ ops:: LeafOp :: Lift {
1024
1027
type_row : type_row ! [ NAT ] ,
1025
1028
new_extension : second_ext,
1026
- } ) ,
1029
+ } ,
1027
1030
) ?;
1028
1031
1029
1032
hugr. connect ( case_in, 0 , lift1, 0 ) ?;
@@ -1088,17 +1091,17 @@ mod test {
1088
1091
} ) ) ;
1089
1092
1090
1093
let root = hugr. root ( ) ;
1091
- let input = hugr. add_node_with_parent (
1094
+ let input = hugr. add_op_with_parent (
1092
1095
root,
1093
- NodeType :: open_extensions ( ops:: Input {
1096
+ ops:: Input {
1094
1097
types : type_row ! [ NAT ] ,
1095
- } ) ,
1098
+ } ,
1096
1099
) ?;
1097
- let output = hugr. add_node_with_parent (
1100
+ let output = hugr. add_op_with_parent (
1098
1101
root,
1099
- NodeType :: open_extensions ( ops:: Output {
1102
+ ops:: Output {
1100
1103
types : type_row ! [ NAT ] ,
1101
- } ) ,
1104
+ } ,
1102
1105
) ?;
1103
1106
1104
1107
// Make identical dataflow nodes which add extension requirement "A" or "B"
@@ -1119,12 +1122,12 @@ mod test {
1119
1122
. unwrap ( ) ;
1120
1123
1121
1124
let lift = hugr
1122
- . add_node_with_parent (
1125
+ . add_op_with_parent (
1123
1126
node,
1124
- NodeType :: open_extensions ( ops:: LeafOp :: Lift {
1127
+ ops:: LeafOp :: Lift {
1125
1128
type_row : type_row ! [ NAT ] ,
1126
1129
new_extension : ext,
1127
- } ) ,
1130
+ } ,
1128
1131
)
1129
1132
. unwrap ( ) ;
1130
1133
@@ -1171,7 +1174,7 @@ mod test {
1171
1174
1172
1175
let [ bb, bb_in, bb_out] = create_with_io ( hugr, bb_parent, dfb, dfb_sig) ?;
1173
1176
1174
- let dfg = hugr. add_node_with_parent ( bb, NodeType :: open_extensions ( op ) ) ?;
1177
+ let dfg = hugr. add_op_with_parent ( bb, op ) ?;
1175
1178
1176
1179
hugr. connect ( bb_in, 0 , dfg, 0 ) ?;
1177
1180
hugr. connect ( dfg, 0 , bb_out, 0 ) ?;
@@ -1203,23 +1206,20 @@ mod test {
1203
1206
extension_delta : entry_extensions,
1204
1207
} ;
1205
1208
1206
- let exit = hugr. add_node_with_parent (
1209
+ let exit = hugr. add_op_with_parent (
1207
1210
root,
1208
- NodeType :: open_extensions ( ops:: BasicBlock :: Exit {
1211
+ ops:: BasicBlock :: Exit {
1209
1212
cfg_outputs : exit_types. into ( ) ,
1210
- } ) ,
1213
+ } ,
1211
1214
) ?;
1212
1215
1213
- let entry = hugr. add_node_before ( exit, NodeType :: open_extensions ( dfb) ) ?;
1214
- let entry_in = hugr. add_node_with_parent (
1216
+ let entry = hugr. add_op_before ( exit, dfb) ?;
1217
+ let entry_in = hugr. add_op_with_parent ( entry, ops:: Input { types : inputs } ) ?;
1218
+ let entry_out = hugr. add_op_with_parent (
1215
1219
entry,
1216
- NodeType :: open_extensions ( ops:: Input { types : inputs } ) ,
1217
- ) ?;
1218
- let entry_out = hugr. add_node_with_parent (
1219
- entry,
1220
- NodeType :: open_extensions ( ops:: Output {
1220
+ ops:: Output {
1221
1221
types : vec ! [ entry_tuple_sum] . into ( ) ,
1222
- } ) ,
1222
+ } ,
1223
1223
) ?;
1224
1224
1225
1225
Ok ( ( [ entry, entry_in, entry_out] , exit) )
@@ -1270,12 +1270,12 @@ mod test {
1270
1270
type_row ! [ NAT ] ,
1271
1271
) ?;
1272
1272
1273
- let mkpred = hugr. add_node_with_parent (
1273
+ let mkpred = hugr. add_op_with_parent (
1274
1274
entry,
1275
- NodeType :: open_extensions ( make_opaque (
1275
+ make_opaque (
1276
1276
A ,
1277
1277
FunctionType :: new ( vec ! [ NAT ] , twoway ( NAT ) ) . with_extension_delta ( & a) ,
1278
- ) ) ,
1278
+ ) ,
1279
1279
) ?;
1280
1280
1281
1281
// Internal wiring for DFGs
@@ -1366,12 +1366,9 @@ mod test {
1366
1366
type_row ! [ NAT ] ,
1367
1367
) ?;
1368
1368
1369
- let entry_mid = hugr. add_node_with_parent (
1369
+ let entry_mid = hugr. add_op_with_parent (
1370
1370
entry,
1371
- NodeType :: open_extensions ( make_opaque (
1372
- UNKNOWN_EXTENSION ,
1373
- FunctionType :: new ( vec ! [ NAT ] , twoway ( NAT ) ) ,
1374
- ) ) ,
1371
+ make_opaque ( UNKNOWN_EXTENSION , FunctionType :: new ( vec ! [ NAT ] , twoway ( NAT ) ) ) ,
1375
1372
) ?;
1376
1373
1377
1374
hugr. connect ( entry_in, 0 , entry_mid, 0 ) ?;
@@ -1455,12 +1452,12 @@ mod test {
1455
1452
type_row ! [ NAT ] ,
1456
1453
) ?;
1457
1454
1458
- let entry_dfg = hugr. add_node_with_parent (
1455
+ let entry_dfg = hugr. add_op_with_parent (
1459
1456
entry,
1460
- NodeType :: open_extensions ( make_opaque (
1457
+ make_opaque (
1461
1458
UNKNOWN_EXTENSION ,
1462
1459
FunctionType :: new ( vec ! [ NAT ] , oneway ( NAT ) ) . with_extension_delta ( & entry_ext) ,
1463
- ) ) ,
1460
+ ) ,
1464
1461
) ?;
1465
1462
1466
1463
hugr. connect ( entry_in, 0 , entry_dfg, 0 ) ?;
@@ -1536,12 +1533,9 @@ mod test {
1536
1533
type_row ! [ NAT ] ,
1537
1534
) ?;
1538
1535
1539
- let entry_mid = hugr. add_node_with_parent (
1536
+ let entry_mid = hugr. add_op_with_parent (
1540
1537
entry,
1541
- NodeType :: open_extensions ( make_opaque (
1542
- UNKNOWN_EXTENSION ,
1543
- FunctionType :: new ( vec ! [ NAT ] , oneway ( NAT ) ) ,
1544
- ) ) ,
1538
+ make_opaque ( UNKNOWN_EXTENSION , FunctionType :: new ( vec ! [ NAT ] , oneway ( NAT ) ) ) ,
1545
1539
) ?;
1546
1540
1547
1541
hugr. connect ( entry_in, 0 , entry_mid, 0 ) ?;
0 commit comments