Skip to content

Commit 9b66d6d

Browse files
authored
feat: Instantiate inferred extensions (#461)
After running extension inference on a hugr, add the inferred extensions to the `op_type` field of the hugr. Resolves #446
1 parent 1a16865 commit 9b66d6d

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/hugr.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ pub mod serialize;
77
pub mod validate;
88
pub mod views;
99

10-
use std::collections::{HashMap, VecDeque};
10+
use std::collections::VecDeque;
1111
use std::iter;
1212

1313
pub(crate) use self::hugrmut::HugrMut;
@@ -197,9 +197,9 @@ impl Hugr {
197197
}
198198

199199
/// Infer extension requirements and add new information to `op_types` field
200-
pub fn infer_extensions(
201-
&mut self,
202-
) -> Result<HashMap<(Node, Direction), ExtensionSet>, InferExtensionError> {
200+
///
201+
/// See [`infer_extensions`] for details on the "closure" value
202+
pub fn infer_extensions(&mut self) -> Result<ExtensionSolution, InferExtensionError> {
203203
let (solution, extension_closure) = infer_extensions(self)?;
204204
self.instantiate_extensions(solution);
205205
Ok(extension_closure)
@@ -214,10 +214,10 @@ impl Hugr {
214214
.filter(|((_, dir), _)| *dir == Direction::Incoming)
215215
{
216216
let nodetype = self.op_types.try_get_mut(node.index).unwrap();
217-
match nodetype.signature() {
217+
match &nodetype.input_extensions {
218218
None => nodetype.input_extensions = Some(input_extensions.clone()),
219219
Some(existing_ext_reqs) => {
220-
debug_assert_eq!(existing_ext_reqs.input_extensions, *input_extensions)
220+
debug_assert_eq!(existing_ext_reqs, input_extensions)
221221
}
222222
}
223223
}

0 commit comments

Comments
 (0)