@@ -193,7 +193,7 @@ impl Hugr {
193
193
rw. apply ( self )
194
194
}
195
195
196
- /// Infer extension requirements
196
+ /// Infer extension requirements and add new information to `op_types` field
197
197
pub fn infer_extensions (
198
198
& mut self ,
199
199
) -> Result < HashMap < ( Node , Direction ) , ExtensionSet > , InferExtensionError > {
@@ -202,9 +202,22 @@ impl Hugr {
202
202
Ok ( extension_closure)
203
203
}
204
204
205
- /// TODO: Write this
206
- fn instantiate_extensions ( & mut self , _solution : ExtensionSolution ) {
207
- //todo!()
205
+ /// Add extension requirement information to the hugr in place.
206
+ fn instantiate_extensions ( & mut self , solution : ExtensionSolution ) {
207
+ // We only care about inferred _input_ extensions, because `NodeType`
208
+ // uses those to infer the output extensions
209
+ for ( ( node, _) , input_extensions) in solution
210
+ . iter ( )
211
+ . filter ( |( ( _, dir) , _) | * dir == Direction :: Incoming )
212
+ {
213
+ let nodetype = self . op_types . try_get_mut ( node. index ) . unwrap ( ) ;
214
+ match nodetype. signature ( ) {
215
+ None => nodetype. input_extensions = Some ( input_extensions. clone ( ) ) ,
216
+ Some ( existing_ext_reqs) => {
217
+ debug_assert_eq ! ( existing_ext_reqs. input_extensions, * input_extensions)
218
+ }
219
+ }
220
+ }
208
221
}
209
222
}
210
223
@@ -366,7 +379,14 @@ impl From<HugrError> for PyErr {
366
379
367
380
#[ cfg( test) ]
368
381
mod test {
369
- use super :: Hugr ;
382
+ use super :: { Hugr , HugrView , NodeType } ;
383
+ use crate :: extension:: ExtensionSet ;
384
+ use crate :: hugr:: HugrInternalsMut ;
385
+ use crate :: ops;
386
+ use crate :: type_row;
387
+ use crate :: types:: { FunctionType , Type } ;
388
+
389
+ use std:: error:: Error ;
370
390
371
391
#[ test]
372
392
fn impls_send_and_sync ( ) {
@@ -385,4 +405,55 @@ mod test {
385
405
let hugr = simple_dfg_hugr ( ) ;
386
406
assert_matches ! ( hugr. get_io( hugr. root( ) ) , Some ( _) ) ;
387
407
}
408
+
409
+ #[ test]
410
+ fn extension_instantiation ( ) -> Result < ( ) , Box < dyn Error > > {
411
+ const BIT : Type = crate :: extension:: prelude:: USIZE_T ;
412
+ let r = ExtensionSet :: singleton ( & "R" . into ( ) ) ;
413
+
414
+ let root = NodeType :: pure ( ops:: DFG {
415
+ signature : FunctionType :: new ( type_row ! [ BIT ] , type_row ! [ BIT ] ) . with_extension_delta ( & r) ,
416
+ } ) ;
417
+ let mut hugr = Hugr :: new ( root) ;
418
+ let input = hugr. add_node_with_parent (
419
+ hugr. root ( ) ,
420
+ NodeType :: pure ( ops:: Input {
421
+ types : type_row ! [ BIT ] ,
422
+ } ) ,
423
+ ) ?;
424
+ let output = hugr. add_node_with_parent (
425
+ hugr. root ( ) ,
426
+ NodeType :: open_extensions ( ops:: Output {
427
+ types : type_row ! [ BIT ] ,
428
+ } ) ,
429
+ ) ?;
430
+ let lift = hugr. add_node_with_parent (
431
+ hugr. root ( ) ,
432
+ NodeType :: open_extensions ( ops:: LeafOp :: Lift {
433
+ type_row : type_row ! [ BIT ] ,
434
+ new_extension : "R" . into ( ) ,
435
+ } ) ,
436
+ ) ?;
437
+ hugr. connect ( input, 0 , lift, 0 ) ?;
438
+ hugr. connect ( lift, 0 , output, 0 ) ?;
439
+ hugr. infer_extensions ( ) ?;
440
+
441
+ assert_eq ! (
442
+ hugr. op_types
443
+ . get( lift. index)
444
+ . signature( )
445
+ . unwrap( )
446
+ . input_extensions,
447
+ ExtensionSet :: new( )
448
+ ) ;
449
+ assert_eq ! (
450
+ hugr. op_types
451
+ . get( output. index)
452
+ . signature( )
453
+ . unwrap( )
454
+ . input_extensions,
455
+ r
456
+ ) ;
457
+ Ok ( ( ) )
458
+ }
388
459
}
0 commit comments