Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Put extension inference behind a feature gate #786

Merged
merged 8 commits into from
Jan 8, 2024
8 changes: 4 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Continuous integration
on:
push:
branches:
- main
- main
pull_request:
branches:
- main
Expand Down Expand Up @@ -33,7 +33,7 @@ jobs:
- name: Check formatting
run: cargo fmt -- --check
- name: Run clippy
run: cargo clippy --all-targets -- -D warnings
run: cargo clippy --all-targets --all-features -- -D warnings
- name: Build docs
run: cargo doc --no-deps --all-features
env:
Expand Down Expand Up @@ -102,9 +102,9 @@ jobs:
- name: Run tests with coverage instrumentation
run: |
cargo llvm-cov clean --workspace
cargo llvm-cov --doctests
cargo llvm-cov --all-features --doctests
- name: Generate coverage report
run: cargo llvm-cov report --codecov --output-path coverage.json
run: cargo llvm-cov --all-features report --codecov --output-path coverage.json
- name: Upload coverage to codecov.io
uses: codecov/codecov-action@v3
with:
Expand Down
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ name = "hugr"
bench = false
path = "src/lib.rs"

[features]
extension_inference = []

[dependencies]
thiserror = "1.0.28"
portgraph = { version = "0.11.0", features = ["serde", "petgraph"] }
Expand Down
5 changes: 4 additions & 1 deletion src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ use crate::types::type_param::{check_type_args, TypeArgError};
use crate::types::type_param::{TypeArg, TypeParam};
use crate::types::{check_typevar_decl, CustomType, PolyFuncType, Substitution, TypeBound};

#[allow(dead_code)]
mod infer;
pub use infer::{infer_extensions, ExtensionSolution, InferExtensionError};
#[cfg(feature = "extension_inference")]
pub use infer::infer_extensions;
pub use infer::{ExtensionSolution, InferExtensionError};

mod op_def;
pub use op_def::{
Expand Down
17 changes: 13 additions & 4 deletions src/extension/infer/test.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
use std::error::Error;

use super::*;
use crate::builder::test::closed_dfg_root_hugr;
use crate::builder::{
Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder,
};
use crate::extension::prelude::QB_T;
use crate::extension::ExtensionId;
use crate::extension::{prelude::PRELUDE_REGISTRY, ExtensionSet};
use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType};
use crate::hugr::{Hugr, HugrMut, HugrView, NodeType};
use crate::macros::const_extension_ids;
use crate::ops::custom::{ExternalOp, OpaqueOp};
use crate::ops::dataflow::DataflowParent;
use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle};
use crate::ops::{self, dataflow::IOTrait};
use crate::ops::{LeafOp, OpType};
#[cfg(feature = "extension_inference")]
use crate::{
builder::test::closed_dfg_root_hugr,
hugr::validate::ValidationError,
ops::{dataflow::DataflowParent, handle::NodeHandle},
};

use crate::type_row;
use crate::types::{FunctionType, Type, TypeRow};
Expand Down Expand Up @@ -154,6 +158,7 @@ fn plus() -> Result<(), InferExtensionError> {
Ok(())
}

#[cfg(feature = "extension_inference")]
#[test]
// This generates a solution that causes validation to fail
// because of a missing lift node
Expand Down Expand Up @@ -215,6 +220,7 @@ fn open_variables() -> Result<(), InferExtensionError> {
Ok(())
}

#[cfg(feature = "extension_inference")]
#[test]
// Infer the extensions on a child node with no inputs
fn dangling_src() -> Result<(), Box<dyn Error>> {
Expand Down Expand Up @@ -306,6 +312,7 @@ fn create_with_io(
Ok([node, input, output])
}

#[cfg(feature = "extension_inference")]
#[test]
fn test_conditional_inference() -> Result<(), Box<dyn Error>> {
fn build_case(
Expand Down Expand Up @@ -968,6 +975,7 @@ fn simple_funcdefn() -> Result<(), Box<dyn Error>> {
Ok(())
}

#[cfg(feature = "extension_inference")]
#[test]
fn funcdefn_signature_mismatch() -> Result<(), Box<dyn Error>> {
let mut builder = ModuleBuilder::new();
Expand Down Expand Up @@ -998,6 +1006,7 @@ fn funcdefn_signature_mismatch() -> Result<(), Box<dyn Error>> {
Ok(())
}

#[cfg(feature = "extension_inference")]
#[test]
// Test that the difference between a FuncDefn's input and output nodes is being
// constrained to be the same as the extension delta in the FuncDefn signature.
Expand Down
46 changes: 25 additions & 21 deletions src/extension/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,33 +120,37 @@ impl ExtensionValidator {

/// Check that a pair of input and output nodes declare the same extensions
/// as in the signature of their parents.
#[allow(unused_variables)]
pub fn validate_io_extensions(
&self,
parent: Node,
input: Node,
output: Node,
) -> Result<(), ExtensionError> {
let parent_input_extensions = self.query_extensions(parent, Direction::Incoming)?;
let parent_output_extensions = self.query_extensions(parent, Direction::Outgoing)?;
for dir in Direction::BOTH {
let input_extensions = self.query_extensions(input, dir)?;
let output_extensions = self.query_extensions(output, dir)?;
if parent_input_extensions != input_extensions {
return Err(ExtensionError::ParentIOExtensionMismatch {
parent,
parent_extensions: parent_input_extensions.clone(),
child: input,
child_extensions: input_extensions.clone(),
});
};
if parent_output_extensions != output_extensions {
return Err(ExtensionError::ParentIOExtensionMismatch {
parent,
parent_extensions: parent_output_extensions.clone(),
child: output,
child_extensions: output_extensions.clone(),
});
};
#[cfg(feature = "extension_inference")]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

coverage is not showing this block - do we need to update the CI coverage workflow to run all features too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated! I wasn't sure what stance we were going to take on whether coverage should represent default hugr or include inference by default

{
let parent_input_extensions = self.query_extensions(parent, Direction::Incoming)?;
let parent_output_extensions = self.query_extensions(parent, Direction::Outgoing)?;
for dir in Direction::BOTH {
let input_extensions = self.query_extensions(input, dir)?;
let output_extensions = self.query_extensions(output, dir)?;
if parent_input_extensions != input_extensions {
return Err(ExtensionError::ParentIOExtensionMismatch {
parent,
parent_extensions: parent_input_extensions.clone(),
child: input,
child_extensions: input_extensions.clone(),
});
};
if parent_output_extensions != output_extensions {
return Err(ExtensionError::ParentIOExtensionMismatch {
parent,
parent_extensions: parent_output_extensions.clone(),
child: output,
child_extensions: output_extensions.clone(),
});
};
}
}
Ok(())
}
Expand Down
33 changes: 22 additions & 11 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ pub mod serialize;
pub mod validate;
pub mod views;

#[cfg(not(feature = "extension_inference"))]
use std::collections::HashMap;
use std::collections::VecDeque;
use std::iter;

Expand All @@ -23,9 +25,9 @@ use thiserror::Error;

pub use self::views::{HugrView, RootTagged};
use crate::core::NodeIndex;
use crate::extension::{
infer_extensions, ExtensionRegistry, ExtensionSet, ExtensionSolution, InferExtensionError,
};
#[cfg(feature = "extension_inference")]
use crate::extension::infer_extensions;
use crate::extension::{ExtensionRegistry, ExtensionSet, ExtensionSolution, InferExtensionError};
use crate::ops::custom::resolve_extension_ops;
use crate::ops::{OpTag, OpTrait, OpType, DEFAULT_OPTYPE};
use crate::types::FunctionType;
Expand Down Expand Up @@ -197,12 +199,19 @@ impl Hugr {
/// Infer extension requirements and add new information to `op_types` field
///
/// See [`infer_extensions`] for details on the "closure" value
#[cfg(feature = "extension_inference")]
pub fn infer_extensions(&mut self) -> Result<ExtensionSolution, InferExtensionError> {
let (solution, extension_closure) = infer_extensions(self)?;
self.instantiate_extensions(solution);
Ok(extension_closure)
}
/// Do nothing - this functionality is gated by the feature "extension_inference"
#[cfg(not(feature = "extension_inference"))]
pub fn infer_extensions(&mut self) -> Result<ExtensionSolution, InferExtensionError> {
Ok(HashMap::new())
}

#[allow(dead_code)]
/// Add extension requirement information to the hugr in place.
fn instantiate_extensions(&mut self, solution: ExtensionSolution) {
// We only care about inferred _input_ extensions, because `NodeType`
Expand Down Expand Up @@ -345,13 +354,7 @@ pub enum HugrError {
#[cfg(test)]
mod test {
use super::{Hugr, HugrView};
use crate::builder::test::closed_dfg_root_hugr;
use crate::extension::ExtensionSet;
use crate::hugr::HugrMut;
use crate::ops;
use crate::type_row;
use crate::types::{FunctionType, Type};

#[cfg(feature = "extension_inference")]
use std::error::Error;

#[test]
Expand All @@ -371,8 +374,16 @@ mod test {
assert_matches!(hugr.get_io(hugr.root()), Some(_));
}

#[cfg(feature = "extension_inference")]
#[test]
fn extension_instantiation() -> Result<(), Box<dyn Error>> {
use crate::builder::test::closed_dfg_root_hugr;
use crate::extension::ExtensionSet;
use crate::hugr::HugrMut;
use crate::ops::LeafOp;
use crate::type_row;
use crate::types::{FunctionType, Type};

const BIT: Type = crate::extension::prelude::USIZE_T;
let r = ExtensionSet::singleton(&"R".try_into().unwrap());

Expand All @@ -382,7 +393,7 @@ mod test {
let [input, output] = hugr.get_io(hugr.root()).unwrap();
let lift = hugr.add_node_with_parent(
hugr.root(),
ops::LeafOp::Lift {
LeafOp::Lift {
type_row: type_row![BIT],
new_extension: "R".try_into().unwrap(),
},
Expand Down
10 changes: 8 additions & 2 deletions src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ use petgraph::visit::{Topo, Walker};
use portgraph::{LinkView, PortView};
use thiserror::Error;

use crate::extension::validate::ExtensionValidator;
use crate::extension::SignatureError;
use crate::extension::{
validate::{ExtensionError, ExtensionValidator},
ExtensionRegistry, ExtensionSolution, InferExtensionError,
validate::ExtensionError, ExtensionRegistry, ExtensionSolution, InferExtensionError,
};

use crate::ops::custom::CustomOpError;
Expand All @@ -36,6 +36,7 @@ struct ValidationContext<'a, 'b> {
/// Dominator tree for each CFG region, using the container node as index.
dominators: HashMap<Node, Dominators<Node>>,
/// Context for the extension validation.
#[allow(dead_code)]
extension_validator: ExtensionValidator,
/// Registry of available Extensions
extension_registry: &'b ExtensionRegistry,
Expand Down Expand Up @@ -64,6 +65,9 @@ impl Hugr {

impl<'a, 'b> ValidationContext<'a, 'b> {
/// Create a new validation context.
// Allow unused "extension_closure" variable for when
// the "extension_inference" feature is disabled.
#[allow(unused_variables)]
pub fn new(
hugr: &'a Hugr,
extension_closure: ExtensionSolution,
Expand Down Expand Up @@ -163,6 +167,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> {

// FuncDefns have no resources since they're static nodes, but the
// functions they define can have any extension delta.
#[cfg(feature = "extension_inference")]
if node_type.tag() != OpTag::FuncDefn {
// If this is a container with I/O nodes, check that the extension they
// define match the extensions of the container.
Expand Down Expand Up @@ -240,6 +245,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
let other_node: Node = self.hugr.graph.port_node(link).unwrap().into();
let other_offset = self.hugr.graph.port_offset(link).unwrap().into();

#[cfg(feature = "extension_inference")]
self.extension_validator
.check_extensions_compatible(&(node, port), &(other_node, other_offset))?;

Expand Down
Loading