diff --git a/specification/hugr.md b/specification/hugr.md index 748c69ebf..6a8e7772e 100644 --- a/specification/hugr.md +++ b/specification/hugr.md @@ -335,18 +335,18 @@ express control flow, i.e. conditional or repeated evaluation. ##### `Conditional` nodes These are parents to multiple `Case` nodes; the children have no edges. -The first input to the Conditional-node is of Predicate type (see below), whose +The first input to the Conditional-node is of TupleSum type (see below), whose arity matches the number of children of the Conditional-node. At runtime the constructor (tag) selects which child to execute; the unpacked -contents of the Predicate with all remaining inputs to Conditional +contents of the TupleSum with all remaining inputs to Conditional appended are sent to this child, and all outputs of the child are the outputs of the Conditional; that child is evaluated, but the others are not. That is, Conditional-nodes act as "if-then-else" followed by a control-flow merge. -A **Predicate(T0, T1…TN)** type is an algebraic “sum of products” type, -defined as `Sum(Tuple(#t0), Tuple(#t1), ...Tuple(#tn))` (see [type -system](#type-system)), where `#ti` is the *i*th Row defining it. +A **TupleSum(T0, T1…TN)** type is an algebraic “sum of products” type, +defined as `Sum(Tuple(#T0), Tuple(#T1), ...Tuple(#Tn))` (see [type +system](#type-system)), where `#Ti` is the *i*th Row defining it. ```mermaid flowchart @@ -362,7 +362,7 @@ flowchart end Case0 ~~~ Case1 end - Pred["case 0 inputs | case 1 inputs"] --> Conditional + TupleSum["case 0 inputs | case 1 inputs"] --> Conditional OI["other inputs"] --> Conditional Conditional --> outputs ``` @@ -371,13 +371,13 @@ flowchart These provide tail-controlled loops. The dataflow sibling graph within the TailLoop-node defines the loop body: this computes a row of outputs, whose -first element has type `Predicate(#I, #O)` and the remainder is a row `#X` +first element has type `TupleSum(#I, #O)` and the remainder is a row `#X` (perhaps empty). Inputs to the contained graph and to the TailLoop node itself are the row `#I:#X`, where `:` indicates row concatenation (with the tuple -inside the `Predicate` unpacked). +inside the `TupleSum` unpacked). Evaluation of the node begins by feeding the node inputs into the child graph -and evaluating it. The `Predicate` produced controls iteration of the loop: +and evaluating it. The `TupleSum` produced controls iteration of the loop: * The first variant (`#I`) means that these values, along with the other sibling-graph outputs `#X`, are fed back into the top of the loop, and the body is evaluated again (thus perhaps many times) @@ -405,7 +405,7 @@ The first child is the entry block and must be a `DFB`, with inputs the same as The remaining children are either `DFB`s or [scoped definitions](#scoped-definitions). The first output of the DSG contained in a `BasicBlock` has type -`Predicate(#t0,...#t(n-1))`, where the node has `n` successors, and the +`TupleSum(#t0,...#t(n-1))`, where the node has `n` successors, and the remaining outputs are a row `#x`. `#ti` with `#x` appended matches the inputs of successor `i`. @@ -431,7 +431,7 @@ output of each of these is a sum type, whose arity is the number of outgoing control edges; the remaining outputs are those that are passed to all succeeding nodes. -The three nodes labelled "Const" are simply generating a predicate with one empty +The three nodes labelled "Const" are simply generating a TupleSum with one empty value to pass to the Output node. ```mermaid @@ -1125,7 +1125,6 @@ run, which removes the `HigherOrder` extension requirement: ``` precompute :: Function[](Function[Quantum,HigherOrder](Array(5, Qubit), (ms: Array(5, Qubit), results: Array(5, Bit))), Function[Quantum](Array(5, Qubit), (ms: Array(5, Qubit), results: Array(5, Bit)))) ->>>>>>> c6abd39 ([doc] Tidy hugr specification) ``` Before we can run the circuit. @@ -1391,8 +1390,8 @@ use an empty node in the replacement and have B map this node to the old one. We can, for example, implement “turning a Conditional-node with known -predicate into a DFG-node” by a `Replace` where the Conditional (and its -preceding predicate) is replaced by an empty DFG and the map B specifies +TupleSum into a DFG-node” by a `Replace` where the Conditional (and its +preceding TupleSum) is replaced by an empty DFG and the map B specifies the “good” child of the Conditional as the surrogate parent of the new DFG’s children. (If the good child was just an Op, we could either remove it and include it in the replacement, or – to avoid this overhead diff --git a/src/algorithm/nest_cfgs.rs b/src/algorithm/nest_cfgs.rs index 62779d051..b7a9a1ec2 100644 --- a/src/algorithm/nest_cfgs.rs +++ b/src/algorithm/nest_cfgs.rs @@ -605,10 +605,8 @@ pub(crate) mod test { // \-> right -/ \-<--<-/ let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?; - let pred_const = - cfg_builder.add_constant(Const::simple_predicate(0, 2), ExtensionSet::new())?; // Nothing here cares which branch - let const_unit = - cfg_builder.add_constant(Const::simple_unary_predicate(), ExtensionSet::new())?; + let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2), ExtensionSet::new())?; // Nothing here cares which + let const_unit = cfg_builder.add_constant(Const::unary_unit_sum(), ExtensionSet::new())?; let entry = n_identity( cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?, @@ -889,10 +887,8 @@ pub(crate) mod test { separate: bool, ) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> { let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?; - let pred_const = - cfg_builder.add_constant(Const::simple_predicate(0, 2), ExtensionSet::new())?; // Nothing here cares which branch - let const_unit = - cfg_builder.add_constant(Const::simple_unary_predicate(), ExtensionSet::new())?; + let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2), ExtensionSet::new())?; // Nothing here cares which + let const_unit = cfg_builder.add_constant(Const::unary_unit_sum(), ExtensionSet::new())?; let entry = n_identity( cfg_builder.simple_entry_builder(type_row![NAT], 2, ExtensionSet::new())?, @@ -933,10 +929,8 @@ pub(crate) mod test { cfg_builder: &mut CFGBuilder, separate_headers: bool, ) -> Result<(BasicBlockID, BasicBlockID), BuildError> { - let pred_const = - cfg_builder.add_constant(Const::simple_predicate(0, 2), ExtensionSet::new())?; // Nothing here cares which branch - let const_unit = - cfg_builder.add_constant(Const::simple_unary_predicate(), ExtensionSet::new())?; + let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2), ExtensionSet::new())?; // Nothing here cares which + let const_unit = cfg_builder.add_constant(Const::unary_unit_sum(), ExtensionSet::new())?; let entry = n_identity( cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?, diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index e4ec74191..a890dcd6c 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -421,8 +421,8 @@ pub trait Dataflow: Container { } /// Return a builder for a [`crate::ops::Conditional`] node. - /// `predicate_inputs` and `predicate_wire` define the type of the predicate - /// variants and the wire carrying the predicate respectively. + /// `tuple_sum_rows` and `tuple_sum_wire` define the type of the TupleSum + /// variants and the wire carrying the TupleSum respectively. /// /// The `other_inputs` must be an iterable over pairs of the type of the input and /// the corresponding wire. @@ -434,24 +434,24 @@ pub trait Dataflow: Container { /// the Conditional node. fn conditional_builder( &mut self, - (predicate_inputs, predicate_wire): (impl IntoIterator, Wire), + (tuple_sum_rows, tuple_sum_wire): (impl IntoIterator, Wire), other_inputs: impl IntoIterator, output_types: TypeRow, extension_delta: ExtensionSet, ) -> Result, BuildError> { - let mut input_wires = vec![predicate_wire]; + let mut input_wires = vec![tuple_sum_wire]; let (input_types, rest_input_wires): (Vec, Vec) = other_inputs.into_iter().unzip(); input_wires.extend(rest_input_wires); let inputs: TypeRow = input_types.into(); - let predicate_inputs: Vec<_> = predicate_inputs.into_iter().collect(); - let n_cases = predicate_inputs.len(); + let tuple_sum_rows: Vec<_> = tuple_sum_rows.into_iter().collect(); + let n_cases = tuple_sum_rows.len(); let n_out_wires = output_types.len(); let conditional_id = self.add_dataflow_op( ops::Conditional { - predicate_inputs, + tuple_sum_rows, other_inputs: inputs, outputs: output_types, extension_delta, @@ -534,15 +534,15 @@ pub trait Dataflow: Container { } /// Add [`LeafOp::MakeTuple`] and [`LeafOp::Tag`] nodes to construct the - /// `tag` variant of a predicate (sum-of-tuples) type. - fn make_predicate( + /// `tag` variant of a TupleSum type. + fn make_tuple_sum( &mut self, tag: usize, - predicate_variants: impl IntoIterator, + tuple_sum_rows: impl IntoIterator, values: impl IntoIterator, ) -> Result { let tuple = self.make_tuple(values)?; - let variants = crate::types::predicate_variants_row(predicate_variants); + let variants = crate::types::tuple_sum_row(tuple_sum_rows); let make_op = self.add_dataflow_op(LeafOp::Tag { tag, variants }, vec![tuple])?; Ok(make_op.out_wire(0)) } @@ -561,7 +561,7 @@ pub trait Dataflow: Container { tail_loop: ops::TailLoop, values: impl IntoIterator, ) -> Result { - self.make_predicate(0, [tail_loop.just_inputs, tail_loop.just_outputs], values) + self.make_tuple_sum(0, [tail_loop.just_inputs, tail_loop.just_outputs], values) } /// Use the wires in `values` to return a wire corresponding to the @@ -578,7 +578,7 @@ pub trait Dataflow: Container { loop_op: ops::TailLoop, values: impl IntoIterator, ) -> Result { - self.make_predicate(1, [loop_op.just_inputs, loop_op.just_outputs], values) + self.make_tuple_sum(1, [loop_op.just_inputs, loop_op.just_outputs], values) } /// Add a [`ops::Call`] node, calling `function`, with inputs diff --git a/src/builder/cfg.rs b/src/builder/cfg.rs index 18855abd0..eb168082e 100644 --- a/src/builder/cfg.rs +++ b/src/builder/cfg.rs @@ -103,8 +103,8 @@ impl + AsRef> CFGBuilder { } /// Return a builder for a non-entry [`BasicBlock::DFB`] child graph with `inputs` - /// and `outputs` and the variants of the branching predicate Sum value - /// specified by `predicate_variants`. + /// and `outputs` and the variants of the branching TupleSum value + /// specified by `tuple_sum_rows`. /// /// # Errors /// @@ -112,13 +112,13 @@ impl + AsRef> CFGBuilder { pub fn block_builder( &mut self, inputs: TypeRow, - predicate_variants: Vec, + tuple_sum_rows: impl IntoIterator, extension_delta: ExtensionSet, other_outputs: TypeRow, ) -> Result, BuildError> { self.any_block_builder( inputs, - predicate_variants, + tuple_sum_rows, other_outputs, extension_delta, false, @@ -128,15 +128,16 @@ impl + AsRef> CFGBuilder { fn any_block_builder( &mut self, inputs: TypeRow, - predicate_variants: Vec, + tuple_sum_rows: impl IntoIterator, other_outputs: TypeRow, extension_delta: ExtensionSet, entry: bool, ) -> Result, BuildError> { + let tuple_sum_rows: Vec<_> = tuple_sum_rows.into_iter().collect(); let op = OpType::BasicBlock(BasicBlock::DFB { inputs: inputs.clone(), other_outputs: other_outputs.clone(), - predicate_variants: predicate_variants.clone(), + tuple_sum_rows: tuple_sum_rows.clone(), extension_delta, }); let parent = self.container_node(); @@ -152,14 +153,14 @@ impl + AsRef> CFGBuilder { BlockBuilder::create( self.hugr_mut(), block_n, - predicate_variants, + tuple_sum_rows, other_outputs, inputs, ) } /// Return a builder for a non-entry [`BasicBlock::DFB`] child graph with `inputs` - /// and `outputs` and a simple predicate type: a Sum of `n_cases` unit types. + /// and `outputs` and a UnitSum type: a Sum of `n_cases` unit types. /// /// # Errors /// @@ -178,15 +179,15 @@ impl + AsRef> CFGBuilder { } /// Return a builder for the entry [`BasicBlock::DFB`] child graph with `inputs` - /// and `outputs` and the variants of the branching predicate Sum value - /// specified by `predicate_variants`. + /// and `outputs` and the variants of the branching TupleSum value + /// specified by `tuple_sum_rows`. /// /// # Errors /// /// This function will return an error if an entry block has already been built. pub fn entry_builder( &mut self, - predicate_variants: Vec, + tuple_sum_rows: impl IntoIterator, other_outputs: TypeRow, extension_delta: ExtensionSet, ) -> Result, BuildError> { @@ -194,17 +195,11 @@ impl + AsRef> CFGBuilder { .inputs .take() .ok_or(BuildError::EntryBuiltError(self.cfg_node))?; - self.any_block_builder( - inputs, - predicate_variants, - other_outputs, - extension_delta, - true, - ) + self.any_block_builder(inputs, tuple_sum_rows, other_outputs, extension_delta, true) } /// Return a builder for the entry [`BasicBlock::DFB`] child graph with `inputs` - /// and `outputs` and a simple predicate type: a Sum of `n_cases` unit types. + /// and `outputs` and a UnitSum type: a Sum of `n_cases` unit types. /// /// # Errors /// @@ -244,8 +239,8 @@ impl + AsRef> CFGBuilder { pub type BlockBuilder = DFGWrapper; impl + AsRef> BlockBuilder { - /// Set the outputs of the block, with `branch_wire` being the value of the - /// predicate. `outputs` are the remaining outputs. + /// Set the outputs of the block, with `branch_wire` carrying the value of the + /// branch controlling TupleSum value. `outputs` are the remaining outputs. pub fn set_outputs( &mut self, branch_wire: Wire, @@ -256,13 +251,13 @@ impl + AsRef> BlockBuilder { fn create( base: B, block_n: Node, - predicate_variants: Vec, + tuple_sum_rows: impl IntoIterator, other_outputs: TypeRow, inputs: TypeRow, ) -> Result { - // The node outputs a predicate before the data outputs of the block node - let predicate_type = Type::new_predicate(predicate_variants); - let mut node_outputs = vec![predicate_type]; + // The node outputs a TupleSum before the data outputs of the block node + let tuple_sum_type = Type::new_tuple_sum(tuple_sum_rows); + let mut node_outputs = vec![tuple_sum_type]; node_outputs.extend_from_slice(&other_outputs); let signature = FunctionType::new(inputs, TypeRow::from(node_outputs)); let inp_ex = base @@ -293,23 +288,23 @@ impl BlockBuilder { pub fn new( inputs: impl Into, input_extensions: impl Into>, - predicate_variants: impl IntoIterator, + tuple_sum_rows: impl IntoIterator, other_outputs: impl Into, extension_delta: ExtensionSet, ) -> Result { let inputs = inputs.into(); - let predicate_variants: Vec<_> = predicate_variants.into_iter().collect(); + let tuple_sum_rows: Vec<_> = tuple_sum_rows.into_iter().collect(); let other_outputs = other_outputs.into(); let op = BasicBlock::DFB { inputs: inputs.clone(), other_outputs: other_outputs.clone(), - predicate_variants: predicate_variants.clone(), + tuple_sum_rows: tuple_sum_rows.clone(), extension_delta, }; let base = Hugr::new(NodeType::new(op, input_extensions)); let root = base.root(); - Self::create(base, root, predicate_variants, other_outputs, inputs) + Self::create(base, root, tuple_sum_rows, other_outputs, inputs) } /// [Set outputs](BlockBuilder::set_outputs) and [finish_hugr](`BlockBuilder::finish_hugr`). @@ -382,14 +377,13 @@ mod test { let entry = { let [inw] = entry_b.input_wires_arr(); - let sum = entry_b.make_predicate(1, sum2_variants, [inw])?; + let sum = entry_b.make_tuple_sum(1, sum2_variants, [inw])?; entry_b.finish_with_outputs(sum, [])? }; let mut middle_b = cfg_builder .simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?; let middle = { - let c = middle_b - .add_load_const(ops::Const::simple_unary_predicate(), ExtensionSet::new())?; + let c = middle_b.add_load_const(ops::Const::unary_unit_sum(), ExtensionSet::new())?; let [inw] = middle_b.input_wires_arr(); middle_b.finish_with_outputs(c, [inw])? }; diff --git a/src/builder/conditional.rs b/src/builder/conditional.rs index 54076b3ce..f1af1ad0d 100644 --- a/src/builder/conditional.rs +++ b/src/builder/conditional.rs @@ -158,20 +158,20 @@ impl HugrBuilder for ConditionalBuilder { impl ConditionalBuilder { /// Initialize a Conditional rooted HUGR builder pub fn new( - predicate_inputs: impl IntoIterator, + tuple_sum_rows: impl IntoIterator, other_inputs: impl Into, outputs: impl Into, extension_delta: ExtensionSet, ) -> Result { - let predicate_inputs: Vec<_> = predicate_inputs.into_iter().collect(); + let tuple_sum_rows: Vec<_> = tuple_sum_rows.into_iter().collect(); let other_inputs = other_inputs.into(); let outputs = outputs.into(); let n_out_wires = outputs.len(); - let n_cases = predicate_inputs.len(); + let n_cases = tuple_sum_rows.len(); let op = ops::Conditional { - predicate_inputs, + tuple_sum_rows, other_inputs, outputs, extension_delta, @@ -222,9 +222,8 @@ mod test { #[test] fn basic_conditional() -> Result<(), BuildError> { - let predicate_inputs = vec![type_row![]; 2]; let mut conditional_b = ConditionalBuilder::new( - predicate_inputs, + [type_row![], type_row![]], type_row![NAT], type_row![NAT], ExtensionSet::new(), @@ -248,11 +247,10 @@ mod test { let const_wire = fbuild.load_const(&tru_const)?; let [int] = fbuild.input_wires_arr(); let conditional_id = { - let predicate_inputs = vec![type_row![]; 2]; let other_inputs = vec![(NAT, int)]; let outputs = vec![NAT].into(); let mut conditional_b = fbuild.conditional_builder( - (predicate_inputs, const_wire), + ([type_row![], type_row![]], const_wire), other_inputs, outputs, ExtensionSet::new(), @@ -276,9 +274,8 @@ mod test { #[test] fn test_not_all_cases() -> Result<(), BuildError> { - let predicate_inputs = vec![type_row![]; 2]; let mut builder = ConditionalBuilder::new( - predicate_inputs, + [type_row![], type_row![]], type_row![], type_row![], ExtensionSet::new(), @@ -295,9 +292,8 @@ mod test { #[test] fn test_case_already_built() -> Result<(), BuildError> { - let predicate_inputs = vec![type_row![]; 2]; let mut builder = ConditionalBuilder::new( - predicate_inputs, + [type_row![], type_row![]], type_row![], type_row![], ExtensionSet::new(), diff --git a/src/builder/tail_loop.rs b/src/builder/tail_loop.rs index 865814417..a8a07eb4a 100644 --- a/src/builder/tail_loop.rs +++ b/src/builder/tail_loop.rs @@ -26,7 +26,7 @@ impl + AsRef> TailLoopBuilder { Ok(TailLoopBuilder::from_dfg_builder(dfg_build)) } /// Set the outputs of the [`ops::TailLoop`], with `out_variant` as the value of the - /// termination predicate, and `rest` being the remaining outputs + /// termination TupleSum, and `rest` being the remaining outputs pub fn set_outputs( &mut self, out_variant: Wire, @@ -48,7 +48,7 @@ impl + AsRef> TailLoopBuilder { } } - /// The output types of the child graph, including the predicate as the first. + /// The output types of the child graph, including the TupleSum as the first. pub fn internal_output_row(&self) -> Result { self.loop_signature().map(ops::TailLoop::body_output_row) } @@ -152,10 +152,9 @@ mod test { let [const_wire] = lift_node.outputs_arr(); let [b1] = loop_b.input_wires_arr(); let conditional_id = { - let predicate_inputs = vec![type_row![]; 2]; let output_row = loop_b.internal_output_row()?; let mut conditional_b = loop_b.conditional_builder( - (predicate_inputs, const_wire), + ([type_row![], type_row![]], const_wire), vec![(BIT, b1)], output_row, ExtensionSet::new(), diff --git a/src/extension/infer.rs b/src/extension/infer.rs index a1df4d4a1..6fbeaec05 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -1048,14 +1048,14 @@ mod test { Ok(case) } - let predicate_inputs = vec![type_row![]; 2]; + let tuple_sum_rows = vec![type_row![]; 2]; let rs = ExtensionSet::from_iter([A, B]); let inputs = type_row![NAT]; let outputs = type_row![NAT]; let op = ops::Conditional { - predicate_inputs, + tuple_sum_rows, other_inputs: inputs.clone(), outputs: outputs.clone(), extension_delta: rs.clone(), @@ -1169,16 +1169,17 @@ mod test { hugr: &mut Hugr, bb_parent: Node, inputs: TypeRow, - predicate_variants: Vec, + tuple_sum_rows: impl IntoIterator, extension_delta: ExtensionSet, ) -> Result> { - let predicate_type = Type::new_predicate(predicate_variants.clone()); - let dfb_sig = FunctionType::new(inputs.clone(), vec![predicate_type]) + let tuple_sum_rows: Vec<_> = tuple_sum_rows.into_iter().collect(); + let tuple_sum_type = Type::new_tuple_sum(tuple_sum_rows.clone()); + let dfb_sig = FunctionType::new(inputs.clone(), vec![tuple_sum_type]) .with_extension_delta(&extension_delta.clone()); let dfb = ops::BasicBlock::DFB { inputs, other_outputs: type_row![], - predicate_variants, + tuple_sum_rows, extension_delta, }; let op = make_opaque(UNKNOWN_EXTENSION, dfb_sig.clone()); @@ -1194,26 +1195,26 @@ mod test { } fn oneway(ty: Type) -> Vec { - vec![Type::new_predicate([vec![ty]])] + vec![Type::new_tuple_sum([vec![ty]])] } fn twoway(ty: Type) -> Vec { - vec![Type::new_predicate([vec![ty.clone()], vec![ty]])] + vec![Type::new_tuple_sum([vec![ty.clone()], vec![ty]])] } fn create_entry_exit( hugr: &mut Hugr, root: Node, inputs: TypeRow, - entry_predicates: Vec, + entry_variants: Vec, entry_extensions: ExtensionSet, exit_types: impl Into, ) -> Result<([Node; 3], Node), Box> { - let entry_predicate_type = Type::new_predicate(entry_predicates.clone()); + let entry_tuple_sum = Type::new_tuple_sum(entry_variants.clone()); let dfb = ops::BasicBlock::DFB { inputs: inputs.clone(), other_outputs: type_row![], - predicate_variants: entry_predicates, + tuple_sum_rows: entry_variants, extension_delta: entry_extensions, }; @@ -1232,7 +1233,7 @@ mod test { let entry_out = hugr.add_node_with_parent( entry, NodeType::open_extensions(ops::Output { - types: vec![entry_predicate_type].into(), + types: vec![entry_tuple_sum].into(), }), )?; diff --git a/src/extension/prelude.rs b/src/extension/prelude.rs index d01ebc149..0db268400 100644 --- a/src/extension/prelude.rs +++ b/src/extension/prelude.rs @@ -86,7 +86,7 @@ pub const QB_T: Type = Type::new_extension(QB_CUSTOM_T); /// Unsigned size type. pub const USIZE_T: Type = Type::new_extension(USIZE_CUSTOM_T); /// Boolean type - Sum of two units. -pub const BOOL_T: Type = Type::new_simple_predicate(2); +pub const BOOL_T: Type = Type::new_unit_sum(2); /// Initialize a new array of element type `element_ty` of length `size` pub fn array_type(element_ty: Type, size: u64) -> Type { diff --git a/src/hugr/rewrite/outline_cfg.rs b/src/hugr/rewrite/outline_cfg.rs index bb0546ae3..910ea4fcc 100644 --- a/src/hugr/rewrite/outline_cfg.rs +++ b/src/hugr/rewrite/outline_cfg.rs @@ -139,10 +139,10 @@ impl Rewrite for OutlineCfg { .cfg_builder(wires_in, input_extensions, outputs, extension_delta) .unwrap(); let cfg = cfg.finish_sub_container().unwrap(); - let predicate = new_block_bldr - .add_constant(ops::Const::simple_unary_predicate(), ExtensionSet::new()) + let unit_sum = new_block_bldr + .add_constant(ops::Const::unary_unit_sum(), ExtensionSet::new()) .unwrap(); - let pred_wire = new_block_bldr.load_const(&predicate).unwrap(); + let pred_wire = new_block_bldr.load_const(&unit_sum).unwrap(); new_block_bldr .set_outputs(pred_wire, cfg.outputs()) .unwrap(); diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index 3338c081a..57b9020c0 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -800,18 +800,18 @@ mod test { (input, copy, output) } - /// Adds an input{BOOL_T}, tag_constant(0, BOOL_T^pred_size), tag(BOOL_T^pred_size), and - /// output{Sum{unit^pred_size}, BOOL_T} operation to a dataflow container. + /// Adds an input{BOOL_T}, tag_constant(0, BOOL_T^tuple_sum_size), tag(BOOL_T^tuple_sum_size), and + /// output{Sum{unit^tuple_sum_size}, BOOL_T} operation to a dataflow container. /// Intended to be used to populate a BasicBlock node in a CFG. /// /// Returns the node indices of each of the operations. fn add_block_children( b: &mut Hugr, parent: Node, - predicate_size: usize, + tuple_sum_size: usize, ) -> (Node, Node, Node, Node) { - let const_op = ops::Const::simple_predicate(0, predicate_size as u8); - let tag_type = Type::new_simple_predicate(predicate_size as u8); + let const_op = ops::Const::unit_sum(0, tuple_sum_size as u8); + let tag_type = Type::new_unit_sum(tuple_sum_size as u8); let input = b .add_op_with_parent(parent, ops::Input::new(type_row![BOOL_T])) @@ -1017,7 +1017,7 @@ mod test { cfg, ops::BasicBlock::DFB { inputs: type_row![BOOL_T], - predicate_variants: vec![type_row![]], + tuple_sum_rows: vec![type_row![]], other_outputs: type_row![BOOL_T], extension_delta: ExtensionSet::new(), }, @@ -1058,7 +1058,7 @@ mod test { block, NodeType::pure(ops::BasicBlock::DFB { inputs: type_row![Q], - predicate_variants: vec![type_row![]], + tuple_sum_rows: vec![type_row![]], other_outputs: type_row![Q], extension_delta: ExtensionSet::new(), }), @@ -1071,10 +1071,7 @@ mod test { .unwrap(); b.replace_op( block_output, - NodeType::pure(ops::Output::new(type_row![ - Type::new_simple_predicate(1), - Q - ])), + NodeType::pure(ops::Output::new(type_row![Type::new_unit_sum(1), Q])), ) .unwrap(); assert_matches!( diff --git a/src/hugr/views/root_checked.rs b/src/hugr/views/root_checked.rs index 9ecdaa27a..26815e8ed 100644 --- a/src/hugr/views/root_checked.rs +++ b/src/hugr/views/root_checked.rs @@ -97,7 +97,7 @@ mod test { let bb = NodeType::pure(BasicBlock::DFB { inputs: type_row![], other_outputs: type_row![], - predicate_variants: vec![type_row![]], + tuple_sum_rows: vec![type_row![]], extension_delta: ExtensionSet::new(), }); let r = dfg_v.replace_op(root, bb.clone()); diff --git a/src/ops/constant.rs b/src/ops/constant.rs index 81b21f7dd..308378e3a 100644 --- a/src/ops/constant.rs +++ b/src/ops/constant.rs @@ -34,41 +34,38 @@ impl Const { &self.typ } - /// Sum of Tuples, used as predicates in branching. + /// Sum of Tuples, used for branching. /// Tuple rows are defined in order by input rows. - pub fn predicate( + pub fn tuple_sum( tag: usize, value: Value, variant_rows: impl IntoIterator, ) -> Result { - let typ = Type::new_predicate(variant_rows); + let typ = Type::new_tuple_sum(variant_rows); Self::new(Value::sum(tag, value), typ) } - /// Constant Sum over units, used as predicates. - pub fn simple_predicate(tag: usize, size: u8) -> Self { + /// Constant Sum over units, used as branching values. + pub fn unit_sum(tag: usize, size: u8) -> Self { Self { - value: Value::simple_predicate(tag), - typ: Type::new_simple_predicate(size), + value: Value::unit_sum(tag), + typ: Type::new_unit_sum(size), } } /// Constant Sum over units, with only one variant. - pub fn simple_unary_predicate() -> Self { - Self { - value: Value::simple_unary_predicate(), - typ: Type::new_simple_predicate(1), - } + pub fn unary_unit_sum() -> Self { + Self::unit_sum(0, 1) } /// Constant "true" value, i.e. the second variant of Sum((), ()). pub fn true_val() -> Self { - Self::simple_predicate(1, 2) + Self::unit_sum(1, 2) } /// Constant "false" value, i.e. the first variant of Sum((), ()). pub fn false_val() -> Self { - Self::simple_predicate(0, 2) + Self::unit_sum(0, 2) } /// Tuple of values @@ -142,17 +139,17 @@ mod test { use super::*; #[test] - fn test_predicate() -> Result<(), BuildError> { + fn test_tuple_sum() -> Result<(), BuildError> { use crate::builder::Container; let pred_rows = vec![type_row![USIZE_T, FLOAT64_TYPE], type_row![]]; - let pred_ty = Type::new_predicate(pred_rows.clone()); + let pred_ty = Type::new_tuple_sum(pred_rows.clone()); let mut b = DFGBuilder::new(FunctionType::new( type_row![], TypeRow::from(vec![pred_ty.clone()]), ))?; let c = b.add_constant( - Const::predicate( + Const::tuple_sum( 0, Value::tuple([CustomTestValue(TypeBound::Eq).into(), serialized_float(5.1)]), pred_rows.clone(), @@ -164,7 +161,7 @@ mod test { let mut b = DFGBuilder::new(FunctionType::new(type_row![], TypeRow::from(vec![pred_ty])))?; let c = b.add_constant( - Const::predicate(1, Value::unit(), pred_rows)?, + Const::tuple_sum(1, Value::unit(), pred_rows)?, ExtensionSet::new(), )?; let w = b.load_const(&c)?; @@ -174,10 +171,10 @@ mod test { } #[test] - fn test_bad_predicate() { + fn test_bad_tuple_sum() { let pred_rows = [type_row![USIZE_T, FLOAT64_TYPE], type_row![]]; - let res = Const::predicate(0, Value::tuple([]), pred_rows); + let res = Const::tuple_sum(0, Value::tuple([]), pred_rows); assert_matches!(res, Err(ConstTypeError::TupleWrongLength)); } diff --git a/src/ops/controlflow.rs b/src/ops/controlflow.rs index 10c149570..54adcf803 100644 --- a/src/ops/controlflow.rs +++ b/src/ops/controlflow.rs @@ -32,7 +32,7 @@ impl DataflowOpTrait for TailLoop { fn signature(&self) -> FunctionType { let [inputs, outputs] = - [&self.just_inputs, &self.just_outputs].map(|row| predicate_first(row, &self.rest)); + [&self.just_inputs, &self.just_outputs].map(|row| tuple_sum_first(row, &self.rest)); FunctionType::new(inputs, outputs) } } @@ -40,23 +40,24 @@ impl DataflowOpTrait for TailLoop { impl TailLoop { /// Build the output TypeRow of the child graph of a TailLoop node. pub(crate) fn body_output_row(&self) -> TypeRow { - let predicate = Type::new_predicate([self.just_inputs.clone(), self.just_outputs.clone()]); - let mut outputs = vec![predicate]; + let tuple_sum_type = + Type::new_tuple_sum([self.just_inputs.clone(), self.just_outputs.clone()]); + let mut outputs = vec![tuple_sum_type]; outputs.extend_from_slice(&self.rest); outputs.into() } /// Build the input TypeRow of the child graph of a TailLoop node. pub(crate) fn body_input_row(&self) -> TypeRow { - predicate_first(&self.just_inputs, &self.rest) + tuple_sum_first(&self.just_inputs, &self.rest) } } /// Conditional operation, defined by child `Case` nodes for each branch. #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct Conditional { - /// The possible rows of the predicate input - pub predicate_inputs: Vec, + /// The possible rows of the TupleSum input + pub tuple_sum_rows: Vec, /// Remaining input types pub other_inputs: TypeRow, /// Output types @@ -77,7 +78,7 @@ impl DataflowOpTrait for Conditional { let mut inputs = self.other_inputs.clone(); inputs .to_mut() - .insert(0, Type::new_predicate(self.predicate_inputs.clone())); + .insert(0, Type::new_tuple_sum(self.tuple_sum_rows.clone())); FunctionType::new(inputs, self.outputs.clone()).with_extension_delta(&self.extension_delta) } } @@ -85,8 +86,8 @@ impl DataflowOpTrait for Conditional { impl Conditional { /// Build the input TypeRow of the nth child graph of a Conditional node. pub(crate) fn case_input_row(&self, case: usize) -> Option { - Some(predicate_first( - self.predicate_inputs.get(case)?, + Some(tuple_sum_first( + self.tuple_sum_rows.get(case)?, &self.other_inputs, )) } @@ -122,7 +123,7 @@ pub enum BasicBlock { DFB { inputs: TypeRow, other_outputs: TypeRow, - predicate_variants: Vec, + tuple_sum_rows: Vec, extension_delta: ExtensionSet, }, /// The single exit node of the CFG, has no children, @@ -192,10 +193,10 @@ impl BasicBlock { pub fn successor_input(&self, successor: usize) -> Option { match self { BasicBlock::DFB { - predicate_variants, + tuple_sum_rows, other_outputs: outputs, .. - } => Some(predicate_first(predicate_variants.get(successor)?, outputs)), + } => Some(tuple_sum_first(tuple_sum_rows.get(successor)?, outputs)), BasicBlock::Exit { .. } => panic!("Exit should have no successors"), } } @@ -240,9 +241,10 @@ impl Case { } } -fn predicate_first(pred: &TypeRow, rest: &TypeRow) -> TypeRow { +fn tuple_sum_first(tuple_sum_row: &TypeRow, rest: &TypeRow) -> TypeRow { TypeRow::from( - pred.iter() + tuple_sum_row + .iter() .cloned() .chain(rest.iter().cloned()) .collect::>(), diff --git a/src/ops/validate.rs b/src/ops/validate.rs index 87060e7c3..6c6ff65cd 100644 --- a/src/ops/validate.rs +++ b/src/ops/validate.rs @@ -146,17 +146,17 @@ impl ValidateOp for super::Conditional { children: impl DoubleEndedIterator, ) -> Result<(), ChildrenValidationError> { let children = children.collect_vec(); - // The first input to the ɣ-node is a predicate of Sum type, + // The first input to the ɣ-node is a value of Sum type, // whose arity matches the number of children of the ɣ-node. - if self.predicate_inputs.len() != children.len() { - return Err(ChildrenValidationError::InvalidConditionalPredicate { + if self.tuple_sum_rows.len() != children.len() { + return Err(ChildrenValidationError::InvalidConditionalTupleSum { child: children[0].0, // Pass an arbitrary child expected_count: children.len(), - actual_predicate_rows: self.predicate_inputs.clone(), + actual_sum_rows: self.tuple_sum_rows.clone(), }); } - // Each child must have its predicate variant's row and the rest of `inputs` as input, + // Each child must have its variant's row and the rest of `inputs` as input, // and matching output for (i, (child, optype)) in children.into_iter().enumerate() { let OpType::Case(case_op) = optype else { @@ -251,13 +251,13 @@ pub enum ChildrenValidationError { /// The signature of a child case in a conditional operation does not match the container's signature. #[error("A conditional case has optype {optype:?}, which differs from the signature of Conditional container")] ConditionalCaseSignature { child: NodeIndex, optype: OpType }, - /// The conditional container's branch predicate does not match the number of children. - #[error("The conditional container's branch predicate input should be a sum with {expected_count} elements, but it had {} elements. Predicate rows: {actual_predicate_rows:?}", - actual_predicate_rows.len())] - InvalidConditionalPredicate { + /// The conditional container's branching value does not match the number of children. + #[error("The conditional container's branch TupleSum input should be a sum with {expected_count} elements, but it had {} elements. TupleSum rows: {actual_sum_rows:?}", + actual_sum_rows.len())] + InvalidConditionalTupleSum { child: NodeIndex, expected_count: usize, - actual_predicate_rows: Vec, + actual_sum_rows: Vec, }, } @@ -269,7 +269,7 @@ impl ChildrenValidationError { ChildrenValidationError::InternalExitChildren { child, .. } => *child, ChildrenValidationError::ConditionalCaseSignature { child, .. } => *child, ChildrenValidationError::IOSignatureMismatch { child, .. } => *child, - ChildrenValidationError::InvalidConditionalPredicate { child, .. } => *child, + ChildrenValidationError::InvalidConditionalTupleSum { child, .. } => *child, } } } @@ -317,14 +317,15 @@ impl ValidateOp for BasicBlock { fn validity_flags(&self) -> OpValidityFlags { match self { BasicBlock::DFB { - predicate_variants, .. + tuple_sum_rows: tuple_sum_variants, + .. } => OpValidityFlags { allowed_children: OpTag::DataflowChild, allowed_first_child: OpTag::Input, allowed_second_child: OpTag::Output, requires_children: true, requires_dag: true, - non_df_ports: (None, Some(predicate_variants.len())), + non_df_ports: (None, Some(tuple_sum_variants.len())), ..Default::default() }, // Default flags are valid for non-container operations @@ -340,12 +341,12 @@ impl ValidateOp for BasicBlock { match self { BasicBlock::DFB { inputs, - predicate_variants, + tuple_sum_rows: tuple_sum_variants, other_outputs: outputs, extension_delta: _, } => { - let predicate_type = Type::new_predicate(predicate_variants.clone()); - let node_outputs: TypeRow = [&[predicate_type], outputs.as_ref()].concat().into(); + let tuple_sum_type = Type::new_tuple_sum(tuple_sum_variants.clone()); + let node_outputs: TypeRow = [&[tuple_sum_type], outputs.as_ref()].concat().into(); validate_io_nodes(inputs, &node_outputs, "basic block graph", children) } // Exit nodes do not have children diff --git a/src/std_extensions/logic.rs b/src/std_extensions/logic.rs index 351bf9670..978c520c8 100644 --- a/src/std_extensions/logic.rs +++ b/src/std_extensions/logic.rs @@ -93,10 +93,10 @@ fn extension() -> Extension { .unwrap(); extension - .add_value(FALSE_NAME, ops::Const::simple_predicate(0, 2)) + .add_value(FALSE_NAME, ops::Const::unit_sum(0, 2)) .unwrap(); extension - .add_value(TRUE_NAME, ops::Const::simple_predicate(1, 2)) + .add_value(TRUE_NAME, ops::Const::unit_sum(1, 2)) .unwrap(); extension } diff --git a/src/types.rs b/src/types.rs index fdc640427..cabdcf39d 100644 --- a/src/types.rs +++ b/src/types.rs @@ -108,11 +108,11 @@ pub(crate) fn least_upper_bound(mut tags: impl Iterator) -> Ty #[serde(tag = "s")] /// Representation of a Sum type. /// Either store the types of the variants, or in the special (but common) case -/// of a "simple predicate" (sum over empty tuples), store only the size of the predicate. +/// of a UnitSum (sum over empty tuples), store only the size of the Sum. pub enum SumType { #[allow(missing_docs)] - #[display(fmt = "SimplePredicate({})", "size")] - Simple { size: u8 }, + #[display(fmt = "UnitSum({})", "size")] + Unit { size: u8 }, #[allow(missing_docs)] General { row: TypeRow }, } @@ -124,7 +124,7 @@ impl SumType { let len: usize = row.len(); if len <= (u8::MAX as usize) && row.iter().all(|t| *t == Type::UNIT) { - Self::Simple { size: len as u8 } + Self::Unit { size: len as u8 } } else { Self::General { row } } @@ -133,7 +133,7 @@ impl SumType { /// Report the tag'th variant, if it exists. pub fn get_variant(&self, tag: usize) -> Option<&Type> { match self { - SumType::Simple { size } if tag < (*size as usize) => Some(Type::UNIT_REF), + SumType::Unit { size } if tag < (*size as usize) => Some(Type::UNIT_REF), SumType::General { row } => row.get(tag), _ => None, } @@ -143,7 +143,7 @@ impl SumType { impl From for Type { fn from(sum: SumType) -> Type { match sum { - SumType::Simple { size } => Type::new_simple_predicate(size), + SumType::Unit { size } => Type::new_unit_sum(size), SumType::General { row } => Type::new_sum(row), } } @@ -166,7 +166,7 @@ impl TypeEnum { fn least_upper_bound(&self) -> TypeBound { match self { TypeEnum::Prim(p) => p.bound(), - TypeEnum::Sum(SumType::Simple { size: _ }) => TypeBound::Eq, + TypeEnum::Sum(SumType::Unit { size: _ }) => TypeBound::Eq, TypeEnum::Sum(SumType::General { row }) => { least_upper_bound(row.iter().map(Type::least_upper_bound)) } @@ -245,19 +245,19 @@ impl Type { Self(type_e, bound) } - /// New Sum of Tuple types, used as predicates in branching. + /// New Sum of Tuple types, used in branching control. /// Tuple rows are defined in order by input rows. - pub fn new_predicate(variant_rows: impl IntoIterator) -> Self + pub fn new_tuple_sum(variant_rows: impl IntoIterator) -> Self where V: Into, { - Self::new_sum(predicate_variants_row(variant_rows)) + Self::new_sum(tuple_sum_row(variant_rows)) } - /// New simple predicate with empty Tuple variants - pub const fn new_simple_predicate(size: u8) -> Self { + /// New UnitSum with empty Tuple variants + pub const fn new_unit_sum(size: u8) -> Self { // should be the only way to avoid going through SumType::new - Self(TypeEnum::Sum(SumType::Simple { size }), TypeBound::Eq) + Self(TypeEnum::Sum(SumType::Unit { size }), TypeBound::Eq) } /// Report the least upper TypeBound, if there is one. @@ -288,7 +288,7 @@ impl Type { TypeEnum::Tuple(row) | TypeEnum::Sum(SumType::General { row }) => { row.iter().try_for_each(|t| t.validate(extension_registry)) } - TypeEnum::Sum(SumType::Simple { .. }) => Ok(()), // No leaves there + TypeEnum::Sum(SumType::Unit { .. }) => Ok(()), // No leaves there TypeEnum::Prim(PrimType::Alias(_)) => Ok(()), TypeEnum::Prim(PrimType::Extension(custy)) => custy.validate(extension_registry), TypeEnum::Prim(PrimType::Function(ft)) => ft @@ -302,7 +302,7 @@ impl Type { /// Return the type row of variants required to define a Sum of Tuples type /// given the rows of each tuple -pub(crate) fn predicate_variants_row(variant_rows: impl IntoIterator) -> TypeRow +pub(crate) fn tuple_sum_row(variant_rows: impl IntoIterator) -> TypeRow where V: Into, { @@ -351,11 +351,11 @@ pub(crate) mod test { #[test] fn sum_construct() { let pred1 = Type::new_sum(type_row![Type::UNIT, Type::UNIT]); - let pred2 = Type::new_simple_predicate(2); + let pred2 = Type::new_unit_sum(2); assert_eq!(pred1, pred2); - let pred_direct = SumType::Simple { size: 2 }; + let pred_direct = SumType::Unit { size: 2 }; assert_eq!(pred1, pred_direct.into()) } } diff --git a/src/types/serialize.rs b/src/types/serialize.rs index f70cc57b0..f55258fc9 100644 --- a/src/types/serialize.rs +++ b/src/types/serialize.rs @@ -80,8 +80,7 @@ mod test { let t = Type::new_sum(vec![USIZE_T, FLOAT64_TYPE]); assert_eq!(ser_roundtrip(&t), t); - // A simple predicate - let t = Type::new_simple_predicate(4); + let t = Type::new_unit_sum(4); assert_eq!(ser_roundtrip(&t), t); } } diff --git a/src/values.rs b/src/values.rs index f9419ede0..987e8a5be 100644 --- a/src/values.rs +++ b/src/values.rs @@ -96,14 +96,14 @@ impl Value { Self::Tuple { vs: vec![] } } - /// Constant Sum over units, used as predicates. - pub fn simple_predicate(tag: usize) -> Self { + /// Constant Sum of a unit value, used to control branches. + pub fn unit_sum(tag: usize) -> Self { Self::sum(tag, Self::unit()) } - /// Constant Sum over Tuples with just one variant of unit type - pub fn simple_unary_predicate() -> Self { - Self::simple_predicate(0) + /// Constant Sum with just one variant of unit type + pub fn unary_unit_sum() -> Self { + Self::unit_sum(0) } /// Tuple of values. @@ -113,7 +113,7 @@ impl Value { } } - /// Sum value (could be of any compatible type, e.g. a predicate) + /// Sum value (could be of any compatible type - e.g., if `value` was a Tuple, a TupleSum type) pub fn sum(tag: usize, value: Value) -> Self { Self::Sum { tag,