Skip to content

Commit 4491bad

Browse files
committed
feat: Instantiate inferred extensions
1 parent 95c0b56 commit 4491bad

File tree

2 files changed

+91
-13
lines changed

2 files changed

+91
-13
lines changed

src/extension/infer.rs

+15-8
Original file line numberDiff line numberDiff line change
@@ -857,21 +857,28 @@ mod test {
857857
let [w] = mult.outputs_arr();
858858

859859
builder.set_outputs([w])?;
860-
let hugr = builder.base;
861-
// TODO: when we put new extensions onto the graph after inference, we
862-
// can call `finish_hugr` and just look at the graph
863-
let (solution, extra) = infer_extensions(&hugr)?;
864-
assert!(extra.is_empty());
860+
let mut hugr = builder.base;
861+
let closure = hugr.infer_extensions()?;
862+
assert!(closure.is_empty());
865863
assert_eq!(
866-
*solution.get(&(src.node(), Direction::Outgoing)).unwrap(),
864+
hugr.get_nodetype(src.node())
865+
.signature()
866+
.unwrap()
867+
.output_extensions(),
867868
rs
868869
);
869870
assert_eq!(
870-
*solution.get(&(mult.node(), Direction::Incoming)).unwrap(),
871+
hugr.get_nodetype(mult.node())
872+
.signature()
873+
.unwrap()
874+
.input_extensions,
871875
rs
872876
);
873877
assert_eq!(
874-
*solution.get(&(mult.node(), Direction::Outgoing)).unwrap(),
878+
hugr.get_nodetype(mult.node())
879+
.signature()
880+
.unwrap()
881+
.output_extensions(),
875882
rs
876883
);
877884
Ok(())

src/hugr.rs

+76-5
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ impl Hugr {
193193
rw.apply(self)
194194
}
195195

196-
/// Infer extension requirements
196+
/// Infer extension requirements and add new information to `op_types` field
197197
pub fn infer_extensions(
198198
&mut self,
199199
) -> Result<HashMap<(Node, Direction), ExtensionSet>, InferExtensionError> {
@@ -202,9 +202,22 @@ impl Hugr {
202202
Ok(extension_closure)
203203
}
204204

205-
/// TODO: Write this
206-
fn instantiate_extensions(&mut self, _solution: ExtensionSolution) {
207-
//todo!()
205+
/// Add extension requirement information to the hugr in place.
206+
fn instantiate_extensions(&mut self, solution: ExtensionSolution) {
207+
// We only care about inferred _input_ extensions, because `NodeType`
208+
// uses those to infer the output extensions
209+
for ((node, _), input_extensions) in solution
210+
.iter()
211+
.filter(|((_, dir), _)| *dir == Direction::Incoming)
212+
{
213+
let nodetype = self.op_types.try_get_mut(node.index).unwrap();
214+
match nodetype.signature() {
215+
None => nodetype.input_extensions = Some(input_extensions.clone()),
216+
Some(existing_ext_reqs) => {
217+
debug_assert_eq!(existing_ext_reqs.input_extensions, *input_extensions)
218+
}
219+
}
220+
}
208221
}
209222
}
210223

@@ -366,7 +379,14 @@ impl From<HugrError> for PyErr {
366379

367380
#[cfg(test)]
368381
mod test {
369-
use super::Hugr;
382+
use super::{Hugr, HugrView, NodeType};
383+
use crate::extension::ExtensionSet;
384+
use crate::hugr::HugrInternalsMut;
385+
use crate::ops;
386+
use crate::type_row;
387+
use crate::types::{FunctionType, Type};
388+
389+
use std::error::Error;
370390

371391
#[test]
372392
fn impls_send_and_sync() {
@@ -385,4 +405,55 @@ mod test {
385405
let hugr = simple_dfg_hugr();
386406
assert_matches!(hugr.get_io(hugr.root()), Some(_));
387407
}
408+
409+
#[test]
410+
fn extension_instantiation() -> Result<(), Box<dyn Error>> {
411+
const BIT: Type = crate::extension::prelude::USIZE_T;
412+
let r = ExtensionSet::singleton(&"R".into());
413+
414+
let root = NodeType::pure(ops::DFG {
415+
signature: FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&r),
416+
});
417+
let mut hugr = Hugr::new(root);
418+
let input = hugr.add_node_with_parent(
419+
hugr.root(),
420+
NodeType::pure(ops::Input {
421+
types: type_row![BIT],
422+
}),
423+
)?;
424+
let output = hugr.add_node_with_parent(
425+
hugr.root(),
426+
NodeType::open_extensions(ops::Output {
427+
types: type_row![BIT],
428+
}),
429+
)?;
430+
let lift = hugr.add_node_with_parent(
431+
hugr.root(),
432+
NodeType::open_extensions(ops::LeafOp::Lift {
433+
type_row: type_row![BIT],
434+
new_extension: "R".into(),
435+
}),
436+
)?;
437+
hugr.connect(input, 0, lift, 0)?;
438+
hugr.connect(lift, 0, output, 0)?;
439+
hugr.infer_extensions()?;
440+
441+
assert_eq!(
442+
hugr.op_types
443+
.get(lift.index)
444+
.signature()
445+
.unwrap()
446+
.input_extensions,
447+
ExtensionSet::new()
448+
);
449+
assert_eq!(
450+
hugr.op_types
451+
.get(output.index)
452+
.signature()
453+
.unwrap()
454+
.input_extensions,
455+
r
456+
);
457+
Ok(())
458+
}
388459
}

0 commit comments

Comments
 (0)