From 8613c8643e7da42aa7e03d583540ba80ccd75264 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Tue, 18 Jul 2023 16:39:44 +0300 Subject: [PATCH 01/26] feat: Initial work on rewriting closures to regular functions with hidden env This commit implements the following mechanism: On a line where a lambda expression is encountered, we initialize a tuple for the captured lambda environment and we rewrite the lambda to a regular function taking this environment as an additional parameter. All calls to the closure are then modified to insert this hidden parameter. In other words, the following code: ``` let x = some_value; let closure = |a| x + a; println(closure(10)); println(closure(20)); ``` is rewritten to: ``` fn closure(env: (Field,), a: Field) -> Field { env.0 + a } let x = some_value; let closure_env = (x,); println(closure(closure_env, 10)); println(closure(closure_env, 20)); ``` In the presence of nested closures, we propagate the captured variables implicitly through all intermediate closures: ``` let x = some_value; let closure = |a, c| # here, `x` is initialized from the hidden env of the outer closure let inner_closure = |b| a + b + x inner_closure(c) ``` To make these transforms possible, the following changes were made to the logic of the HIR resolver and the monomorphization pass: * In the HIR resolver pass, the code determines the precise list of variables captured by each lambda. Along with the list, we compute the index of each captured var within the parent closure's environment (when the capture is propagated). * Introduction of a new `Closure` type in order to be able to recognize the call-sites that need the automatic environment variable treatment. It's a bit unfortunate that the Closure type is defined within the `AST` modules that are used to describe the output of the monomorphization pass, because we aim to eliminate all closures during the pass. A better solution would have been possible if the type check pass after HIR resolution was outputting types specific to the HIR pass (then the closures would exist only within this separate non-simplified type system). * The majority of the work is in the Lambda processing step in the monomorphizer which performs the necessary transformations based on the above information. Remaining things to do: * There are a number of pending TODO items for various minor unresolved loose ends in the code. * There are a lot of possible additional tests to be written. * Update docs --- .../src/hir/resolution/resolver.rs | 178 ++++++++++++++---- .../noirc_frontend/src/hir/type_check/expr.rs | 83 +++++--- .../noirc_frontend/src/hir/type_check/mod.rs | 23 +++ crates/noirc_frontend/src/hir_def/expr.rs | 16 ++ crates/noirc_frontend/src/hir_def/types.rs | 17 ++ .../src/monomorphization/ast.rs | 56 +++++- .../src/monomorphization/mod.rs | 133 +++++++++++-- crates/noirc_frontend/src/node_interner.rs | 3 +- 8 files changed, 429 insertions(+), 80 deletions(-) diff --git a/crates/noirc_frontend/src/hir/resolution/resolver.rs b/crates/noirc_frontend/src/hir/resolution/resolver.rs index 8b4f97dbd8e..e5363311a49 100644 --- a/crates/noirc_frontend/src/hir/resolution/resolver.rs +++ b/crates/noirc_frontend/src/hir/resolution/resolver.rs @@ -12,10 +12,10 @@ // // XXX: Resolver does not check for unused functions use crate::hir_def::expr::{ - HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCastExpression, - HirConstructorExpression, HirExpression, HirForExpression, HirIdent, HirIfExpression, - HirIndexExpression, HirInfixExpression, HirLambda, HirLiteral, HirMemberAccess, - HirMethodCallExpression, HirPrefixExpression, + HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCapturedVar, + HirCastExpression, HirConstructorExpression, HirExpression, HirForExpression, HirIdent, + HirIfExpression, HirIndexExpression, HirInfixExpression, HirLambda, HirLiteral, + HirMemberAccess, HirMethodCallExpression, HirPrefixExpression, }; use crate::token::Attribute; use regex::Regex; @@ -58,6 +58,11 @@ type Scope = GenericScope; type ScopeTree = GenericScopeTree; type ScopeForest = GenericScopeForest; +pub struct LambdaContext { + captures: Vec, + scope_index: usize, +} + /// The primary jobs of the Resolver are to validate that every variable found refers to exactly 1 /// definition in scope, and to convert the AST into the HIR. /// @@ -81,12 +86,10 @@ pub struct Resolver<'a> { /// were declared in. generics: Vec<(Rc, TypeVariable, Span)>, - /// Lambdas share the function scope of the function they're defined in, - /// so to identify whether they use any variables from the parent function - /// we keep track of the scope index a variable is declared in. When a lambda - /// is declared we push a scope and set this lambda_index to the scope index. - /// Any variable from a scope less than that must be from the parent function. - lambda_index: usize, + /// When resolving lambda expressions, we need to keep track of the variables + /// that are captured. We do this in order to create the hidden environment + /// parameter for the lambda function. + lambda_stack: Vec, } /// ResolverMetas are tagged onto each definition to track how many times they are used @@ -112,7 +115,7 @@ impl<'a> Resolver<'a> { self_type: None, generics: Vec::new(), errors: Vec::new(), - lambda_index: 0, + lambda_stack: Vec::new(), file, } } @@ -125,10 +128,6 @@ impl<'a> Resolver<'a> { self.errors.push(err); } - fn current_lambda_index(&self) -> usize { - self.scopes.current_scope_index() - } - /// Resolving a function involves interning the metadata /// interning any statements inside of the function /// and interning the function itself @@ -279,25 +278,25 @@ impl<'a> Resolver<'a> { // // If a variable is not found, then an error is logged and a dummy id // is returned, for better error reporting UX - fn find_variable_or_default(&mut self, name: &Ident) -> HirIdent { + fn find_variable_or_default(&mut self, name: &Ident) -> (HirIdent, usize) { self.find_variable(name).unwrap_or_else(|error| { self.push_err(error); let id = DefinitionId::dummy_id(); let location = Location::new(name.span(), self.file); - HirIdent { location, id } + (HirIdent { location, id }, 0) }) } - fn find_variable(&mut self, name: &Ident) -> Result { + fn find_variable(&mut self, name: &Ident) -> Result<(HirIdent, usize), ResolverError> { // Find the definition for this Ident let scope_tree = self.scopes.current_scope_tree(); let variable = scope_tree.find(&name.0.contents); let location = Location::new(name.span(), self.file); - if let Some((variable_found, _)) = variable { + if let Some((variable_found, scope)) = variable { variable_found.num_times_used += 1; let id = variable_found.ident.id; - Ok(HirIdent { location, id }) + Ok((HirIdent { location, id }, scope)) } else { Err(ResolverError::VariableNotDeclared { name: name.0.contents.clone(), @@ -517,24 +516,24 @@ impl<'a> Resolver<'a> { } } - fn get_ident_from_path(&mut self, path: Path) -> HirIdent { + fn get_ident_from_path(&mut self, path: Path) -> (HirIdent, usize) { let location = Location::new(path.span(), self.file); let error = match path.as_ident().map(|ident| self.find_variable(ident)) { - Some(Ok(ident)) => return ident, + Some(Ok(found)) => return found, // Try to look it up as a global, but still issue the first error if we fail Some(Err(error)) => match self.lookup_global(path) { - Ok(id) => return HirIdent { location, id }, + Ok(id) => return (HirIdent { location, id }, 0), Err(_) => error, }, None => match self.lookup_global(path) { - Ok(id) => return HirIdent { location, id }, + Ok(id) => return (HirIdent { location, id }, 0), Err(error) => error, }, }; self.push_err(error); let id = DefinitionId::dummy_id(); - HirIdent { location, id } + (HirIdent { location, id }, 0) } /// Translates an UnresolvedType to a Type @@ -837,12 +836,15 @@ impl<'a> Resolver<'a> { Self::find_numeric_generics_in_type(field, found); } } + Type::Function(parameters, return_type) => { for parameter in parameters { Self::find_numeric_generics_in_type(parameter, found); } Self::find_numeric_generics_in_type(return_type, found); } + Type::Closure(func) => Self::find_numeric_generics_in_type(func, found), + Type::Struct(struct_type, generics) => { for (i, generic) in generics.iter().enumerate() { if let Type::NamedGeneric(type_variable, name) = generic { @@ -915,7 +917,7 @@ impl<'a> Resolver<'a> { fn resolve_lvalue(&mut self, lvalue: LValue) -> HirLValue { match lvalue { LValue::Ident(ident) => { - HirLValue::Ident(self.find_variable_or_default(&ident), Type::Error) + HirLValue::Ident(self.find_variable_or_default(&ident).0, Type::Error) } LValue::MemberAccess { object, field_name } => { let object = Box::new(self.resolve_lvalue(*object)); @@ -965,7 +967,52 @@ impl<'a> Resolver<'a> { // Otherwise, then it is referring to an Identifier // This lookup allows support of such statements: let x = foo::bar::SOME_GLOBAL + 10; // If the expression is a singular indent, we search the resolver's current scope as normal. - let hir_ident = self.get_ident_from_path(path); + let (hir_ident, var_scope_index) = self.get_ident_from_path(path); + + if hir_ident.id != DefinitionId::dummy_id() { + match self.interner.definition(hir_ident.id).kind { + DefinitionKind::Function(_) => {} + DefinitionKind::Global(_) => {} + DefinitionKind::GenericType(_) => {} + // We ignore the above definition kinds because only local variables can be captured by closures. + DefinitionKind::Local(_) => { + let mut transitive_capture_index: Option = None; + + for lambda_index in 0..self.lambda_stack.len() { + if self.lambda_stack[lambda_index].scope_index > var_scope_index { + // Beware: the same variable may be captured multiple times, so we check + // for its presence before adding the capture below. + let pos = self.lambda_stack[lambda_index] + .captures + .iter() + .position(|capture| capture.ident.id == hir_ident.id); + + if pos.is_none() { + self.lambda_stack[lambda_index].captures.push( + HirCapturedVar { + ident: hir_ident, + transitive_capture_index, + }, + ); + } + + if lambda_index + 1 < self.lambda_stack.len() { + // There is more than one closure between the current scope and + // the scope of the variable, so this is a propagated capture. + // We need to track the transitive capture index as we go up in + // the closure stack. + transitive_capture_index = Some(pos.unwrap_or( + // If this was a fresh capture, we added it to the end of + // the captures vector: + self.lambda_stack[lambda_index].captures.len() - 1, + )) + } + } + } + } + } + } + HirExpression::Ident(hir_ident) } ExpressionKind::Prefix(prefix) => { @@ -1087,8 +1134,10 @@ impl<'a> Resolver<'a> { // We must stay in the same function scope as the parent function to allow for closures // to capture variables. This is currently limited to immutable variables. ExpressionKind::Lambda(lambda) => self.in_new_scope(|this| { - let new_index = this.current_lambda_index(); - let old_index = std::mem::replace(&mut this.lambda_index, new_index); + let scope_index = this.scopes.current_scope_index(); + + this.lambda_stack + .push(LambdaContext { captures: Vec::new(), scope_index: scope_index }); let parameters = vecmap(lambda.parameters, |(pattern, typ)| { let parameter = DefinitionKind::Local(None); @@ -1098,8 +1147,14 @@ impl<'a> Resolver<'a> { let return_type = this.resolve_inferred_type(lambda.return_type); let body = this.resolve_expression(lambda.body); - this.lambda_index = old_index; - HirExpression::Lambda(HirLambda { parameters, return_type, body }) + let lambda_context = this.lambda_stack.pop().unwrap(); + + HirExpression::Lambda(HirLambda { + parameters, + return_type, + body, + captures: lambda_context.captures, + }) }), }; @@ -1411,6 +1466,7 @@ pub fn verify_mutable_reference(interner: &NodeInterner, rhs: ExprId) -> Result< #[cfg(test)] mod test { + use core::panic; use std::collections::HashMap; use fm::FileId; @@ -1434,7 +1490,9 @@ mod test { // and functions can be forward declared fn resolve_src_code(src: &str, func_namespace: Vec<&str>) -> Vec { let (program, errors) = parse_program(src); - assert!(errors.is_empty()); + if !errors.is_empty() { + panic!("Unexpected parse errors in test code: {:?}", errors); + } let mut interner = NodeInterner::default(); @@ -1656,9 +1714,61 @@ mod test { x } "#; + let errors = resolve_src_code(src, vec!["main", "foo"]); + if !errors.is_empty() { + println!("Unexpected errors: {:?}", errors); + assert!(false); // there should be no errors + } + } + + fn resolve_basic_closure() { + let src = r#" + fn main(x : Field) -> pub Field { + let closure = |y| y + x; + closure(x) + } + "#; + + let errors = resolve_src_code(src, vec!["main", "foo"]); + if !errors.is_empty() { + println!("Unexpected errors: {:?}", errors); + assert!(false); // there should be no errors + } + } + + #[test] + fn resolve_complex_closures() { + let src = r#" + fn main(x: Field) -> pub Field { + let closure_without_captures = |x| x + x; + let a = closure_without_captures(1); + + let closure_capturing_a_param = |y| y + x; + let b = closure_capturing_a_param(2); + + let closure_capturing_a_local_var = |y| y + b; + let c = closure_capturing_a_local_var(3); + + let closure_with_transitive_captures = |y| { + let d = 5; + let nested_closure = |z| { + let doubly_nested_closure = |w| w + x + b; + a + z + y + d + x + doubly_nested_closure(4) + x + y + }; + let res = nested_closure(5); + res + }; + + a + b + c + closure_with_transitive_captures(6) + } + "#; let errors = resolve_src_code(src, vec!["main", "foo"]); assert!(errors.is_empty()); + if !errors.is_empty() { + println!("Unexpected errors: {:?}", errors); + assert!(false); // there should be no errors + } } #[test] @@ -1694,6 +1804,10 @@ mod test { } } + // TODO: Create a more sophisticated set of search functions over the HIR, so we can check + // that the correct variables are captured in each closure + + fn path_unresolved_error(err: ResolverError, expected_unresolved_path: &str) { match err { ResolverError::PathResolutionError(PathResolutionError::Unresolved(name)) => { diff --git a/crates/noirc_frontend/src/hir/type_check/expr.rs b/crates/noirc_frontend/src/hir/type_check/expr.rs index 24ac5f3443e..c3106b2b56a 100644 --- a/crates/noirc_frontend/src/hir/type_check/expr.rs +++ b/crates/noirc_frontend/src/hir/type_check/expr.rs @@ -279,11 +279,19 @@ impl<'interner> TypeChecker<'interner> { Type::Tuple(vecmap(&elements, |elem| self.check_expression(elem))) } HirExpression::Lambda(lambda) => { - let params = vecmap(lambda.parameters, |(pattern, typ)| { - self.bind_pattern(&pattern, typ.clone()); + let captured_vars = vecmap(lambda.captures, |capture| { + let typ = self.interner.id_type(capture.ident.id); typ }); + let env_type = Type::Tuple(captured_vars); + let mut params = vec![env_type]; + + for (pattern, typ) in lambda.parameters { + self.bind_pattern(&pattern, typ.clone()); + params.push(typ); + } + let actual_return = self.check_expression(&lambda.body); let span = self.interner.expr_span(&lambda.body); @@ -294,7 +302,9 @@ impl<'interner> TypeChecker<'interner> { expr_span: span, } }); - Type::Function(params, Box::new(lambda.return_type)) + + let function_type = Type::Function(params, Box::new(lambda.return_type)); + Type::Closure(Box::new(function_type)) } }; @@ -870,11 +880,43 @@ impl<'interner> TypeChecker<'interner> { } } + fn bind_function_type_impl( + &mut self, + fn_params: &Vec, + fn_ret: &Type, + callsite_args: &Vec<(Type, ExprId, Span)>, + span: Span, + skip_params: usize, + ) -> Type { + let real_fn_params_count = fn_params.len() - skip_params; + + if real_fn_params_count != callsite_args.len() { + self.errors.push(TypeCheckError::ParameterCountMismatch { + expected: real_fn_params_count, + found: callsite_args.len(), + span: span + }); + return Type::Error; + } + + for (param, (arg, _, arg_span)) in fn_params.iter().skip(skip_params).zip(callsite_args) { + arg.make_subtype_of(param, *arg_span, &mut self.errors, || { + TypeCheckError::TypeMismatch { + expected_typ: param.to_string(), + expr_typ: arg.to_string(), + expr_span: *arg_span, + } + }); + } + + fn_ret.clone() + } + fn bind_function_type( &mut self, function: Type, args: Vec<(Type, ExprId, Span)>, - span: Span, + span: Span ) -> Type { // Could do a single unification for the entire function type, but matching beforehand // lets us issue a more precise error on the individual argument that fails to type check. @@ -894,31 +936,14 @@ impl<'interner> TypeChecker<'interner> { ret } Type::Function(parameters, ret) => { - if parameters.len() != args.len() { - self.errors.push(TypeCheckError::ParameterCountMismatch { - expected: parameters.len(), - found: args.len(), - span, - }); - return Type::Error; - } - - for (param, (arg, arg_id, arg_span)) in parameters.iter().zip(args) { - arg.make_subtype_with_coercions( - param, - arg_id, - self.interner, - &mut self.errors, - || TypeCheckError::TypeMismatch { - expected_typ: param.to_string(), - expr_typ: arg.to_string(), - expr_span: arg_span, - }, - ); - } - - *ret - } + self.bind_function_type_impl( + parameters.as_ref(), + ret.as_ref(), + args.as_ref(), + span, + 0, + ) + }, Type::Error => Type::Error, found => { self.errors.push(TypeCheckError::ExpectedFunction { found, span }); diff --git a/crates/noirc_frontend/src/hir/type_check/mod.rs b/crates/noirc_frontend/src/hir/type_check/mod.rs index 26d0e36abf9..9ab581cddca 100644 --- a/crates/noirc_frontend/src/hir/type_check/mod.rs +++ b/crates/noirc_frontend/src/hir/type_check/mod.rs @@ -152,6 +152,7 @@ impl<'interner> TypeChecker<'interner> { #[cfg(test)] mod test { use std::collections::HashMap; + use std::vec; use fm::FileId; use iter_extended::vecmap; @@ -314,7 +315,29 @@ mod test { type_check_src_code(src, vec![String::from("main"), String::from("foo")]); } + #[test] + fn basic_closure() { + let src = r#" + fn main(x : Field) -> pub Field { + let closure = |y| y + x; + closure(x) + } + "#; + + type_check_src_code(src, vec![String::from("main"), String::from("foo")]); + } + #[test] + fn closure_with_no_args() { + let src = r#" + fn main(x : Field) -> pub Field { + let closure = || x; + closure() + } + "#; + + type_check_src_code(src, vec![String::from("main")]); + } // This is the same Stub that is in the resolver, maybe we can pull this out into a test module and re-use? struct TestPathResolver(HashMap); diff --git a/crates/noirc_frontend/src/hir_def/expr.rs b/crates/noirc_frontend/src/hir_def/expr.rs index db7db0a803d..fd980328f5f 100644 --- a/crates/noirc_frontend/src/hir_def/expr.rs +++ b/crates/noirc_frontend/src/hir_def/expr.rs @@ -197,9 +197,25 @@ impl HirBlockExpression { } } +/// A variable captured inside a closure +#[derive(Debug, Clone)] +pub struct HirCapturedVar { + pub ident: HirIdent, + + /// This will be None when the capture refers to a local variable declared + /// in the same scope as the closure. In a closure-inside-another-closure + /// scenarios, we might have a transitive captures of variables that must + /// be propagated during the construction of each closure. In this case, + /// we store the index of the captured variable in the environment of our + /// direct parent closure. We do this in order to simplify the HIR to AST + /// transformation in the monomorphization pass. + pub transitive_capture_index: Option, +} + #[derive(Debug, Clone)] pub struct HirLambda { pub parameters: Vec<(HirPattern, Type)>, pub return_type: Type, pub body: ExprId, + pub captures: Vec, } diff --git a/crates/noirc_frontend/src/hir_def/types.rs b/crates/noirc_frontend/src/hir_def/types.rs index ff0a4e53fae..76c41a2c86f 100644 --- a/crates/noirc_frontend/src/hir_def/types.rs +++ b/crates/noirc_frontend/src/hir_def/types.rs @@ -73,6 +73,11 @@ pub enum Type { /// A functions with arguments, and a return type. Function(Vec, Box), + /// A closure (a pair of a function pointer and a tuple of captured variables). + /// Stores the underlying function type, which has been modifies such that the + /// first parameter is the type of the captured variables tuple. + Closure(Box), + /// &mut T MutableReference(Box), @@ -701,6 +706,7 @@ impl Type { parameters.iter().any(|parameter| parameter.contains_numeric_typevar(target_id)) || return_type.contains_numeric_typevar(target_id) } + Type::Closure(func) => func.contains_numeric_typevar(target_id), Type::Struct(struct_type, generics) => { generics.iter().enumerate().any(|(i, generic)| { if named_generic_id_matches_target(generic) { @@ -801,6 +807,9 @@ impl std::fmt::Display for Type { let args = vecmap(args, ToString::to_string); write!(f, "fn({}) -> {}", args.join(", "), ret) } + Type::Closure(func) => { + write!(f, "closure {}", func) // i.e. we produce a string such as "closure fn(args) -> ret" + } Type::MutableReference(element) => { write!(f, "&mut {element}") } @@ -1506,6 +1515,7 @@ impl Type { Type::NamedGeneric(..) => unreachable!(), Type::Forall(..) => unreachable!(), Type::Function(_, _) => unreachable!(), + Type::Closure(_) => unreachable!(), Type::MutableReference(_) => unreachable!("&mut cannot be used in the abi"), Type::NotConstant => unreachable!(), } @@ -1625,6 +1635,10 @@ impl Type { let ret = Box::new(ret.substitute(type_bindings)); Type::Function(args, ret) } + Type::Closure(func) => { + let func = Box::new(func.substitute(type_bindings)); + Type::Closure(func) + } Type::MutableReference(element) => { Type::MutableReference(Box::new(element.substitute(type_bindings))) } @@ -1663,6 +1677,7 @@ impl Type { Type::Function(args, ret) => { args.iter().any(|arg| arg.occurs(target_id)) || ret.occurs(target_id) } + Type::Closure(func) => func.occurs(target_id), Type::MutableReference(element) => element.occurs(target_id), Type::FieldElement(_) @@ -1711,6 +1726,8 @@ impl Type { let ret = Box::new(ret.follow_bindings()); Function(args, ret) } + Closure(func) => Closure(Box::new(func.follow_bindings())), + MutableReference(element) => MutableReference(Box::new(element.follow_bindings())), // Expect that this function should only be called on instantiated types diff --git a/crates/noirc_frontend/src/monomorphization/ast.rs b/crates/noirc_frontend/src/monomorphization/ast.rs index 7ad05f09231..c017d1d9102 100644 --- a/crates/noirc_frontend/src/monomorphization/ast.rs +++ b/crates/noirc_frontend/src/monomorphization/ast.rs @@ -29,7 +29,6 @@ pub enum Expression { Tuple(Vec), ExtractTupleField(Box, usize), Call(Call), - Let(Let), Constrain(Box, Location), Assign(Assign), @@ -103,6 +102,13 @@ pub struct Binary { pub location: Location, } +#[derive(Debug, Clone)] +pub struct Lambda { + pub function: Ident, + pub env: Ident, + pub typ: Type, // TODO: Perhaps this is not necessary +} + #[derive(Debug, Clone)] pub struct If { pub condition: Box, @@ -225,6 +231,54 @@ impl Type { } } +pub fn type_of_lvalue(lvalue: &LValue) -> Type { + match lvalue { + LValue::Ident(ident) => ident.typ.clone(), + LValue::Index { element_type, .. } => element_type.clone(), + LValue::MemberAccess { object, field_index } => { + let tuple_type = type_of_lvalue(object.as_ref()); + match tuple_type { + Type::Tuple(fields) => fields[*field_index].clone(), + _ => unreachable!("ICE: Member access on non-tuple type"), + } + } + LValue::Dereference { element_type, .. } => element_type.clone(), + } +} + +pub fn type_of(expr: &Expression) -> Type { + match expr { + Expression::Ident(ident) => ident.typ.clone(), + Expression::Literal(lit) => match lit { + Literal::Integer(_, typ) => typ.clone(), + Literal::Bool(_) => Type::Bool, + Literal::Str(str) => Type::String(str.len() as u64), + Literal::Array(array) => { + // TODO + Type::Array(array.contents.len() as u64, Box::new(Type::Unit)) + }, + Literal::FmtStr(_, _, _) => unimplemented!() + }, + Expression::Block(stmts) => type_of(stmts.last().unwrap()), + Expression::Unary(unary) => unary.result_type.clone(), + Expression::Binary(_binary) => unreachable!("TODO: How do we get the type of a Binary op"), + Expression::Index(index) => index.element_type.clone(), + Expression::Cast(cast) => cast.r#type.clone(), + Expression::For(_for_expr) => unreachable!("TODO: How do we get the type of a for loop?"), + Expression::If(if_expr) => if_expr.typ.clone(), + Expression::Tuple(elements) => Type::Tuple(elements.iter().map(type_of).collect()), + Expression::ExtractTupleField(tuple, index) => match tuple.as_ref() { + Expression::Tuple(fields) => type_of(&fields[*index]), + _ => unreachable!("ICE: Tuple field access on non-tuple type"), + }, + Expression::Call(call) => call.return_type.clone(), + Expression::Let(let_stmt) => type_of(let_stmt.expression.as_ref()), + Expression::Constrain(contraint, _) => type_of(contraint.as_ref()), + Expression::Assign(assign) => type_of_lvalue(&assign.lvalue), + Expression::Semi(expr) => type_of(expr.as_ref()), // TODO: Is this correct? + } +} + #[derive(Debug, Clone)] pub struct Program { pub functions: Vec, diff --git a/crates/noirc_frontend/src/monomorphization/mod.rs b/crates/noirc_frontend/src/monomorphization/mod.rs index dbe2ee080bf..19508a26ed2 100644 --- a/crates/noirc_frontend/src/monomorphization/mod.rs +++ b/crates/noirc_frontend/src/monomorphization/mod.rs @@ -30,6 +30,11 @@ use self::ast::{Definition, FuncId, Function, LocalId, Program}; pub mod ast; pub mod printer; +struct LambdaContext { + env_ident: Box, + captures: Vec, +} + /// The context struct for the monomorphization pass. /// /// This struct holds the FIFO queue of functions to monomorphize, which is added to @@ -58,6 +63,8 @@ struct Monomorphizer<'interner> { /// Used to reference existing definitions in the HIR interner: &'interner NodeInterner, + lambda_envs_stack: Vec, + next_local_id: u32, next_function_id: u32, } @@ -103,6 +110,7 @@ impl<'interner> Monomorphizer<'interner> { next_local_id: 0, next_function_id: 0, interner, + lambda_envs_stack: Vec::new(), } } @@ -541,6 +549,15 @@ impl<'interner> Monomorphizer<'interner> { ast::Expression::Block(definitions) } + /// Find a captured variable in the innermost closure + fn lookup_captured(&mut self, id: node_interner::DefinitionId) -> Option { + let ctx = self.lambda_envs_stack.last()?; + ctx.captures + .iter() + .position(|capture| capture.ident.id == id) + .map(|index| ast::Expression::ExtractTupleField(ctx.env_ident.clone(), index)) + } + /// A local (ie non-global) ident only fn local_ident(&mut self, ident: &HirIdent) -> Option { let definition = self.interner.definition(ident.id); @@ -568,10 +585,10 @@ impl<'interner> Monomorphizer<'interner> { ast::Expression::Ident(ident) } DefinitionKind::Global(expr_id) => self.expr(*expr_id), - DefinitionKind::Local(_) => { + DefinitionKind::Local(_) => self.lookup_captured(ident.id).unwrap_or_else(|| { let ident = self.local_ident(&ident).unwrap(); ast::Expression::Ident(ident) - } + }), DefinitionKind::GenericType(type_variable) => { let value = match &*type_variable.borrow() { TypeBinding::Unbound(_) => { @@ -663,6 +680,19 @@ impl<'interner> Monomorphizer<'interner> { ast::Type::Function(args, ret) } + HirType::Closure(func) => { + match func.as_ref() { + HirType::Function(arguments, return_type) => { + let converted_args = vecmap(arguments, Self::convert_type); + let converted_ret = Box::new(Self::convert_type(&return_type)); + let fn_type = ast::Type::Function(converted_args, converted_ret); + let env_type = ast::Type::Tuple(vec![]); // TODO compute this + ast::Type::Tuple(vec![env_type, fn_type]) + } + _ => unreachable!("Unexpected closure type {}", func), + } + } + HirType::MutableReference(element) => { let element = Self::convert_type(element); ast::Type::MutableReference(Box::new(element)) @@ -677,19 +707,23 @@ impl<'interner> Monomorphizer<'interner> { } } + fn is_function_closure(&self, func: &ast::Expression) -> bool { + matches!(ast::type_of(func), ast::Type::Tuple(_)) + } + fn function_call( &mut self, call: HirCallExpression, id: node_interner::ExprId, ) -> ast::Expression { - let func = Box::new(self.expr(call.func)); + let original_func = Box::new(self.expr(call.func)); let mut arguments = vecmap(&call.arguments, |id| self.expr(*id)); let hir_arguments = vecmap(&call.arguments, |id| self.interner.expression(id)); let return_type = self.interner.id_type(id); let return_type = Self::convert_type(&return_type); let location = call.location; - if let ast::Expression::Ident(ident) = func.as_ref() { + if let ast::Expression::Ident(ident) = original_func.as_ref() { if let Definition::Oracle(name) = &ident.definition { if name.as_str() == "println" { // Oracle calls are required to be wrapped in an unconstrained function @@ -699,6 +733,19 @@ impl<'interner> Monomorphizer<'interner> { } } + let is_closure = self.is_function_closure(&*original_func); + + let func = if is_closure { + Box::new(ast::Expression::ExtractTupleField(Box::new((*original_func).clone()), 1usize)) + } else { + original_func.clone() + }; + + if is_closure { + let env_argument = + ast::Expression::ExtractTupleField(Box::new((*original_func).clone()), 0usize); + arguments.insert(0, env_argument); + } self.try_evaluate_call(&func, &return_type).unwrap_or(ast::Expression::Call(ast::Call { func, arguments, @@ -924,27 +971,79 @@ impl<'interner> Monomorphizer<'interner> { Param(pattern, typ, noirc_abi::AbiVisibility::Private) })); - let parameters = self.parameters(parameters); - let body = self.expr(lambda.body); + let converted_parameters = self.parameters(parameters); let id = self.next_function_id(); - let return_type = ret_type.clone(); let name = lambda_name.to_owned(); - let unconstrained = false; + let return_type = ret_type.clone(); - let function = ast::Function { id, name, parameters, body, return_type, unconstrained }; - self.push_function(id, function); + let env_local_id = self.next_local_id(); + let env_name = "env"; + let env_tuple = ast::Expression::Tuple(vecmap(&lambda.captures, |capture| { + match capture.transitive_capture_index { + Some(field_index) => match self.lambda_envs_stack.last() { + Some(lambda_ctx) => ast::Expression::ExtractTupleField( + lambda_ctx.env_ident.clone(), + field_index, + ), + None => unreachable!( + "Expected to find a parent closure environment, but found none" + ), + }, + None => { + let ident = self.local_ident(&capture.ident).unwrap(); + ast::Expression::Ident(ident) + } + } + })); + let env_typ = ast::type_of(&env_tuple); - let typ = ast::Type::Function(parameter_types, Box::new(ret_type)); + let env_let_stmt = ast::Expression::Let(ast::Let { + id: env_local_id, + mutable: true, + name: env_name.to_string(), + expression: Box::new(env_tuple), + }); - let name = lambda_name.to_owned(); - ast::Expression::Ident(ast::Ident { + let location = None; // TODO: This should match the location of the lambda expression + let mutable = false; + let definition = Definition::Local(env_local_id); + + let env_ident = ast::Expression::Ident(ast::Ident { + location, + mutable, + definition, + name: env_name.to_string(), + typ: env_typ.clone(), + }); + + // TODO: Is this costly? Can we avoid the copies somehow? + self.lambda_envs_stack.push(LambdaContext { + env_ident: Box::new(env_ident.clone()), + captures: lambda.captures, + }); + let body = self.expr(lambda.body); + self.lambda_envs_stack.pop(); + + let lambda_fn_typ: ast::Type = ast::Type::Function(parameter_types, Box::new(ret_type)); + let lambda_fn = ast::Expression::Ident(ast::Ident { definition: Definition::Function(id), mutable: false, - location: None, - name, - typ, - }) + location: None, // TODO: This should match the location of the lambda expression + name: name.clone(), + typ: lambda_fn_typ, + }); + + let mut parameters = vec![]; + parameters.push((env_local_id, true, env_name.to_string(), env_typ)); + parameters.extend(converted_parameters); + + let unconstrained = false; + let function = ast::Function { id, name, parameters, body, return_type, unconstrained }; + self.push_function(id, function); + + let lambda_value = ast::Expression::Tuple(vec![env_ident, lambda_fn]); + ast::Expression::Block(vec![env_let_stmt, lambda_value]) } /// Implements std::unsafe::zeroed by returning an appropriate zeroed diff --git a/crates/noirc_frontend/src/node_interner.rs b/crates/noirc_frontend/src/node_interner.rs index f5fea5c1ea7..c6ff1d98e3f 100644 --- a/crates/noirc_frontend/src/node_interner.rs +++ b/crates/noirc_frontend/src/node_interner.rs @@ -683,6 +683,7 @@ fn get_type_method_key(typ: &Type) -> Option { | Type::Error | Type::NotConstant | Type::Struct(_, _) - | Type::FmtString(_, _) => None, + | Type::FmtString(_, _) + | Type::Closure(_) => None, // TODO: Is this correct? How do we add methods to functions? Can we do the same for closures? } } From cca011145e6aed2899be4dfb207dc080f23f2ab7 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Thu, 20 Jul 2023 16:42:53 +0300 Subject: [PATCH 02/26] refactor: use panic, instead of println+assert Co-authored-by: jfecher --- crates/noirc_frontend/src/hir/resolution/resolver.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crates/noirc_frontend/src/hir/resolution/resolver.rs b/crates/noirc_frontend/src/hir/resolution/resolver.rs index e5363311a49..3fe4669f1b6 100644 --- a/crates/noirc_frontend/src/hir/resolution/resolver.rs +++ b/crates/noirc_frontend/src/hir/resolution/resolver.rs @@ -1731,8 +1731,7 @@ mod test { let errors = resolve_src_code(src, vec!["main", "foo"]); if !errors.is_empty() { - println!("Unexpected errors: {:?}", errors); - assert!(false); // there should be no errors + panic!("Unexpected errors: {:?}", errors); } } From e586d9a095c2f8048b018c4e6b326303e294c82f Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Wed, 19 Jul 2023 20:47:45 +0300 Subject: [PATCH 03/26] test: add an initial monomorphization rewrite test a lot of the machinery is copied from similar existing tests the original authors also note some of those can be refactored in something reusable --- .../src/monomorphization/mod.rs | 158 ++++++++++++++++++ 1 file changed, 158 insertions(+) diff --git a/crates/noirc_frontend/src/monomorphization/mod.rs b/crates/noirc_frontend/src/monomorphization/mod.rs index 19508a26ed2..592426ecb95 100644 --- a/crates/noirc_frontend/src/monomorphization/mod.rs +++ b/crates/noirc_frontend/src/monomorphization/mod.rs @@ -1171,3 +1171,161 @@ fn undo_instantiation_bindings(bindings: TypeBindings) { *var.borrow_mut() = TypeBinding::Unbound(id); } } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use fm::FileId; + use iter_extended::vecmap; + + use crate::{ + graph::CrateId, + hir::{ + def_map::{ + CrateDefMap, LocalModuleId, ModuleData, ModuleDefId, ModuleId, ModuleOrigin, + }, + resolution::{ + import::PathResolutionError, path_resolver::PathResolver, resolver::Resolver, + }, + }, + hir_def::function::HirFunction, + node_interner::{FuncId, NodeInterner}, + parse_program, + }; + + use super::monomorphize; + + // TODO: refactor into a more general test utility? + // mostly copied from hir / type_check / mod.rs and adapted a bit + fn type_check_src_code(src: &str, func_namespace: Vec) -> (FuncId, NodeInterner) { + let (program, errors) = parse_program(src); + let mut interner = NodeInterner::default(); + + // Using assert_eq here instead of assert(errors.is_empty()) displays + // the whole vec if the assert fails rather than just two booleans + assert_eq!(errors, vec![]); + + let main_id = interner.push_fn(HirFunction::empty()); + interner.push_function_definition("main".into(), main_id); + + let func_ids = vecmap(&func_namespace, |name| { + let id = interner.push_fn(HirFunction::empty()); + interner.push_function_definition(name.into(), id); + id + }); + + let mut path_resolver = TestPathResolver(HashMap::new()); + for (name, id) in func_namespace.into_iter().zip(func_ids.clone()) { + path_resolver.insert_func(name.to_owned(), id); + } + + let mut def_maps: HashMap = HashMap::new(); + let file = FileId::default(); + + let mut modules = arena::Arena::new(); + modules.insert(ModuleData::new(None, ModuleOrigin::File(file), false)); + + def_maps.insert( + CrateId::dummy_id(), + CrateDefMap { + root: path_resolver.local_module_id(), + modules, + krate: CrateId::dummy_id(), + extern_prelude: HashMap::new(), + }, + ); + + let func_meta = vecmap(program.functions, |nf| { + let resolver = Resolver::new(&mut interner, &path_resolver, &def_maps, file); + let (hir_func, func_meta, _resolver_errors) = resolver.resolve_function(nf, main_id); + // TODO: not sure why, we do get an error here, + // but otherwise seem to get an ok monomorphization result + // assert_eq!(resolver_errors, vec![]); + (hir_func, func_meta) + }); + + println!("Before update_fn"); + + for ((hir_func, meta), func_id) in func_meta.into_iter().zip(func_ids.clone()) { + interner.update_fn(func_id, hir_func); + interner.push_fn_meta(meta, func_id); + } + + println!("Before type_check_func"); + + // Type check section + let errors = crate::hir::type_check::type_check_func( + &mut interner, + func_ids.first().cloned().unwrap(), + ); + assert_eq!(errors, vec![]); + (func_ids.first().cloned().unwrap(), interner) + } + + // TODO: refactor into a more general test utility? + // TestPathResolver struct and impls copied from hir / type_check / mod.rs + struct TestPathResolver(HashMap); + + impl PathResolver for TestPathResolver { + fn resolve( + &self, + _def_maps: &HashMap, + path: crate::Path, + ) -> Result { + // Not here that foo::bar and hello::foo::bar would fetch the same thing + let name = path.segments.last().unwrap(); + let mod_def = self.0.get(&name.0.contents).cloned(); + mod_def.ok_or_else(move || PathResolutionError::Unresolved(name.clone())) + } + + fn local_module_id(&self) -> LocalModuleId { + // This is not LocalModuleId::dummy since we need to use this to index into a Vec + // later and do not want to push u32::MAX number of elements before we do. + LocalModuleId(arena::Index::from_raw_parts(0, 0)) + } + + fn module_id(&self) -> ModuleId { + ModuleId { krate: CrateId::dummy_id(), local_id: self.local_module_id() } + } + } + + impl TestPathResolver { + fn insert_func(&mut self, name: String, func_id: FuncId) { + self.0.insert(name, func_id.into()); + } + } + + // a helper test method + // TODO: maybe just compare trimmed src/expected + // for easier formatting? + fn check_rewrite(src: &str, expected: &str) { + let (func, interner) = type_check_src_code(src, vec!["main".to_string()]); + let program = monomorphize(func, &interner); + // println!("[{}]", program); + assert!(format!("{}", program) == expected); + } + + #[test] + fn simple_closure_with_no_captured_variables() { + let src = r#" + fn main() -> Field { + let closure = |x| x; + closure(0) + } + "#; + + let expected_rewrite = r#"fn main$f0() -> Field { + let closure$2 = { + let env$1 = (); + (env$l1, lambda$f1) + }; + closure$l2.1(closure$l2.0, 0) +} +fn lambda$f1(mut env$l1: (), x$l0: Field) -> Field { + x$l0 +} +"#; + check_rewrite(src, expected_rewrite); + } +} From e642e71e318f682c90b03e3005a99d0c1d28dbcc Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Thu, 20 Jul 2023 17:00:33 +0300 Subject: [PATCH 04/26] fix: address some PR comments: comment/refactor/small fixes --- .../src/hir/resolution/resolver.rs | 69 ++++++++++--------- .../src/monomorphization/ast.rs | 4 +- .../src/monomorphization/mod.rs | 7 +- 3 files changed, 41 insertions(+), 39 deletions(-) diff --git a/crates/noirc_frontend/src/hir/resolution/resolver.rs b/crates/noirc_frontend/src/hir/resolution/resolver.rs index 3fe4669f1b6..28338597730 100644 --- a/crates/noirc_frontend/src/hir/resolution/resolver.rs +++ b/crates/noirc_frontend/src/hir/resolution/resolver.rs @@ -60,6 +60,8 @@ type ScopeForest = GenericScopeForest; pub struct LambdaContext { captures: Vec, + /// the index in the scope tree + /// (sometimes being filled by ScopeTree's find method) scope_index: usize, } @@ -935,6 +937,39 @@ impl<'a> Resolver<'a> { } } + fn resolve_local_variable(&mut self, hir_ident: HirIdent, var_scope_index: usize) { + let mut transitive_capture_index: Option = None; + + for lambda_index in 0..self.lambda_stack.len() { + if self.lambda_stack[lambda_index].scope_index > var_scope_index { + // Beware: the same variable may be captured multiple times, so we check + // for its presence before adding the capture below. + let pos = self.lambda_stack[lambda_index] + .captures + .iter() + .position(|capture| capture.ident.id == hir_ident.id); + + if pos.is_none() { + self.lambda_stack[lambda_index] + .captures + .push(HirCapturedVar { ident: hir_ident, transitive_capture_index }); + } + + if lambda_index + 1 < self.lambda_stack.len() { + // There is more than one closure between the current scope and + // the scope of the variable, so this is a propagated capture. + // We need to track the transitive capture index as we go up in + // the closure stack. + transitive_capture_index = Some(pos.unwrap_or( + // If this was a fresh capture, we added it to the end of + // the captures vector: + self.lambda_stack[lambda_index].captures.len() - 1, + )) + } + } + } + } + pub fn resolve_expression(&mut self, expr: Expression) -> ExprId { let hir_expr = match expr.kind { ExpressionKind::Literal(literal) => HirExpression::Literal(match literal { @@ -976,39 +1011,7 @@ impl<'a> Resolver<'a> { DefinitionKind::GenericType(_) => {} // We ignore the above definition kinds because only local variables can be captured by closures. DefinitionKind::Local(_) => { - let mut transitive_capture_index: Option = None; - - for lambda_index in 0..self.lambda_stack.len() { - if self.lambda_stack[lambda_index].scope_index > var_scope_index { - // Beware: the same variable may be captured multiple times, so we check - // for its presence before adding the capture below. - let pos = self.lambda_stack[lambda_index] - .captures - .iter() - .position(|capture| capture.ident.id == hir_ident.id); - - if pos.is_none() { - self.lambda_stack[lambda_index].captures.push( - HirCapturedVar { - ident: hir_ident, - transitive_capture_index, - }, - ); - } - - if lambda_index + 1 < self.lambda_stack.len() { - // There is more than one closure between the current scope and - // the scope of the variable, so this is a propagated capture. - // We need to track the transitive capture index as we go up in - // the closure stack. - transitive_capture_index = Some(pos.unwrap_or( - // If this was a fresh capture, we added it to the end of - // the captures vector: - self.lambda_stack[lambda_index].captures.len() - 1, - )) - } - } - } + self.resolve_local_variable(hir_ident, var_scope_index); } } } diff --git a/crates/noirc_frontend/src/monomorphization/ast.rs b/crates/noirc_frontend/src/monomorphization/ast.rs index c017d1d9102..95b95c38525 100644 --- a/crates/noirc_frontend/src/monomorphization/ast.rs +++ b/crates/noirc_frontend/src/monomorphization/ast.rs @@ -264,7 +264,7 @@ pub fn type_of(expr: &Expression) -> Type { Expression::Binary(_binary) => unreachable!("TODO: How do we get the type of a Binary op"), Expression::Index(index) => index.element_type.clone(), Expression::Cast(cast) => cast.r#type.clone(), - Expression::For(_for_expr) => unreachable!("TODO: How do we get the type of a for loop?"), + Expression::For(_for_expr) => Type::Unit, Expression::If(if_expr) => if_expr.typ.clone(), Expression::Tuple(elements) => Type::Tuple(elements.iter().map(type_of).collect()), Expression::ExtractTupleField(tuple, index) => match tuple.as_ref() { @@ -275,7 +275,7 @@ pub fn type_of(expr: &Expression) -> Type { Expression::Let(let_stmt) => type_of(let_stmt.expression.as_ref()), Expression::Constrain(contraint, _) => type_of(contraint.as_ref()), Expression::Assign(assign) => type_of_lvalue(&assign.lvalue), - Expression::Semi(expr) => type_of(expr.as_ref()), // TODO: Is this correct? + Expression::Semi(_expr) => Type::Unit, } } diff --git a/crates/noirc_frontend/src/monomorphization/mod.rs b/crates/noirc_frontend/src/monomorphization/mod.rs index 592426ecb95..3b8901b079b 100644 --- a/crates/noirc_frontend/src/monomorphization/mod.rs +++ b/crates/noirc_frontend/src/monomorphization/mod.rs @@ -971,7 +971,7 @@ impl<'interner> Monomorphizer<'interner> { Param(pattern, typ, noirc_abi::AbiVisibility::Private) })); - let converted_parameters = self.parameters(parameters); + let mut converted_parameters = self.parameters(parameters); let id = self.next_function_id(); let name = lambda_name.to_owned(); @@ -1000,7 +1000,7 @@ impl<'interner> Monomorphizer<'interner> { let env_let_stmt = ast::Expression::Let(ast::Let { id: env_local_id, - mutable: true, + mutable: false, name: env_name.to_string(), expression: Box::new(env_tuple), }); @@ -1017,7 +1017,6 @@ impl<'interner> Monomorphizer<'interner> { typ: env_typ.clone(), }); - // TODO: Is this costly? Can we avoid the copies somehow? self.lambda_envs_stack.push(LambdaContext { env_ident: Box::new(env_ident.clone()), captures: lambda.captures, @@ -1036,7 +1035,7 @@ impl<'interner> Monomorphizer<'interner> { let mut parameters = vec![]; parameters.push((env_local_id, true, env_name.to_string(), env_typ)); - parameters.extend(converted_parameters); + parameters.append(&mut converted_parameters); let unconstrained = false; let function = ast::Function { id, name, parameters, body, return_type, unconstrained }; From 77794e8e1c9c20f01750509c22d2e1bbe8625ba2 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Thu, 27 Jul 2023 22:01:06 +0300 Subject: [PATCH 05/26] fix: use an unified Function object, fix some problems, comments --- .../closures_mut_ref/Nargo.toml | 6 + .../closures_mut_ref/Prover.toml | 1 + .../closures_mut_ref/src/main.nr | 19 ++ .../fibonacci_by_ref/Nargo.toml | 6 + .../fibonacci_by_ref/Prover.toml | 2 + .../fibonacci_by_ref/src/main.nr | 15 ++ .../higher_order_fn_selector/Nargo.toml | 6 + .../higher_order_fn_selector/src/main.nr | 49 ++++++ .../higher_order_functions/Nargo.toml | 6 + .../higher_order_functions/Prover.toml | 0 .../higher_order_functions/src/main.nr | 88 ++++++++++ .../higher_order_functions/target/c.json | 1 + .../higher_order_functions/target/main.json | 1 + .../higher_order_functions/target/witness.tr | Bin 0 -> 112 bytes .../inner_outer_cl/Nargo.toml | 6 + .../inner_outer_cl/src/main.nr | 10 ++ .../ret_fn_ret_cl/Nargo.toml | 6 + .../ret_fn_ret_cl/Prover.toml | 1 + .../ret_fn_ret_cl/src/main.nr | 29 ++++ .../src/ssa/ssa_gen/context.rs | 2 +- crates/noirc_evaluator/src/ssa/ssa_gen/mod.rs | 3 +- .../src/hir/resolution/resolver.rs | 155 ++++++++++++++--- .../noirc_frontend/src/hir/type_check/expr.rs | 39 +++-- .../noirc_frontend/src/hir/type_check/mod.rs | 6 +- crates/noirc_frontend/src/hir_def/function.rs | 4 +- crates/noirc_frontend/src/hir_def/types.rs | 67 ++++--- .../src/monomorphization/ast.rs | 94 +++++----- .../src/monomorphization/mod.rs | 163 +++++++++++++----- crates/noirc_frontend/src/node_interner.rs | 5 +- 29 files changed, 626 insertions(+), 164 deletions(-) create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/closures_mut_ref/Nargo.toml create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/closures_mut_ref/Prover.toml create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/closures_mut_ref/src/main.nr create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/fibonacci_by_ref/Nargo.toml create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/fibonacci_by_ref/Prover.toml create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/fibonacci_by_ref/src/main.nr create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_fn_selector/Nargo.toml create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_fn_selector/src/main.nr create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/Nargo.toml create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/Prover.toml create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/src/main.nr create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/target/c.json create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/target/main.json create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/target/witness.tr create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/inner_outer_cl/Nargo.toml create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/inner_outer_cl/src/main.nr create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/ret_fn_ret_cl/Nargo.toml create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/ret_fn_ret_cl/Prover.toml create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/ret_fn_ret_cl/src/main.nr diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/closures_mut_ref/Nargo.toml b/crates/nargo_cli/tests/test_data_ssa_refactor/closures_mut_ref/Nargo.toml new file mode 100644 index 00000000000..c829bb160b1 --- /dev/null +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/closures_mut_ref/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "closures_mut_ref" +authors = [""] +compiler_version = "0.8.0" + +[dependencies] \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/closures_mut_ref/Prover.toml b/crates/nargo_cli/tests/test_data_ssa_refactor/closures_mut_ref/Prover.toml new file mode 100644 index 00000000000..11497a473bc --- /dev/null +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/closures_mut_ref/Prover.toml @@ -0,0 +1 @@ +x = "0" diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/closures_mut_ref/src/main.nr b/crates/nargo_cli/tests/test_data_ssa_refactor/closures_mut_ref/src/main.nr new file mode 100644 index 00000000000..e02cfb6880d --- /dev/null +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/closures_mut_ref/src/main.nr @@ -0,0 +1,19 @@ +use dep::std; + +fn main(mut x: Field) { + + let add1 = |z| { + *z = *z + 1; + }; + + let add2 = |z| { + *z = *z + 2; + }; + + add1(&mut x); + assert(x == 1); + + add2(&mut x); + assert(x == 3); + +} diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/fibonacci_by_ref/Nargo.toml b/crates/nargo_cli/tests/test_data_ssa_refactor/fibonacci_by_ref/Nargo.toml new file mode 100644 index 00000000000..f5d28236db2 --- /dev/null +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/fibonacci_by_ref/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "fibonacci_by_ref" +authors = [""] +compiler_version = "0.8.0" + +[dependencies] \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/fibonacci_by_ref/Prover.toml b/crates/nargo_cli/tests/test_data_ssa_refactor/fibonacci_by_ref/Prover.toml new file mode 100644 index 00000000000..d594b02e17d --- /dev/null +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/fibonacci_by_ref/Prover.toml @@ -0,0 +1,2 @@ +prev = "1" +cur = "2" diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/fibonacci_by_ref/src/main.nr b/crates/nargo_cli/tests/test_data_ssa_refactor/fibonacci_by_ref/src/main.nr new file mode 100644 index 00000000000..763c5165c0f --- /dev/null +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/fibonacci_by_ref/src/main.nr @@ -0,0 +1,15 @@ +fn fib_fn(a: Field, b: Field, res: &mut Field) { + *res = a + b; +} + +fn main(mut prev: Field, mut cur: Field) { + + let mut fib = prev + cur; + for i in 1..10 { + prev = cur; + cur = fib; + fib_fn(prev, cur, &mut fib); + assert(prev + cur == fib); + } + +} diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_fn_selector/Nargo.toml b/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_fn_selector/Nargo.toml new file mode 100644 index 00000000000..3c2277e35a5 --- /dev/null +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_fn_selector/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "higher_order_fn_selector" +authors = [""] +compiler_version = "0.8.0" + +[dependencies] \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_fn_selector/src/main.nr b/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_fn_selector/src/main.nr new file mode 100644 index 00000000000..3b8ec51ba00 --- /dev/null +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_fn_selector/src/main.nr @@ -0,0 +1,49 @@ +fn f(x: &mut Field) -> Field { + *x = *x - 1; + 1 +} + +fn g(x: &mut Field) -> Field { + *x *= 2; + 1 +} + +fn h(x: &mut Field) -> Field { + *x *= 3; + 1 +} + +use dep::std; + +fn selector(flag:&mut bool) -> fn(&mut Field) -> Field { //TODO: Can we have fn(&mut Field) -> () return type? + let mut my_func = f; + + if *flag { + my_func = g; + } + else { + my_func = h; + }; + + // Flip the flag for the next function call + *flag = !(*flag); + my_func +} + +fn main() { + + let mut flag: bool = true; + + let mut x: Field = 100; + let returned_func = selector(&mut flag); + let status = returned_func(&mut x); + + assert(x == 200); + + let mut y: Field = 100; + let returned_func2 = selector(&mut flag); + let status2 = returned_func2(&mut y); + + assert(y == 300); + +} diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/Nargo.toml b/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/Nargo.toml new file mode 100644 index 00000000000..cf7526abc7f --- /dev/null +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "higher_order_functions" +authors = [""] +compiler_version = "0.1" + +[dependencies] \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/Prover.toml b/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/Prover.toml new file mode 100644 index 00000000000..e69de29bb2d diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/src/main.nr b/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/src/main.nr new file mode 100644 index 00000000000..a6e328b09af --- /dev/null +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/src/main.nr @@ -0,0 +1,88 @@ +use dep::std; + +fn main() -> pub Field { + let f = if 3 * 7 > 200 as u32 { foo } else { bar }; + assert(f()[1] == 2); + // Lambdas: + assert(twice(|x| x * 2, 5) == 20); + assert((|x, y| x + y + 1)(2, 3) == 6); + + // nested lambdas + assert((|a, b| { + a + (|c| c + 2)(b) + })(0, 1) == 3); + + + // Closures: + let a = 42; + let g = || a; + assert(g() == 42); + + // TODO: enable this again after fixing #2054 + // https://github.com/noir-lang/noir/issues/2054 + // by @jfecher's PR https://github.com/noir-lang/noir/pull/2057 + + // Mutable variables cannot be captured, but you can + // copy them into immutable variables and capture those: + // let mut x = 2; + // x = x + 1; + // let z = x; + + // Add extra mutations to ensure we can mutate x without the + // captured z changing. + // x = x + 1; + // TODO: this behavior changed in the new ssa backend: + // now even z is changed, and it wasn't in the previous backend + // assert(z == 2); + // fails! + // decide what to do after opening an issue about the simpler + // variable alias case + // + // assert((|y| y + z)(1) == 4); + + let ret = twice(add1, 3); + + test_array_functions(); + ret +} + +/// Test the array functions in std::array +fn test_array_functions() { + let myarray: [i32; 3] = [1, 2, 3]; + assert(myarray.any(|n| n > 2)); + + let evens: [i32; 3] = [2, 4, 6]; + assert(evens.all(|n| n > 1)); + + assert(evens.fold(0, |a, b| a + b) == 12); + assert(evens.reduce(|a, b| a + b) == 12); + + // TODO: is this a sort_via issue with the new backend, + // or something more general? + // + // currently it fails only with `--experimental-ssa` with + // "not yet implemented: Cast into signed" + // but it worked with the original ssa backend + // (before dropping it) + // + // let descending = myarray.sort_via(|a, b| a > b); + // assert(descending == [3, 2, 1]); + + assert(evens.map(|n| n / 2) == myarray); +} + +fn foo() -> [u32; 2] { + [1, 3] +} + +fn bar() -> [u32; 2] { + [3, 2] +} + +fn add1(x: Field) -> Field { + x + 1 +} + +fn twice(f: fn(Field) -> Field, x: Field) -> Field { + f(f(x)) +} diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/target/c.json b/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/target/c.json new file mode 100644 index 00000000000..c1233b8160b --- /dev/null +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/target/c.json @@ -0,0 +1 @@ +{"backend":"acvm-backend-barretenberg","abi":{"parameters":[],"param_witnesses":{},"return_type":null,"return_witnesses":[]},"bytecode":[155,194,56,97,194,4,0],"proving_key":null,"verification_key":null} \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/target/main.json b/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/target/main.json new file mode 100644 index 00000000000..8d7a1566313 --- /dev/null +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/target/main.json @@ -0,0 +1 @@ +{"backend":"acvm-backend-barretenberg","abi":{"parameters":[{"name":"x","type":{"kind":"integer","sign":"unsigned","width":32},"visibility":"private"},{"name":"y","type":{"kind":"integer","sign":"unsigned","width":32},"visibility":"private"},{"name":"z","type":{"kind":"integer","sign":"unsigned","width":32},"visibility":"private"}],"param_witnesses":{"x":[1],"y":[2],"z":[3]},"return_type":null,"return_witnesses":[]},"bytecode":"H4sIAAAAAAAA/9WUTW6DMBSEJ/yFhoY26bYLjoAxBLPrVYpK7n+EgmoHamWXeShYQsYSvJ+Z9/kDwCf+1m58ArsXi3PgnUN7dt/u7P9fdi8fW8rlATduCW89GFe5l2iMES90YBd+EyTyjIjtGYIm+HF1eanroa0GpdV3WXW9acq66S9GGdWY5qcyWg+mNm3Xd23ZqVoP6tp0+moDJ5AxNOTUWdk6VUTsOSb6wtRPCuDYziaZAzGA92OMFCsAPCUqMAOcQg5gZwIb4BdsA+A9seeU6AtTPymAUzubZA7EAD6MMTKsAPCUqMAMcAY5gJ0JbIBfsQ2AD8SeM6IvTP2kAM7sbJI5EAP4OMbIsQLAU6ICM8A55AB2JrABfsM2AD4Se86Jvjy5freeQ2LPObGud6J+Ce5ADz6LzJqX9Z4W75HdgzszkQj0BC+Pr6PohSpl0kkg7hm84Zfq+8z36N/l9OyaLtcv2EfpKJUUAAA=","proving_key":null,"verification_key":null} \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/target/witness.tr b/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/target/witness.tr new file mode 100644 index 0000000000000000000000000000000000000000..a539f87a55498eeaff3e546ac9126cea0091fa70 GIT binary patch literal 112 zcmV-$0FVD4iwFP!00002|E<$W3cw%?h2hTg=t&aVF5LAhrT4#sir&CKAZGQE2Z Field { + x +} + +fn ret_fn() -> fn(Field) -> Field { + let y = 1; + let inner_closure = |z| -> Field{ + z + y + }; + std::println(inner_closure(1)); + f +} + +fn ret_closure() -> fn(Field) -> Field { + let cl = |z: Field| -> Field { + z + }; + cl +} + +fn main(x : Field) { + let result_fn = ret_fn(); + assert(result_fn(x) == x); // Works + + let result_cl = ret_closure(); + assert(result_cl(x) == x); +} diff --git a/crates/noirc_evaluator/src/ssa/ssa_gen/context.rs b/crates/noirc_evaluator/src/ssa/ssa_gen/context.rs index 3e0bbff2a83..c3578e5ee7e 100644 --- a/crates/noirc_evaluator/src/ssa/ssa_gen/context.rs +++ b/crates/noirc_evaluator/src/ssa/ssa_gen/context.rs @@ -218,7 +218,7 @@ impl<'a> FunctionContext<'a> { } ast::Type::Unit => panic!("convert_non_tuple_type called on a unit type"), ast::Type::Tuple(_) => panic!("convert_non_tuple_type called on a tuple: {typ}"), - ast::Type::Function(_, _) => Type::Function, + ast::Type::Function(_, _, _) => Type::Function, ast::Type::Slice(element) => { let element_types = Self::convert_type(element).flatten(); Type::Slice(Rc::new(element_types)) diff --git a/crates/noirc_evaluator/src/ssa/ssa_gen/mod.rs b/crates/noirc_evaluator/src/ssa/ssa_gen/mod.rs index 0c0dd35211b..c89254f50f1 100644 --- a/crates/noirc_evaluator/src/ssa/ssa_gen/mod.rs +++ b/crates/noirc_evaluator/src/ssa/ssa_gen/mod.rs @@ -86,7 +86,8 @@ impl<'a> FunctionContext<'a> { /// Codegen any non-tuple expression so that we can unwrap the Values /// tree to return a single value for use with most SSA instructions. fn codegen_non_tuple_expression(&mut self, expr: &Expression) -> ValueId { - self.codegen_expression(expr).into_leaf().eval(self) + let e = self.codegen_expression(expr); + e.into_leaf().eval(self) } /// Codegen a reference to an ident. diff --git a/crates/noirc_frontend/src/hir/resolution/resolver.rs b/crates/noirc_frontend/src/hir/resolution/resolver.rs index 28338597730..12525cb3ba8 100644 --- a/crates/noirc_frontend/src/hir/resolution/resolver.rs +++ b/crates/noirc_frontend/src/hir/resolution/resolver.rs @@ -364,7 +364,8 @@ impl<'a> Resolver<'a> { UnresolvedType::Function(args, ret) => { let args = vecmap(args, |arg| self.resolve_type_inner(arg, new_variables)); let ret = Box::new(self.resolve_type_inner(*ret, new_variables)); - Type::Function(args, ret) + let env = Box::new(Type::Unit); + Type::Function(args, ret, env) } UnresolvedType::MutableReference(element) => { Type::MutableReference(Box::new(self.resolve_type_inner(*element, new_variables))) @@ -706,7 +707,7 @@ impl<'a> Resolver<'a> { }); } - let mut typ = Type::Function(parameter_types, return_type); + let mut typ = Type::Function(parameter_types, return_type, Box::new(Type::Unit)); if !generics.is_empty() { typ = Type::Forall(generics, Box::new(typ)); @@ -839,13 +840,12 @@ impl<'a> Resolver<'a> { } } - Type::Function(parameters, return_type) => { + Type::Function(parameters, return_type, _env) => { for parameter in parameters { Self::find_numeric_generics_in_type(parameter, found); } Self::find_numeric_generics_in_type(return_type, found); } - Type::Closure(func) => Self::find_numeric_generics_in_type(func, found), Type::Struct(struct_type, generics) => { for (i, generic) in generics.iter().enumerate() { @@ -964,7 +964,7 @@ impl<'a> Resolver<'a> { // If this was a fresh capture, we added it to the end of // the captures vector: self.lambda_stack[lambda_index].captures.len() - 1, - )) + )); } } } @@ -1139,8 +1139,7 @@ impl<'a> Resolver<'a> { ExpressionKind::Lambda(lambda) => self.in_new_scope(|this| { let scope_index = this.scopes.current_scope_index(); - this.lambda_stack - .push(LambdaContext { captures: Vec::new(), scope_index: scope_index }); + this.lambda_stack.push(LambdaContext { captures: Vec::new(), scope_index }); let parameters = vecmap(lambda.parameters, |(pattern, typ)| { let parameter = DefinitionKind::Local(None); @@ -1478,10 +1477,14 @@ mod test { use crate::hir::def_map::{ModuleData, ModuleId, ModuleOrigin}; use crate::hir::resolution::errors::ResolverError; use crate::hir::resolution::import::PathResolutionError; + use crate::hir::resolution::resolver::StmtId; use crate::graph::CrateId; + use crate::hir_def::expr::HirExpression; use crate::hir_def::function::HirFunction; + use crate::hir_def::stmt::HirStatement; use crate::node_interner::{FuncId, NodeInterner}; + use crate::ParsedModule; use crate::{ hir::def_map::{CrateDefMap, LocalModuleId, ModuleDefId}, parse_program, Path, @@ -1491,24 +1494,15 @@ mod test { // func_namespace is used to emulate the fact that functions can be imported // and functions can be forward declared - fn resolve_src_code(src: &str, func_namespace: Vec<&str>) -> Vec { + fn init_src_code_resolution( + src: &str, + ) -> (ParsedModule, NodeInterner, HashMap, FileId, TestPathResolver) { let (program, errors) = parse_program(src); if !errors.is_empty() { panic!("Unexpected parse errors in test code: {:?}", errors); } - let mut interner = NodeInterner::default(); - - let func_ids = vecmap(&func_namespace, |name| { - let id = interner.push_fn(HirFunction::empty()); - interner.push_function_definition(name.to_string(), id); - id - }); - - let mut path_resolver = TestPathResolver(HashMap::new()); - for (name, id) in func_namespace.into_iter().zip(func_ids) { - path_resolver.insert_func(name.to_owned(), id); - } + let interner: NodeInterner = NodeInterner::default(); let mut def_maps: HashMap = HashMap::new(); let file = FileId::default(); @@ -1516,6 +1510,8 @@ mod test { let mut modules = arena::Arena::new(); modules.insert(ModuleData::new(None, ModuleOrigin::File(file), false)); + let path_resolver = TestPathResolver(HashMap::new()); + def_maps.insert( CrateId::dummy_id(), CrateDefMap { @@ -1526,10 +1522,30 @@ mod test { }, ); + (program, interner, def_maps, file, path_resolver) + } + + // func_namespace is used to emulate the fact that functions can be imported + // and functions can be forward declared + fn resolve_src_code(src: &str, func_namespace: Vec<&str>) -> Vec { + let (program, mut interner, def_maps, file, mut path_resolver) = + init_src_code_resolution(src); + + let func_ids = vecmap(&func_namespace, |name| { + let id = interner.push_fn(HirFunction::empty()); + interner.push_function_definition(name.to_string(), id); + id + }); + + for (name, id) in func_namespace.into_iter().zip(func_ids) { + path_resolver.insert_func(name.to_owned(), id); + } + let mut errors = Vec::new(); for func in program.functions { let id = interner.push_fn(HirFunction::empty()); interner.push_function_definition(func.name().to_string(), id); + let resolver = Resolver::new(&mut interner, &path_resolver, &def_maps, file); let (_, _, err) = resolver.resolve_function(func, id, ModuleId::dummy_id()); errors.extend(err); @@ -1538,6 +1554,81 @@ mod test { errors } + fn get_program_captures(src: &str) -> Vec> { + let (program, mut interner, def_maps, file, mut path_resolver) = + init_src_code_resolution(src); + + let mut all_captures: Vec> = Vec::new(); + for func in program.functions { + let id = interner.push_fn(HirFunction::empty()); + interner.push_function_definition(func.name().clone().to_string(), id); + path_resolver.insert_func(func.name().to_owned(), id); + + let resolver = Resolver::new(&mut interner, &path_resolver, &def_maps, file); + let (hir_func, _, _) = resolver.resolve_function(func, id, ModuleId::dummy_id()); + + // Iterate over function statements and apply filtering function + parse_statement_blocks( + hir_func.block(&interner).statements(), + &interner, + &mut all_captures, + ); + } + all_captures + } + + fn parse_statement_blocks( + stmts: &[StmtId], + interner: &NodeInterner, + result: &mut Vec>, + ) { + let mut expr: HirExpression; + + for stmt_id in stmts.iter() { + let hir_stmt = interner.statement(stmt_id); + match hir_stmt { + HirStatement::Expression(expr_id) => { + expr = interner.expression(&expr_id); + } + HirStatement::Let(let_stmt) => { + expr = interner.expression(&let_stmt.expression); + } + HirStatement::Assign(assign_stmt) => { + expr = interner.expression(&assign_stmt.expression); + } + HirStatement::Constrain(constr_stmt) => { + expr = interner.expression(&constr_stmt.0); + } + HirStatement::Semi(semi_expr) => { + expr = interner.expression(&semi_expr); + } + HirStatement::Error => panic!("Invalid HirStatement!"), + } + get_lambda_captures(expr, &interner, result); // TODO: dyn filter function as parameter + } + } + + fn get_lambda_captures( + expr: HirExpression, + interner: &NodeInterner, + result: &mut Vec>, + ) { + if let HirExpression::Lambda(lambda_expr) = expr { + let mut cur_capture = Vec::new(); + + for capture in lambda_expr.captures.iter() { + cur_capture.push(interner.definition(capture.ident.id).name.clone()); + } + result.push(cur_capture); + + // Check for other captures recursively within the lambda body + let hir_body_expr = interner.expression(&lambda_expr.body); + if let HirExpression::Block(block_expr) = hir_body_expr.clone() { + parse_statement_blocks(block_expr.statements(), interner, result); + } + } + } + #[test] fn resolve_empty_function() { let src = " @@ -1771,8 +1862,27 @@ mod test { println!("Unexpected errors: {:?}", errors); assert!(false); // there should be no errors } - } + let expected_captures = vec![ + vec![], + vec!["x".to_string()], + vec!["b".to_string()], + vec!["x".to_string(), "b".to_string(), "a".to_string()], + vec![ + "x".to_string(), + "b".to_string(), + "a".to_string(), + "y".to_string(), + "d".to_string(), + ], + vec!["x".to_string(), "b".to_string()], + ]; + + let parsed_captures = get_program_captures(src); + + assert_eq!(expected_captures, parsed_captures); + } + #[test] fn resolve_fmt_strings() { let src = r#" @@ -1806,10 +1916,9 @@ mod test { } } - // TODO: Create a more sophisticated set of search functions over the HIR, so we can check + // possible TODO: Create a more sophisticated set of search functions over the HIR, so we can check // that the correct variables are captured in each closure - fn path_unresolved_error(err: ResolverError, expected_unresolved_path: &str) { match err { ResolverError::PathResolutionError(PathResolutionError::Unresolved(name)) => { diff --git a/crates/noirc_frontend/src/hir/type_check/expr.rs b/crates/noirc_frontend/src/hir/type_check/expr.rs index c3106b2b56a..4ef92e2c5d6 100644 --- a/crates/noirc_frontend/src/hir/type_check/expr.rs +++ b/crates/noirc_frontend/src/hir/type_check/expr.rs @@ -279,13 +279,11 @@ impl<'interner> TypeChecker<'interner> { Type::Tuple(vecmap(&elements, |elem| self.check_expression(elem))) } HirExpression::Lambda(lambda) => { - let captured_vars = vecmap(lambda.captures, |capture| { - let typ = self.interner.id_type(capture.ident.id); - typ - }); + let captured_vars = + vecmap(lambda.captures, |capture| self.interner.id_type(capture.ident.id)); let env_type = Type::Tuple(captured_vars); - let mut params = vec![env_type]; + let mut params = vec![env_type.clone()]; for (pattern, typ) in lambda.parameters { self.bind_pattern(&pattern, typ.clone()); @@ -303,8 +301,9 @@ impl<'interner> TypeChecker<'interner> { } }); - let function_type = Type::Function(params, Box::new(lambda.return_type)); - Type::Closure(Box::new(function_type)) + let function_type = + Type::Function(params, Box::new(lambda.return_type), Box::new(env_type)); + function_type } }; @@ -329,9 +328,9 @@ impl<'interner> TypeChecker<'interner> { argument_types: &mut [(Type, ExprId, noirc_errors::Span)], ) { let expected_object_type = match function_type { - Type::Function(args, _) => args.get(0), + Type::Function(args, _, _) => args.get(0), Type::Forall(_, typ) => match typ.as_ref() { - Type::Function(args, _) => args.get(0), + Type::Function(args, _, _) => args.get(0), typ => unreachable!("Unexpected type for function: {typ}"), }, typ => unreachable!("Unexpected type for function: {typ}"), @@ -891,10 +890,10 @@ impl<'interner> TypeChecker<'interner> { let real_fn_params_count = fn_params.len() - skip_params; if real_fn_params_count != callsite_args.len() { - self.errors.push(TypeCheckError::ParameterCountMismatch { + self.errors.push(TypeCheckError::ParameterCountMismatch { expected: real_fn_params_count, - found: callsite_args.len(), - span: span + found: callsite_args.len(), + span: span, }); return Type::Error; } @@ -928,22 +927,30 @@ impl<'interner> TypeChecker<'interner> { let ret = self.interner.next_type_variable(); let args = vecmap(args, |(arg, _, _)| arg); - let expected = Type::Function(args, Box::new(ret.clone())); + let expected = Type::Function(args, Box::new(ret.clone()), Box::new(Type::Unit)); if let Err(error) = binding.borrow_mut().bind_to(expected, span) { self.errors.push(error); } ret } - Type::Function(parameters, ret) => { + Type::Function(parameters, ret, env) => { self.bind_function_type_impl( parameters.as_ref(), ret.as_ref(), args.as_ref(), span, - 0, + match *env { + Type::Unit => 0, + Type::Tuple(_) => { + 1 // closure env + } + _ => unreachable!( + "function env internal type should be either Unit or Tuple" + ), + }, ) - }, + } Type::Error => Type::Error, found => { self.errors.push(TypeCheckError::ExpectedFunction { found, span }); diff --git a/crates/noirc_frontend/src/hir/type_check/mod.rs b/crates/noirc_frontend/src/hir/type_check/mod.rs index 9ab581cddca..1883c0abf62 100644 --- a/crates/noirc_frontend/src/hir/type_check/mod.rs +++ b/crates/noirc_frontend/src/hir/type_check/mod.rs @@ -246,7 +246,11 @@ mod test { contract_function_type: None, is_internal: None, is_unconstrained: false, - typ: Type::Function(vec![Type::field(None), Type::field(None)], Box::new(Type::Unit)), + typ: Type::Function( + vec![Type::field(None), Type::field(None)], + Box::new(Type::Unit), + Box::new(Type::Unit), + ), parameters: vec![ Param(Identifier(x), Type::field(None), noirc_abi::AbiVisibility::Private), Param(Identifier(y), Type::field(None), noirc_abi::AbiVisibility::Private), diff --git a/crates/noirc_frontend/src/hir_def/function.rs b/crates/noirc_frontend/src/hir_def/function.rs index a69e8bb08b5..225731626f0 100644 --- a/crates/noirc_frontend/src/hir_def/function.rs +++ b/crates/noirc_frontend/src/hir_def/function.rs @@ -180,9 +180,9 @@ impl FuncMeta { /// Gives the (uninstantiated) return type of this function. pub fn return_type(&self) -> &Type { match &self.typ { - Type::Function(_, ret) => ret, + Type::Function(_, ret, _env) => ret, Type::Forall(_, typ) => match typ.as_ref() { - Type::Function(_, ret) => ret, + Type::Function(_, ret, _env) => ret, _ => unreachable!(), }, _ => unreachable!(), diff --git a/crates/noirc_frontend/src/hir_def/types.rs b/crates/noirc_frontend/src/hir_def/types.rs index 76c41a2c86f..d8c1acb4245 100644 --- a/crates/noirc_frontend/src/hir_def/types.rs +++ b/crates/noirc_frontend/src/hir_def/types.rs @@ -70,13 +70,11 @@ pub enum Type { /// like `fn foo(...) {}`. Unlike TypeVariables, they cannot be bound over. NamedGeneric(TypeVariable, Rc), - /// A functions with arguments, and a return type. - Function(Vec, Box), - - /// A closure (a pair of a function pointer and a tuple of captured variables). - /// Stores the underlying function type, which has been modifies such that the - /// first parameter is the type of the captured variables tuple. - Closure(Box), + /// A functions with arguments, a return type and environment. + /// the environment should be `Unit` by default, + /// for closures it should contain a `Tuple` type with the captured + /// variable types. + Function(Vec, Box, Box), /// &mut T MutableReference(Box), @@ -702,11 +700,10 @@ impl Type { Type::Tuple(fields) => { fields.iter().any(|field| field.contains_numeric_typevar(target_id)) } - Type::Function(parameters, return_type) => { + Type::Function(parameters, return_type, _env) => { parameters.iter().any(|parameter| parameter.contains_numeric_typevar(target_id)) || return_type.contains_numeric_typevar(target_id) } - Type::Closure(func) => func.contains_numeric_typevar(target_id), Type::Struct(struct_type, generics) => { generics.iter().enumerate().any(|(i, generic)| { if named_generic_id_matches_target(generic) { @@ -803,12 +800,9 @@ impl std::fmt::Display for Type { let typevars = vecmap(typevars, |(var, _)| var.to_string()); write!(f, "forall {}. {}", typevars.join(" "), typ) } - Type::Function(args, ret) => { + Type::Function(args, ret, env) => { let args = vecmap(args, ToString::to_string); - write!(f, "fn({}) -> {}", args.join(", "), ret) - } - Type::Closure(func) => { - write!(f, "closure {}", func) // i.e. we produce a string such as "closure fn(args) -> ret" + write!(f, "fn({}) -> {} [{}]", args.join(", "), ret, env) } Type::MutableReference(element) => { write!(f, "&mut {element}") @@ -1205,13 +1199,14 @@ impl Type { } } - (Function(params_a, ret_a), Function(params_b, ret_b)) => { + (Function(params_a, ret_a, env_a), Function(params_b, ret_b, env_b)) => { if params_a.len() == params_b.len() { for (a, b) in params_a.iter().zip(params_b) { a.try_unify(b, span)?; } - ret_b.try_unify(ret_a, span) + ret_b.try_unify(ret_a, span)?; + env_a.try_unify(env_b, span) } else { Err(SpanKind::None) } @@ -1412,9 +1407,16 @@ impl Type { } } - (Function(params_a, ret_a), Function(params_b, ret_b)) => { - if params_a.len() == params_b.len() { - for (a, b) in params_a.iter().zip(params_b) { + (Function(params_a, ret_a, env_a), Function(params_b, ret_b, _env_b)) => { + let skip_params = match *env_a.clone() { + Type::Unit => 0, + Type::Tuple(_) => { + 1 // closure env + } + _ => unreachable!("function env internal type should be either Unit or Tuple"), + }; + if params_a.len() - skip_params == params_b.len() { + for (a, b) in params_a.iter().skip(skip_params).zip(params_b) { a.is_subtype_of(b, span)?; } @@ -1514,8 +1516,7 @@ impl Type { Type::TypeVariable(_, _) => unreachable!(), Type::NamedGeneric(..) => unreachable!(), Type::Forall(..) => unreachable!(), - Type::Function(_, _) => unreachable!(), - Type::Closure(_) => unreachable!(), + Type::Function(_, _, _) => unreachable!(), Type::MutableReference(_) => unreachable!("&mut cannot be used in the abi"), Type::NotConstant => unreachable!(), } @@ -1630,14 +1631,11 @@ impl Type { let typ = Box::new(typ.substitute(type_bindings)); Type::Forall(typevars.clone(), typ) } - Type::Function(args, ret) => { + Type::Function(args, ret, env) => { let args = vecmap(args, |arg| arg.substitute(type_bindings)); let ret = Box::new(ret.substitute(type_bindings)); - Type::Function(args, ret) - } - Type::Closure(func) => { - let func = Box::new(func.substitute(type_bindings)); - Type::Closure(func) + let env = Box::new(env.substitute(type_bindings)); + Type::Function(args, ret, env) } Type::MutableReference(element) => { Type::MutableReference(Box::new(element.substitute(type_bindings))) @@ -1674,10 +1672,11 @@ impl Type { Type::Forall(typevars, typ) => { !typevars.iter().any(|(id, _)| *id == target_id) && typ.occurs(target_id) } - Type::Function(args, ret) => { - args.iter().any(|arg| arg.occurs(target_id)) || ret.occurs(target_id) + Type::Function(args, ret, env) => { + args.iter().any(|arg| arg.occurs(target_id)) + || ret.occurs(target_id) + || env.occurs(target_id) } - Type::Closure(func) => func.occurs(target_id), Type::MutableReference(element) => element.occurs(target_id), Type::FieldElement(_) @@ -1721,12 +1720,12 @@ impl Type { self.clone() } - Function(args, ret) => { + Function(args, ret, env) => { let args = vecmap(args, |arg| arg.follow_bindings()); let ret = Box::new(ret.follow_bindings()); - Function(args, ret) + let env = Box::new(env.follow_bindings()); + Function(args, ret, env) } - Closure(func) => Closure(Box::new(func.follow_bindings())), MutableReference(element) => MutableReference(Box::new(element.follow_bindings())), @@ -1768,7 +1767,7 @@ fn convert_array_expression_to_slice( interner.push_expr_location(func, location.span, location.file); interner.push_expr_type(&call, target_type.clone()); - interner.push_expr_type(&func, Type::Function(vec![array_type], Box::new(target_type))); + interner.push_expr_type(&func, Type::Function(vec![array_type], Box::new(target_type), Box::new(Type::Unit))); } impl BinaryTypeOperator { diff --git a/crates/noirc_frontend/src/monomorphization/ast.rs b/crates/noirc_frontend/src/monomorphization/ast.rs index 95b95c38525..e756111aa0a 100644 --- a/crates/noirc_frontend/src/monomorphization/ast.rs +++ b/crates/noirc_frontend/src/monomorphization/ast.rs @@ -219,7 +219,7 @@ pub enum Type { Tuple(Vec), Slice(Box), MutableReference(Box), - Function(/*args:*/ Vec, /*ret:*/ Box), + Function(/*args:*/ Vec, /*ret:*/ Box, /*env:*/ Box), } impl Type { @@ -231,51 +231,59 @@ impl Type { } } -pub fn type_of_lvalue(lvalue: &LValue) -> Type { - match lvalue { - LValue::Ident(ident) => ident.typ.clone(), - LValue::Index { element_type, .. } => element_type.clone(), - LValue::MemberAccess { object, field_index } => { - let tuple_type = type_of_lvalue(object.as_ref()); - match tuple_type { - Type::Tuple(fields) => fields[*field_index].clone(), - _ => unreachable!("ICE: Member access on non-tuple type"), +impl Expression { + pub fn type_of(&self) -> Type { + match self { + Expression::Ident(ident) => ident.typ.clone(), + Expression::Literal(lit) => match lit { + Literal::Integer(_, typ) => typ.clone(), + Literal::Bool(_) => Type::Bool, + Literal::Str(str) => Type::String(str.len() as u64), + Literal::Array(array) => { + // temp + Type::Array(array.contents.len() as u64, Box::new(Type::Unit)) + }, + Literal::FmtStr(_, _, _) => unimplemented!() + }, + Expression::Block(stmts) => (stmts.last().unwrap()).type_of(), + Expression::Unary(unary) => unary.result_type.clone(), + Expression::Binary(_binary) => { + unreachable!("TODO: How do we get the type of a Binary op") } + Expression::Index(index) => index.element_type.clone(), + Expression::Cast(cast) => cast.r#type.clone(), + Expression::For(_for_expr) => Type::Unit, + Expression::If(if_expr) => if_expr.typ.clone(), + Expression::Tuple(elements) => { + Type::Tuple(elements.iter().map(|e| e.type_of()).collect()) + } + Expression::ExtractTupleField(tuple, index) => match tuple.as_ref() { + Expression::Tuple(fields) => (&fields[*index]).type_of(), + _ => unreachable!("ICE: Tuple field access on non-tuple type"), + }, + Expression::Call(call) => call.return_type.clone(), + Expression::Let(let_stmt) => let_stmt.expression.as_ref().type_of(), + Expression::Constrain(constraint, _) => constraint.as_ref().type_of(), + Expression::Assign(assign) => (&assign.lvalue).type_of(), + Expression::Semi(_expr) => Type::Unit, } - LValue::Dereference { element_type, .. } => element_type.clone(), } } -pub fn type_of(expr: &Expression) -> Type { - match expr { - Expression::Ident(ident) => ident.typ.clone(), - Expression::Literal(lit) => match lit { - Literal::Integer(_, typ) => typ.clone(), - Literal::Bool(_) => Type::Bool, - Literal::Str(str) => Type::String(str.len() as u64), - Literal::Array(array) => { - // TODO - Type::Array(array.contents.len() as u64, Box::new(Type::Unit)) - }, - Literal::FmtStr(_, _, _) => unimplemented!() - }, - Expression::Block(stmts) => type_of(stmts.last().unwrap()), - Expression::Unary(unary) => unary.result_type.clone(), - Expression::Binary(_binary) => unreachable!("TODO: How do we get the type of a Binary op"), - Expression::Index(index) => index.element_type.clone(), - Expression::Cast(cast) => cast.r#type.clone(), - Expression::For(_for_expr) => Type::Unit, - Expression::If(if_expr) => if_expr.typ.clone(), - Expression::Tuple(elements) => Type::Tuple(elements.iter().map(type_of).collect()), - Expression::ExtractTupleField(tuple, index) => match tuple.as_ref() { - Expression::Tuple(fields) => type_of(&fields[*index]), - _ => unreachable!("ICE: Tuple field access on non-tuple type"), - }, - Expression::Call(call) => call.return_type.clone(), - Expression::Let(let_stmt) => type_of(let_stmt.expression.as_ref()), - Expression::Constrain(contraint, _) => type_of(contraint.as_ref()), - Expression::Assign(assign) => type_of_lvalue(&assign.lvalue), - Expression::Semi(_expr) => Type::Unit, +impl LValue { + pub fn type_of(&self) -> Type { + match self { + LValue::Ident(ident) => ident.typ.clone(), + LValue::Index { element_type, .. } => element_type.clone(), + LValue::MemberAccess { object, field_index } => { + let tuple_type = object.as_ref().type_of(); + match tuple_type { + Type::Tuple(fields) => fields[*field_index].clone(), + _ => unreachable!("ICE: Member access on non-tuple type"), + } + } + LValue::Dereference { element_type, .. } => element_type.clone(), + } } } @@ -378,9 +386,9 @@ impl std::fmt::Display for Type { let elements = vecmap(elements, ToString::to_string); write!(f, "({})", elements.join(", ")) } - Type::Function(args, ret) => { + Type::Function(args, ret, env) => { let args = vecmap(args, ToString::to_string); - write!(f, "fn({}) -> {}", args.join(", "), ret) + write!(f, "fn({}) -> {} [{}]", args.join(", "), ret, env) } Type::Slice(element) => write!(f, "[{element}"), Type::MutableReference(element) => write!(f, "&mut {element}"), diff --git a/crates/noirc_frontend/src/monomorphization/mod.rs b/crates/noirc_frontend/src/monomorphization/mod.rs index 3b8901b079b..ecf620d88e1 100644 --- a/crates/noirc_frontend/src/monomorphization/mod.rs +++ b/crates/noirc_frontend/src/monomorphization/mod.rs @@ -581,8 +581,19 @@ impl<'interner> Monomorphizer<'interner> { let definition = self.lookup_function(*func_id, expr_id, &typ); let typ = Self::convert_type(&typ); - let ident = ast::Ident { location, mutable, definition, name, typ }; - ast::Expression::Ident(ident) + let ident = ast::Ident { location, mutable, definition, name, typ: typ.clone() }; + let ident_expression = ast::Expression::Ident(ident).clone(); + if self.is_function_closure_type(&typ) { + ast::Expression::Tuple(vec![ + ast::Expression::ExtractTupleField( + Box::new(ident_expression.clone()), + 0usize, + ), + ast::Expression::ExtractTupleField(Box::new(ident_expression), 1usize), + ]) + } else { + ident_expression + } } DefinitionKind::Global(expr_id) => self.expr(*expr_id), DefinitionKind::Local(_) => self.lookup_captured(ident.id).unwrap_or_else(|| { @@ -674,23 +685,11 @@ impl<'interner> Monomorphizer<'interner> { ast::Type::Tuple(fields) } - HirType::Function(args, ret) => { + HirType::Function(args, ret, env) => { let args = vecmap(args, Self::convert_type); let ret = Box::new(Self::convert_type(ret)); - ast::Type::Function(args, ret) - } - - HirType::Closure(func) => { - match func.as_ref() { - HirType::Function(arguments, return_type) => { - let converted_args = vecmap(arguments, Self::convert_type); - let converted_ret = Box::new(Self::convert_type(&return_type)); - let fn_type = ast::Type::Function(converted_args, converted_ret); - let env_type = ast::Type::Tuple(vec![]); // TODO compute this - ast::Type::Tuple(vec![env_type, fn_type]) - } - _ => unreachable!("Unexpected closure type {}", func), - } + let env = Box::new(Self::convert_type(env)); + ast::Type::Function(args, ret, env) } HirType::MutableReference(element) => { @@ -708,7 +707,27 @@ impl<'interner> Monomorphizer<'interner> { } fn is_function_closure(&self, func: &ast::Expression) -> bool { - matches!(ast::type_of(func), ast::Type::Tuple(_)) + let t = func.type_of(); + if self.is_function_closure_type(&t) { + true + } else if let ast::Type::Tuple(elements) = t { + if elements.len() == 2 { + matches!(elements[1], ast::Type::Function(_, _, _)) + } else { + false + } + } else { + false + } + } + + fn is_function_closure_type(&self, t: &ast::Type) -> bool { + if let ast::Type::Function(_, _, env) = t { + let e = (*env).clone(); + matches!(*e, ast::Type::Tuple(_captures)) + } else { + false + } } fn function_call( @@ -716,9 +735,10 @@ impl<'interner> Monomorphizer<'interner> { call: HirCallExpression, id: node_interner::ExprId, ) -> ast::Expression { - let original_func = Box::new(self.expr(call.func)); + let original_func = Box::new(self.expr((call.func).clone())); let mut arguments = vecmap(&call.arguments, |id| self.expr(*id)); let hir_arguments = vecmap(&call.arguments, |id| self.interner.expression(id)); + let func: Box; let return_type = self.interner.id_type(id); let return_type = Self::convert_type(&return_type); let location = call.location; @@ -733,25 +753,42 @@ impl<'interner> Monomorphizer<'interner> { } } - let is_closure = self.is_function_closure(&*original_func); + let mut block_expressions = vec![]; - let func = if is_closure { - Box::new(ast::Expression::ExtractTupleField(Box::new((*original_func).clone()), 1usize)) + let is_closure = self.is_function_closure(&original_func); + if is_closure { + let extracted_func: ast::Expression; + let hir_call_func = self.interner.expression(&call.func); + if let HirExpression::Lambda(l) = hir_call_func { + let (setup, closure_variable) = self.lambda_with_setup(l); + block_expressions.push(setup); + extracted_func = closure_variable; + } else { + extracted_func = *original_func; + } + func = Box::new(ast::Expression::ExtractTupleField( + Box::new(extracted_func.clone()), + 1usize, + )); + let env_argument = ast::Expression::ExtractTupleField(Box::new(extracted_func), 0usize); + arguments.insert(0, env_argument); } else { - original_func.clone() + func = original_func.clone() }; - if is_closure { - let env_argument = - ast::Expression::ExtractTupleField(Box::new((*original_func).clone()), 0usize); - arguments.insert(0, env_argument); - } - self.try_evaluate_call(&func, &return_type).unwrap_or(ast::Expression::Call(ast::Call { + let call = self.try_evaluate_call(&func, &return_type).unwrap_or(ast::Expression::Call(ast::Call { func, arguments, return_type, location, - })) + })); + + if !block_expressions.is_empty() { + block_expressions.push(call); + ast::Expression::Block(block_expressions) + } else { + call + } } /// Adds a function argument that contains type metadata that is required to tell @@ -961,7 +998,21 @@ impl<'interner> Monomorphizer<'interner> { } } - fn lambda(&mut self, lambda: HirLambda) -> ast::Expression { + fn lambda_with_setup(&mut self, lambda: HirLambda) -> (ast::Expression, ast::Expression) { + // returns (, ) + // which can be used directly in callsites or transformed + // directly to a single `Expression` + // for other cases by `lambda` which is called by `expr` + // + // it solves the problem of detecting special cases where + // we call something like + // `{let env$.. = ..;}.1({let env$.. = ..;}.0, ..)` + // which was leading to redefinition errors + // + // instead of detecting and extracting + // patterns in the resulting tree, + // which seems more fragile, we directly reuse the return parameters + // of this function in those cases let ret_type = Self::convert_type(&lambda.return_type); let lambda_name = "lambda"; let parameter_types = vecmap(&lambda.parameters, |(_, typ)| Self::convert_type(typ)); @@ -996,7 +1047,7 @@ impl<'interner> Monomorphizer<'interner> { } } })); - let env_typ = ast::type_of(&env_tuple); + let env_typ = (&env_tuple).type_of(); let env_let_stmt = ast::Expression::Let(ast::Let { id: env_local_id, @@ -1024,17 +1075,18 @@ impl<'interner> Monomorphizer<'interner> { let body = self.expr(lambda.body); self.lambda_envs_stack.pop(); - let lambda_fn_typ: ast::Type = ast::Type::Function(parameter_types, Box::new(ret_type)); + let lambda_fn_typ: ast::Type = + ast::Type::Function(parameter_types, Box::new(ret_type), Box::new(env_typ.clone())); let lambda_fn = ast::Expression::Ident(ast::Ident { definition: Definition::Function(id), mutable: false, location: None, // TODO: This should match the location of the lambda expression name: name.clone(), - typ: lambda_fn_typ, + typ: lambda_fn_typ.clone(), }); let mut parameters = vec![]; - parameters.push((env_local_id, true, env_name.to_string(), env_typ)); + parameters.push((env_local_id, true, env_name.to_string(), env_typ.clone())); parameters.append(&mut converted_parameters); let unconstrained = false; @@ -1042,7 +1094,32 @@ impl<'interner> Monomorphizer<'interner> { self.push_function(id, function); let lambda_value = ast::Expression::Tuple(vec![env_ident, lambda_fn]); - ast::Expression::Block(vec![env_let_stmt, lambda_value]) + let block_local_id = self.next_local_id(); + let block_ident_name = "closure_variable"; + let block_let_stmt = ast::Expression::Let(ast::Let { + id: block_local_id, + mutable: false, + name: block_ident_name.to_string(), + expression: Box::new(ast::Expression::Block(vec![env_let_stmt, lambda_value])), + }); + + let closure_definition = Definition::Local(block_local_id); + + let closure_ident = ast::Expression::Ident(ast::Ident { + location, + mutable: false, + definition: closure_definition, + name: block_ident_name.to_string(), + typ: ast::Type::Tuple(vec![env_typ, lambda_fn_typ]), + }); + + (block_let_stmt, closure_ident) + // ast::Expression::Block(vec![block_let_stmt, closure_ident]) + } + + fn lambda(&mut self, lambda: HirLambda) -> ast::Expression { + let (setup, closure_variable) = self.lambda_with_setup(lambda); + ast::Expression::Block(vec![setup, closure_variable]) } /// Implements std::unsafe::zeroed by returning an appropriate zeroed @@ -1082,8 +1159,8 @@ impl<'interner> Monomorphizer<'interner> { ast::Type::Tuple(fields) => { ast::Expression::Tuple(vecmap(fields, |field| self.zeroed_value_of_type(field))) } - ast::Type::Function(parameter_types, ret_type) => { - self.create_zeroed_function(parameter_types, ret_type) + ast::Type::Function(parameter_types, ret_type, env) => { + self.create_zeroed_function(parameter_types, ret_type, env) } ast::Type::Slice(element_type) => { ast::Expression::Literal(ast::Literal::Array(ast::ArrayLiteral { @@ -1110,6 +1187,7 @@ impl<'interner> Monomorphizer<'interner> { &mut self, parameter_types: &[ast::Type], ret_type: &ast::Type, + env_type: &ast::Type, ) -> ast::Expression { let lambda_name = "zeroed_lambda"; @@ -1132,7 +1210,11 @@ impl<'interner> Monomorphizer<'interner> { mutable: false, location: None, name: lambda_name.to_owned(), - typ: ast::Type::Function(parameter_types.to_owned(), Box::new(ret_type.clone())), + typ: ast::Type::Function( + parameter_types.to_owned(), + Box::new(ret_type.clone()), + Box::new(env_type.clone()), + ), }) } } @@ -1237,7 +1319,8 @@ mod tests { let func_meta = vecmap(program.functions, |nf| { let resolver = Resolver::new(&mut interner, &path_resolver, &def_maps, file); - let (hir_func, func_meta, _resolver_errors) = resolver.resolve_function(nf, main_id); + let (hir_func, func_meta, _resolver_errors) = + resolver.resolve_function(nf, main_id, ModuleId::dummy_id()); // TODO: not sure why, we do get an error here, // but otherwise seem to get an ok monomorphization result // assert_eq!(resolver_errors, vec![]); diff --git a/crates/noirc_frontend/src/node_interner.rs b/crates/noirc_frontend/src/node_interner.rs index c6ff1d98e3f..7a1d53f4dbf 100644 --- a/crates/noirc_frontend/src/node_interner.rs +++ b/crates/noirc_frontend/src/node_interner.rs @@ -672,7 +672,7 @@ fn get_type_method_key(typ: &Type) -> Option { Type::String(_) => Some(String), Type::Unit => Some(Unit), Type::Tuple(_) => Some(Tuple), - Type::Function(_, _) => Some(Function), + Type::Function(_, _, _) => Some(Function), Type::MutableReference(element) => get_type_method_key(element), // We do not support adding methods to these types @@ -683,7 +683,6 @@ fn get_type_method_key(typ: &Type) -> Option { | Type::Error | Type::NotConstant | Type::Struct(_, _) - | Type::FmtString(_, _) - | Type::Closure(_) => None, // TODO: Is this correct? How do we add methods to functions? Can we do the same for closures? + | Type::FmtString(_, _) => None } } From 33c2c5cefabcdb6153143ba9db73bd7cfa375874 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Fri, 28 Jul 2023 16:41:50 +0300 Subject: [PATCH 06/26] fix: fix code, addressing `cargo clippy` warnings --- crates/noirc_frontend/src/hir/type_check/expr.rs | 6 ++---- crates/noirc_frontend/src/monomorphization/ast.rs | 4 ++-- crates/noirc_frontend/src/monomorphization/mod.rs | 8 ++++---- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/crates/noirc_frontend/src/hir/type_check/expr.rs b/crates/noirc_frontend/src/hir/type_check/expr.rs index 4ef92e2c5d6..c98d79cf950 100644 --- a/crates/noirc_frontend/src/hir/type_check/expr.rs +++ b/crates/noirc_frontend/src/hir/type_check/expr.rs @@ -301,9 +301,7 @@ impl<'interner> TypeChecker<'interner> { } }); - let function_type = - Type::Function(params, Box::new(lambda.return_type), Box::new(env_type)); - function_type + Type::Function(params, Box::new(lambda.return_type), Box::new(env_type)) } }; @@ -893,7 +891,7 @@ impl<'interner> TypeChecker<'interner> { self.errors.push(TypeCheckError::ParameterCountMismatch { expected: real_fn_params_count, found: callsite_args.len(), - span: span, + span, }); return Type::Error; } diff --git a/crates/noirc_frontend/src/monomorphization/ast.rs b/crates/noirc_frontend/src/monomorphization/ast.rs index e756111aa0a..6c93d9f4f1a 100644 --- a/crates/noirc_frontend/src/monomorphization/ast.rs +++ b/crates/noirc_frontend/src/monomorphization/ast.rs @@ -258,13 +258,13 @@ impl Expression { Type::Tuple(elements.iter().map(|e| e.type_of()).collect()) } Expression::ExtractTupleField(tuple, index) => match tuple.as_ref() { - Expression::Tuple(fields) => (&fields[*index]).type_of(), + Expression::Tuple(fields) => fields[*index].type_of(), _ => unreachable!("ICE: Tuple field access on non-tuple type"), }, Expression::Call(call) => call.return_type.clone(), Expression::Let(let_stmt) => let_stmt.expression.as_ref().type_of(), Expression::Constrain(constraint, _) => constraint.as_ref().type_of(), - Expression::Assign(assign) => (&assign.lvalue).type_of(), + Expression::Assign(assign) => assign.lvalue.type_of(), Expression::Semi(_expr) => Type::Unit, } } diff --git a/crates/noirc_frontend/src/monomorphization/mod.rs b/crates/noirc_frontend/src/monomorphization/mod.rs index ecf620d88e1..15e2bbf1fcf 100644 --- a/crates/noirc_frontend/src/monomorphization/mod.rs +++ b/crates/noirc_frontend/src/monomorphization/mod.rs @@ -582,7 +582,7 @@ impl<'interner> Monomorphizer<'interner> { let definition = self.lookup_function(*func_id, expr_id, &typ); let typ = Self::convert_type(&typ); let ident = ast::Ident { location, mutable, definition, name, typ: typ.clone() }; - let ident_expression = ast::Expression::Ident(ident).clone(); + let ident_expression = ast::Expression::Ident(ident); if self.is_function_closure_type(&typ) { ast::Expression::Tuple(vec![ ast::Expression::ExtractTupleField( @@ -735,7 +735,7 @@ impl<'interner> Monomorphizer<'interner> { call: HirCallExpression, id: node_interner::ExprId, ) -> ast::Expression { - let original_func = Box::new(self.expr((call.func).clone())); + let original_func = Box::new(self.expr(call.func)); let mut arguments = vecmap(&call.arguments, |id| self.expr(*id)); let hir_arguments = vecmap(&call.arguments, |id| self.interner.expression(id)); let func: Box; @@ -773,7 +773,7 @@ impl<'interner> Monomorphizer<'interner> { let env_argument = ast::Expression::ExtractTupleField(Box::new(extracted_func), 0usize); arguments.insert(0, env_argument); } else { - func = original_func.clone() + func = original_func.clone(); }; let call = self.try_evaluate_call(&func, &return_type).unwrap_or(ast::Expression::Call(ast::Call { @@ -1047,7 +1047,7 @@ impl<'interner> Monomorphizer<'interner> { } } })); - let env_typ = (&env_tuple).type_of(); + let env_typ = env_tuple.type_of(); let env_let_stmt = ast::Expression::Let(ast::Let { id: env_local_id, From 0c0c12315d37f506701f117c395caa39ff4e23fc Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Fri, 28 Jul 2023 17:47:00 +0300 Subject: [PATCH 07/26] fix: replace type_of usage and remove it, as hinted in review --- .../src/monomorphization/ast.rs | 56 ------------------- .../src/monomorphization/mod.rs | 28 +++++++--- 2 files changed, 19 insertions(+), 65 deletions(-) diff --git a/crates/noirc_frontend/src/monomorphization/ast.rs b/crates/noirc_frontend/src/monomorphization/ast.rs index 6c93d9f4f1a..42396543d26 100644 --- a/crates/noirc_frontend/src/monomorphization/ast.rs +++ b/crates/noirc_frontend/src/monomorphization/ast.rs @@ -231,62 +231,6 @@ impl Type { } } -impl Expression { - pub fn type_of(&self) -> Type { - match self { - Expression::Ident(ident) => ident.typ.clone(), - Expression::Literal(lit) => match lit { - Literal::Integer(_, typ) => typ.clone(), - Literal::Bool(_) => Type::Bool, - Literal::Str(str) => Type::String(str.len() as u64), - Literal::Array(array) => { - // temp - Type::Array(array.contents.len() as u64, Box::new(Type::Unit)) - }, - Literal::FmtStr(_, _, _) => unimplemented!() - }, - Expression::Block(stmts) => (stmts.last().unwrap()).type_of(), - Expression::Unary(unary) => unary.result_type.clone(), - Expression::Binary(_binary) => { - unreachable!("TODO: How do we get the type of a Binary op") - } - Expression::Index(index) => index.element_type.clone(), - Expression::Cast(cast) => cast.r#type.clone(), - Expression::For(_for_expr) => Type::Unit, - Expression::If(if_expr) => if_expr.typ.clone(), - Expression::Tuple(elements) => { - Type::Tuple(elements.iter().map(|e| e.type_of()).collect()) - } - Expression::ExtractTupleField(tuple, index) => match tuple.as_ref() { - Expression::Tuple(fields) => fields[*index].type_of(), - _ => unreachable!("ICE: Tuple field access on non-tuple type"), - }, - Expression::Call(call) => call.return_type.clone(), - Expression::Let(let_stmt) => let_stmt.expression.as_ref().type_of(), - Expression::Constrain(constraint, _) => constraint.as_ref().type_of(), - Expression::Assign(assign) => assign.lvalue.type_of(), - Expression::Semi(_expr) => Type::Unit, - } - } -} - -impl LValue { - pub fn type_of(&self) -> Type { - match self { - LValue::Ident(ident) => ident.typ.clone(), - LValue::Index { element_type, .. } => element_type.clone(), - LValue::MemberAccess { object, field_index } => { - let tuple_type = object.as_ref().type_of(); - match tuple_type { - Type::Tuple(fields) => fields[*field_index].clone(), - _ => unreachable!("ICE: Member access on non-tuple type"), - } - } - LValue::Dereference { element_type, .. } => element_type.clone(), - } - } -} - #[derive(Debug, Clone)] pub struct Program { pub functions: Vec, diff --git a/crates/noirc_frontend/src/monomorphization/mod.rs b/crates/noirc_frontend/src/monomorphization/mod.rs index 15e2bbf1fcf..058aaa73200 100644 --- a/crates/noirc_frontend/src/monomorphization/mod.rs +++ b/crates/noirc_frontend/src/monomorphization/mod.rs @@ -19,6 +19,7 @@ use crate::{ expr::*, function::{FuncMeta, Param, Parameters}, stmt::{HirAssignStatement, HirLValue, HirLetStatement, HirPattern, HirStatement}, + types }, node_interner::{self, DefinitionKind, NodeInterner, StmtId}, token::Attribute, @@ -356,7 +357,7 @@ impl<'interner> Monomorphizer<'interner> { } HirExpression::Constructor(constructor) => self.constructor(constructor, expr), - HirExpression::Lambda(lambda) => self.lambda(lambda), + HirExpression::Lambda(lambda) => self.lambda(lambda, expr), HirExpression::MethodCall(_) => { unreachable!("Encountered HirExpression::MethodCall during monomorphization") @@ -706,8 +707,8 @@ impl<'interner> Monomorphizer<'interner> { } } - fn is_function_closure(&self, func: &ast::Expression) -> bool { - let t = func.type_of(); + fn is_function_closure(&self, raw_func_id: node_interner::ExprId) -> bool { + let t = Self::convert_type(&self.interner.id_type(raw_func_id)); if self.is_function_closure_type(&t) { true } else if let ast::Type::Tuple(elements) = t { @@ -755,12 +756,12 @@ impl<'interner> Monomorphizer<'interner> { let mut block_expressions = vec![]; - let is_closure = self.is_function_closure(&original_func); + let is_closure = self.is_function_closure(call.func); if is_closure { let extracted_func: ast::Expression; let hir_call_func = self.interner.expression(&call.func); if let HirExpression::Lambda(l) = hir_call_func { - let (setup, closure_variable) = self.lambda_with_setup(l); + let (setup, closure_variable) = self.lambda_with_setup(l, call.func); block_expressions.push(setup); extracted_func = closure_variable; } else { @@ -998,7 +999,11 @@ impl<'interner> Monomorphizer<'interner> { } } - fn lambda_with_setup(&mut self, lambda: HirLambda) -> (ast::Expression, ast::Expression) { + fn lambda_with_setup( + &mut self, + lambda: HirLambda, + expr: node_interner::ExprId, + ) -> (ast::Expression, ast::Expression) { // returns (, ) // which can be used directly in callsites or transformed // directly to a single `Expression` @@ -1047,7 +1052,12 @@ impl<'interner> Monomorphizer<'interner> { } } })); - let env_typ = env_tuple.type_of(); + let expr_type = self.interner.id_type(expr); + let env_typ = if let types::Type::Function(_, _, function_env_type) = expr_type { + Self::convert_type(&function_env_type) + } else { + unreachable!("expected a Function type for a Lambda node") + }; let env_let_stmt = ast::Expression::Let(ast::Let { id: env_local_id, @@ -1117,8 +1127,8 @@ impl<'interner> Monomorphizer<'interner> { // ast::Expression::Block(vec![block_let_stmt, closure_ident]) } - fn lambda(&mut self, lambda: HirLambda) -> ast::Expression { - let (setup, closure_variable) = self.lambda_with_setup(lambda); + fn lambda(&mut self, lambda: HirLambda, expr: node_interner::ExprId) -> ast::Expression { + let (setup, closure_variable) = self.lambda_with_setup(lambda, expr); ast::Expression::Block(vec![setup, closure_variable]) } From b2b964be2c57f4e9bc5e29a14e177acab63ed6e7 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Fri, 28 Jul 2023 17:59:59 +0300 Subject: [PATCH 08/26] test: move closure-related tests to test_data --- .../closures_mut_ref/Nargo.toml | 0 .../closures_mut_ref/Prover.toml | 0 .../closures_mut_ref/src/main.nr | 0 .../fibonacci_by_ref/Nargo.toml | 0 .../fibonacci_by_ref/Prover.toml | 0 .../fibonacci_by_ref/src/main.nr | 0 .../higher_order_fn_selector/Nargo.toml | 0 .../higher_order_fn_selector/src/main.nr | 0 .../higher_order_functions/Nargo.toml | 0 .../higher_order_functions/Prover.toml | 0 .../higher_order_functions/src/main.nr | 0 .../higher_order_functions/target/c.json | 0 .../higher_order_functions/target/main.json | 0 .../higher_order_functions/target/witness.tr | Bin .../inner_outer_cl/Nargo.toml | 0 .../inner_outer_cl/src/main.nr | 0 .../ret_fn_ret_cl/Nargo.toml | 0 .../ret_fn_ret_cl/Prover.toml | 0 .../ret_fn_ret_cl/src/main.nr | 0 19 files changed, 0 insertions(+), 0 deletions(-) rename crates/nargo_cli/tests/{test_data_ssa_refactor => test_data}/closures_mut_ref/Nargo.toml (100%) rename crates/nargo_cli/tests/{test_data_ssa_refactor => test_data}/closures_mut_ref/Prover.toml (100%) rename crates/nargo_cli/tests/{test_data_ssa_refactor => test_data}/closures_mut_ref/src/main.nr (100%) rename crates/nargo_cli/tests/{test_data_ssa_refactor => test_data}/fibonacci_by_ref/Nargo.toml (100%) rename crates/nargo_cli/tests/{test_data_ssa_refactor => test_data}/fibonacci_by_ref/Prover.toml (100%) rename crates/nargo_cli/tests/{test_data_ssa_refactor => test_data}/fibonacci_by_ref/src/main.nr (100%) rename crates/nargo_cli/tests/{test_data_ssa_refactor => test_data}/higher_order_fn_selector/Nargo.toml (100%) rename crates/nargo_cli/tests/{test_data_ssa_refactor => test_data}/higher_order_fn_selector/src/main.nr (100%) rename crates/nargo_cli/tests/{test_data_ssa_refactor => test_data}/higher_order_functions/Nargo.toml (100%) rename crates/nargo_cli/tests/{test_data_ssa_refactor => test_data}/higher_order_functions/Prover.toml (100%) rename crates/nargo_cli/tests/{test_data_ssa_refactor => test_data}/higher_order_functions/src/main.nr (100%) rename crates/nargo_cli/tests/{test_data_ssa_refactor => test_data}/higher_order_functions/target/c.json (100%) rename crates/nargo_cli/tests/{test_data_ssa_refactor => test_data}/higher_order_functions/target/main.json (100%) rename crates/nargo_cli/tests/{test_data_ssa_refactor => test_data}/higher_order_functions/target/witness.tr (100%) rename crates/nargo_cli/tests/{test_data_ssa_refactor => test_data}/inner_outer_cl/Nargo.toml (100%) rename crates/nargo_cli/tests/{test_data_ssa_refactor => test_data}/inner_outer_cl/src/main.nr (100%) rename crates/nargo_cli/tests/{test_data_ssa_refactor => test_data}/ret_fn_ret_cl/Nargo.toml (100%) rename crates/nargo_cli/tests/{test_data_ssa_refactor => test_data}/ret_fn_ret_cl/Prover.toml (100%) rename crates/nargo_cli/tests/{test_data_ssa_refactor => test_data}/ret_fn_ret_cl/src/main.nr (100%) diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/closures_mut_ref/Nargo.toml b/crates/nargo_cli/tests/test_data/closures_mut_ref/Nargo.toml similarity index 100% rename from crates/nargo_cli/tests/test_data_ssa_refactor/closures_mut_ref/Nargo.toml rename to crates/nargo_cli/tests/test_data/closures_mut_ref/Nargo.toml diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/closures_mut_ref/Prover.toml b/crates/nargo_cli/tests/test_data/closures_mut_ref/Prover.toml similarity index 100% rename from crates/nargo_cli/tests/test_data_ssa_refactor/closures_mut_ref/Prover.toml rename to crates/nargo_cli/tests/test_data/closures_mut_ref/Prover.toml diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/closures_mut_ref/src/main.nr b/crates/nargo_cli/tests/test_data/closures_mut_ref/src/main.nr similarity index 100% rename from crates/nargo_cli/tests/test_data_ssa_refactor/closures_mut_ref/src/main.nr rename to crates/nargo_cli/tests/test_data/closures_mut_ref/src/main.nr diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/fibonacci_by_ref/Nargo.toml b/crates/nargo_cli/tests/test_data/fibonacci_by_ref/Nargo.toml similarity index 100% rename from crates/nargo_cli/tests/test_data_ssa_refactor/fibonacci_by_ref/Nargo.toml rename to crates/nargo_cli/tests/test_data/fibonacci_by_ref/Nargo.toml diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/fibonacci_by_ref/Prover.toml b/crates/nargo_cli/tests/test_data/fibonacci_by_ref/Prover.toml similarity index 100% rename from crates/nargo_cli/tests/test_data_ssa_refactor/fibonacci_by_ref/Prover.toml rename to crates/nargo_cli/tests/test_data/fibonacci_by_ref/Prover.toml diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/fibonacci_by_ref/src/main.nr b/crates/nargo_cli/tests/test_data/fibonacci_by_ref/src/main.nr similarity index 100% rename from crates/nargo_cli/tests/test_data_ssa_refactor/fibonacci_by_ref/src/main.nr rename to crates/nargo_cli/tests/test_data/fibonacci_by_ref/src/main.nr diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_fn_selector/Nargo.toml b/crates/nargo_cli/tests/test_data/higher_order_fn_selector/Nargo.toml similarity index 100% rename from crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_fn_selector/Nargo.toml rename to crates/nargo_cli/tests/test_data/higher_order_fn_selector/Nargo.toml diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_fn_selector/src/main.nr b/crates/nargo_cli/tests/test_data/higher_order_fn_selector/src/main.nr similarity index 100% rename from crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_fn_selector/src/main.nr rename to crates/nargo_cli/tests/test_data/higher_order_fn_selector/src/main.nr diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/Nargo.toml b/crates/nargo_cli/tests/test_data/higher_order_functions/Nargo.toml similarity index 100% rename from crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/Nargo.toml rename to crates/nargo_cli/tests/test_data/higher_order_functions/Nargo.toml diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/Prover.toml b/crates/nargo_cli/tests/test_data/higher_order_functions/Prover.toml similarity index 100% rename from crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/Prover.toml rename to crates/nargo_cli/tests/test_data/higher_order_functions/Prover.toml diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/src/main.nr b/crates/nargo_cli/tests/test_data/higher_order_functions/src/main.nr similarity index 100% rename from crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/src/main.nr rename to crates/nargo_cli/tests/test_data/higher_order_functions/src/main.nr diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/target/c.json b/crates/nargo_cli/tests/test_data/higher_order_functions/target/c.json similarity index 100% rename from crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/target/c.json rename to crates/nargo_cli/tests/test_data/higher_order_functions/target/c.json diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/target/main.json b/crates/nargo_cli/tests/test_data/higher_order_functions/target/main.json similarity index 100% rename from crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/target/main.json rename to crates/nargo_cli/tests/test_data/higher_order_functions/target/main.json diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/target/witness.tr b/crates/nargo_cli/tests/test_data/higher_order_functions/target/witness.tr similarity index 100% rename from crates/nargo_cli/tests/test_data_ssa_refactor/higher_order_functions/target/witness.tr rename to crates/nargo_cli/tests/test_data/higher_order_functions/target/witness.tr diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/inner_outer_cl/Nargo.toml b/crates/nargo_cli/tests/test_data/inner_outer_cl/Nargo.toml similarity index 100% rename from crates/nargo_cli/tests/test_data_ssa_refactor/inner_outer_cl/Nargo.toml rename to crates/nargo_cli/tests/test_data/inner_outer_cl/Nargo.toml diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/inner_outer_cl/src/main.nr b/crates/nargo_cli/tests/test_data/inner_outer_cl/src/main.nr similarity index 100% rename from crates/nargo_cli/tests/test_data_ssa_refactor/inner_outer_cl/src/main.nr rename to crates/nargo_cli/tests/test_data/inner_outer_cl/src/main.nr diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/ret_fn_ret_cl/Nargo.toml b/crates/nargo_cli/tests/test_data/ret_fn_ret_cl/Nargo.toml similarity index 100% rename from crates/nargo_cli/tests/test_data_ssa_refactor/ret_fn_ret_cl/Nargo.toml rename to crates/nargo_cli/tests/test_data/ret_fn_ret_cl/Nargo.toml diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/ret_fn_ret_cl/Prover.toml b/crates/nargo_cli/tests/test_data/ret_fn_ret_cl/Prover.toml similarity index 100% rename from crates/nargo_cli/tests/test_data_ssa_refactor/ret_fn_ret_cl/Prover.toml rename to crates/nargo_cli/tests/test_data/ret_fn_ret_cl/Prover.toml diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/ret_fn_ret_cl/src/main.nr b/crates/nargo_cli/tests/test_data/ret_fn_ret_cl/src/main.nr similarity index 100% rename from crates/nargo_cli/tests/test_data_ssa_refactor/ret_fn_ret_cl/src/main.nr rename to crates/nargo_cli/tests/test_data/ret_fn_ret_cl/src/main.nr From 35121eebf2a508beb76b54e5a1e656c0dcaf81b7 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Fri, 28 Jul 2023 19:47:07 +0300 Subject: [PATCH 09/26] test: update closure rewrite test output --- crates/noirc_frontend/src/monomorphization/mod.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/crates/noirc_frontend/src/monomorphization/mod.rs b/crates/noirc_frontend/src/monomorphization/mod.rs index 058aaa73200..11f59538f30 100644 --- a/crates/noirc_frontend/src/monomorphization/mod.rs +++ b/crates/noirc_frontend/src/monomorphization/mod.rs @@ -1408,11 +1408,14 @@ mod tests { "#; let expected_rewrite = r#"fn main$f0() -> Field { - let closure$2 = { - let env$1 = (); - (env$l1, lambda$f1) + let closure$3 = { + let closure_variable$2 = { + let env$1 = (); + (env$l1, lambda$f1) + }; + closure_variable$l2 }; - closure$l2.1(closure$l2.0, 0) + closure$l3.1(closure$l3.0, 0) } fn lambda$f1(mut env$l1: (), x$l0: Field) -> Field { x$l0 From c0b2567630c10bb89b14fd26f255bbe8de322fc3 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Mon, 31 Jul 2023 13:24:47 +0300 Subject: [PATCH 10/26] chore: apply cargo fmt changes --- crates/noirc_frontend/src/hir/type_check/expr.rs | 2 +- crates/noirc_frontend/src/hir_def/types.rs | 5 ++++- crates/noirc_frontend/src/monomorphization/mod.rs | 11 ++++------- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/crates/noirc_frontend/src/hir/type_check/expr.rs b/crates/noirc_frontend/src/hir/type_check/expr.rs index c98d79cf950..126a4400985 100644 --- a/crates/noirc_frontend/src/hir/type_check/expr.rs +++ b/crates/noirc_frontend/src/hir/type_check/expr.rs @@ -913,7 +913,7 @@ impl<'interner> TypeChecker<'interner> { &mut self, function: Type, args: Vec<(Type, ExprId, Span)>, - span: Span + span: Span, ) -> Type { // Could do a single unification for the entire function type, but matching beforehand // lets us issue a more precise error on the individual argument that fails to type check. diff --git a/crates/noirc_frontend/src/hir_def/types.rs b/crates/noirc_frontend/src/hir_def/types.rs index d8c1acb4245..7b92d833bc0 100644 --- a/crates/noirc_frontend/src/hir_def/types.rs +++ b/crates/noirc_frontend/src/hir_def/types.rs @@ -1767,7 +1767,10 @@ fn convert_array_expression_to_slice( interner.push_expr_location(func, location.span, location.file); interner.push_expr_type(&call, target_type.clone()); - interner.push_expr_type(&func, Type::Function(vec![array_type], Box::new(target_type), Box::new(Type::Unit))); + interner.push_expr_type( + &func, + Type::Function(vec![array_type], Box::new(target_type), Box::new(Type::Unit)), + ); } impl BinaryTypeOperator { diff --git a/crates/noirc_frontend/src/monomorphization/mod.rs b/crates/noirc_frontend/src/monomorphization/mod.rs index 11f59538f30..869870022b0 100644 --- a/crates/noirc_frontend/src/monomorphization/mod.rs +++ b/crates/noirc_frontend/src/monomorphization/mod.rs @@ -19,7 +19,7 @@ use crate::{ expr::*, function::{FuncMeta, Param, Parameters}, stmt::{HirAssignStatement, HirLValue, HirLetStatement, HirPattern, HirStatement}, - types + types, }, node_interner::{self, DefinitionKind, NodeInterner, StmtId}, token::Attribute, @@ -777,12 +777,9 @@ impl<'interner> Monomorphizer<'interner> { func = original_func.clone(); }; - let call = self.try_evaluate_call(&func, &return_type).unwrap_or(ast::Expression::Call(ast::Call { - func, - arguments, - return_type, - location, - })); + let call = self + .try_evaluate_call(&func, &return_type) + .unwrap_or(ast::Expression::Call(ast::Call { func, arguments, return_type, location })); if !block_expressions.is_empty() { block_expressions.push(call); From 45a3a52c963f54f0d5a8bbea5698581e72058e18 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Mon, 31 Jul 2023 14:23:19 +0300 Subject: [PATCH 11/26] test: capture some variables in some tests, fix warnings, add a TODO add a TODO about returning closures --- .../test_data/closures_mut_ref/src/main.nr | 7 ++-- .../test_data/fibonacci_by_ref/src/main.nr | 3 +- .../higher_order_fn_selector/src/main.nr | 8 ++--- .../test_data/inner_outer_cl/src/main.nr | 8 +++-- .../tests/test_data/ret_fn_ret_cl/src/main.nr | 32 ++++++++++++------- 5 files changed, 35 insertions(+), 23 deletions(-) diff --git a/crates/nargo_cli/tests/test_data/closures_mut_ref/src/main.nr b/crates/nargo_cli/tests/test_data/closures_mut_ref/src/main.nr index e02cfb6880d..ae990e004fd 100644 --- a/crates/nargo_cli/tests/test_data/closures_mut_ref/src/main.nr +++ b/crates/nargo_cli/tests/test_data/closures_mut_ref/src/main.nr @@ -1,13 +1,14 @@ use dep::std; fn main(mut x: Field) { - + let one = 1; let add1 = |z| { - *z = *z + 1; + *z = *z + one; }; + let two = 2; let add2 = |z| { - *z = *z + 2; + *z = *z + two; }; add1(&mut x); diff --git a/crates/nargo_cli/tests/test_data/fibonacci_by_ref/src/main.nr b/crates/nargo_cli/tests/test_data/fibonacci_by_ref/src/main.nr index 763c5165c0f..d972e795822 100644 --- a/crates/nargo_cli/tests/test_data/fibonacci_by_ref/src/main.nr +++ b/crates/nargo_cli/tests/test_data/fibonacci_by_ref/src/main.nr @@ -5,11 +5,10 @@ fn fib_fn(a: Field, b: Field, res: &mut Field) { fn main(mut prev: Field, mut cur: Field) { let mut fib = prev + cur; - for i in 1..10 { + for _ in 1..10 { prev = cur; cur = fib; fib_fn(prev, cur, &mut fib); assert(prev + cur == fib); } - } diff --git a/crates/nargo_cli/tests/test_data/higher_order_fn_selector/src/main.nr b/crates/nargo_cli/tests/test_data/higher_order_fn_selector/src/main.nr index 3b8ec51ba00..4eabe059be0 100644 --- a/crates/nargo_cli/tests/test_data/higher_order_fn_selector/src/main.nr +++ b/crates/nargo_cli/tests/test_data/higher_order_fn_selector/src/main.nr @@ -1,3 +1,5 @@ +use dep::std; + fn f(x: &mut Field) -> Field { *x = *x - 1; 1 @@ -13,8 +15,6 @@ fn h(x: &mut Field) -> Field { 1 } -use dep::std; - fn selector(flag:&mut bool) -> fn(&mut Field) -> Field { //TODO: Can we have fn(&mut Field) -> () return type? let mut my_func = f; @@ -36,13 +36,13 @@ fn main() { let mut x: Field = 100; let returned_func = selector(&mut flag); - let status = returned_func(&mut x); + let _status = returned_func(&mut x); assert(x == 200); let mut y: Field = 100; let returned_func2 = selector(&mut flag); - let status2 = returned_func2(&mut y); + let _status2 = returned_func2(&mut y); assert(y == 300); diff --git a/crates/nargo_cli/tests/test_data/inner_outer_cl/src/main.nr b/crates/nargo_cli/tests/test_data/inner_outer_cl/src/main.nr index dcf97d709a9..ce847b56b93 100644 --- a/crates/nargo_cli/tests/test_data/inner_outer_cl/src/main.nr +++ b/crates/nargo_cli/tests/test_data/inner_outer_cl/src/main.nr @@ -1,10 +1,12 @@ fn main() { + let z1 = 0; + let z2 = 1; let cl_outer = |x| { let cl_inner = |y| { - x + y + x + y + z2 }; - cl_inner(1) + cl_inner(1) + z1 }; let result = cl_outer(1); - assert(result == 2); + assert(result == 3); } diff --git a/crates/nargo_cli/tests/test_data/ret_fn_ret_cl/src/main.nr b/crates/nargo_cli/tests/test_data/ret_fn_ret_cl/src/main.nr index 974c5321f64..d3a3346b541 100644 --- a/crates/nargo_cli/tests/test_data/ret_fn_ret_cl/src/main.nr +++ b/crates/nargo_cli/tests/test_data/ret_fn_ret_cl/src/main.nr @@ -1,29 +1,39 @@ use dep::std; fn f(x: Field) -> Field { - x + x + 1 } fn ret_fn() -> fn(Field) -> Field { - let y = 1; - let inner_closure = |z| -> Field{ - z + y - }; - std::println(inner_closure(1)); f } -fn ret_closure() -> fn(Field) -> Field { +// TODO: in the advanced implicitly generic function with closures branch +// which would support higher-order functions in a better way +// support returning closures: +// +// fn ret_closure() -> fn(Field) -> Field { +// let y = 1; +// let inner_closure = |z| -> Field{ +// z + y +// }; +// inner_closure +// } + +fn ret_lambda() -> fn(Field) -> Field { let cl = |z: Field| -> Field { - z + z + 1 }; cl } fn main(x : Field) { let result_fn = ret_fn(); - assert(result_fn(x) == x); // Works + assert(result_fn(x) == x + 1); + + // let result_closure = ret_closure(); + // assert(result_closure(x) == x + 1); - let result_cl = ret_closure(); - assert(result_cl(x) == x); + let result_lambda = ret_lambda(); + assert(result_lambda(x) == x + 1); } From 0233a1ad06379b60a6ff7a38eaea65ff469422ee Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Mon, 31 Jul 2023 14:35:27 +0300 Subject: [PATCH 12/26] test: add simplification of #1088 as a resolve test, enable another test --- .../src/hir/resolution/resolver.rs | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/crates/noirc_frontend/src/hir/resolution/resolver.rs b/crates/noirc_frontend/src/hir/resolution/resolver.rs index 12525cb3ba8..dae01eee8f4 100644 --- a/crates/noirc_frontend/src/hir/resolution/resolver.rs +++ b/crates/noirc_frontend/src/hir/resolution/resolver.rs @@ -1815,6 +1815,7 @@ mod test { } } + #[test] fn resolve_basic_closure() { let src = r#" fn main(x : Field) -> pub Field { @@ -1829,6 +1830,29 @@ mod test { } } + #[test] + fn resolve_simplified_closure() { + // based on bug https://github.com/noir-lang/noir/issues/1088 + + let src = r#"fn do_closure(x: Field) -> Field { + let y = x; + let ret_capture = || { + y + }; + ret_capture() + } + + fn main(x: Field) { + assert(do_closure(x) == 100); + } + + "#; + let parsed_captures = get_program_captures(src); + let mut expected_captures = vec![]; + expected_captures.push(vec!["y".to_string()]); + assert_eq!(expected_captures, parsed_captures); + } + #[test] fn resolve_complex_closures() { let src = r#" From 1ab754f53ea72469d5d52cc822fa702265a93d27 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Tue, 1 Aug 2023 15:23:54 +0300 Subject: [PATCH 13/26] fix: fix unify for closures, fix display for fn/closure types --- crates/noirc_frontend/src/hir_def/types.rs | 36 ++++++++++++++++++---- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/crates/noirc_frontend/src/hir_def/types.rs b/crates/noirc_frontend/src/hir_def/types.rs index 7b92d833bc0..6946d6bc596 100644 --- a/crates/noirc_frontend/src/hir_def/types.rs +++ b/crates/noirc_frontend/src/hir_def/types.rs @@ -801,8 +801,14 @@ impl std::fmt::Display for Type { write!(f, "forall {}. {}", typevars.join(" "), typ) } Type::Function(args, ret, env) => { - let args = vecmap(args, ToString::to_string); - write!(f, "fn({}) -> {} [{}]", args.join(", "), ret, env) + let (params_skip_count, closure_env_text) = match **env { + Type::Unit => (0, "".to_string()), + _ => (1, format!(" with closure environment {env}")), + }; + + let args = vecmap(args.iter().skip(params_skip_count), ToString::to_string); + + write!(f, "fn({}) -> {ret}{closure_env_text}", args.join(", ")) } Type::MutableReference(element) => { write!(f, "&mut {element}") @@ -1200,13 +1206,31 @@ impl Type { } (Function(params_a, ret_a, env_a), Function(params_b, ret_b, env_b)) => { - if params_a.len() == params_b.len() { - for (a, b) in params_a.iter().zip(params_b) { + let skip_params_count_a = match **env_a { + // non-closure function: + Type::Unit => 0, + // possibly a closure: so we transform the function to pass env as a first arg + // which means we should skip the first arg now in param checking + _ => 1, + }; + let real_fn_param_count_a = params_a.len() - skip_params_count_a; + + let skip_params_count_b = match **env_b { + Type::Unit => 0, + _ => 1, + }; + let real_fn_param_count_b = params_b.len() - skip_params_count_b; + + if real_fn_param_count_a == real_fn_param_count_b { + for (a, b) in params_a + .iter() + .skip(skip_params_count_a) + .zip(params_b.iter().skip(skip_params_count_b)) + { a.try_unify(b, span)?; } - ret_b.try_unify(ret_a, span)?; - env_a.try_unify(env_b, span) + ret_b.try_unify(ret_a, span) } else { Err(SpanKind::None) } From 8021212181b1dc3c6c29e33aed8d137c7e165808 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Tue, 1 Aug 2023 18:49:43 +0300 Subject: [PATCH 14/26] test: update closure tests after resolving mutable bug --- .../higher_order_functions/src/main.nr | 32 ++++++++----------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/crates/nargo_cli/tests/test_data/higher_order_functions/src/main.nr b/crates/nargo_cli/tests/test_data/higher_order_functions/src/main.nr index a6e328b09af..eeaa5cf9f64 100644 --- a/crates/nargo_cli/tests/test_data/higher_order_functions/src/main.nr +++ b/crates/nargo_cli/tests/test_data/higher_order_functions/src/main.nr @@ -18,27 +18,23 @@ fn main() -> pub Field { let g = || a; assert(g() == 42); - // TODO: enable this again after fixing #2054 - // https://github.com/noir-lang/noir/issues/2054 - // by @jfecher's PR https://github.com/noir-lang/noir/pull/2057 - - // Mutable variables cannot be captured, but you can - // copy them into immutable variables and capture those: - // let mut x = 2; - // x = x + 1; - // let z = x; + // When you copy mutable variables, + // the capture of the copies shouldn't change: + let mut x = 2; + x = x + 1; + let z = x; // Add extra mutations to ensure we can mutate x without the // captured z changing. - // x = x + 1; - // TODO: this behavior changed in the new ssa backend: - // now even z is changed, and it wasn't in the previous backend - // assert(z == 2); - // fails! - // decide what to do after opening an issue about the simpler - // variable alias case - // - // assert((|y| y + z)(1) == 4); + x = x + 1; + assert((|y| y + z)(1) == 4); + + // When you capture mutable variables, + // again, the captured variable doesn't change: + let closure_capturing_mutable = (|y| y + x); + assert(closure_capturing_mutable(1) == 5); + x += 1; + assert(closure_capturing_mutable(1) == 5); let ret = twice(add1, 3); From 7c041ce320221f431f3d36bdc3b1520bab69561a Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Wed, 2 Aug 2023 13:22:18 +0300 Subject: [PATCH 15/26] fix: address some review comments for closure PR: fixes/cleanup --- .../tests/test_data/fibonacci_by_ref/Nargo.toml | 6 ------ .../tests/test_data/fibonacci_by_ref/Prover.toml | 2 -- .../tests/test_data/fibonacci_by_ref/src/main.nr | 14 -------------- .../test_data/higher_order_functions/src/main.nr | 3 +++ .../noirc_frontend/src/hir/resolution/resolver.rs | 2 +- crates/noirc_frontend/src/hir_def/types.rs | 3 ++- crates/noirc_frontend/src/monomorphization/ast.rs | 7 +++++-- crates/noirc_frontend/src/node_interner.rs | 2 +- 8 files changed, 12 insertions(+), 27 deletions(-) delete mode 100644 crates/nargo_cli/tests/test_data/fibonacci_by_ref/Nargo.toml delete mode 100644 crates/nargo_cli/tests/test_data/fibonacci_by_ref/Prover.toml delete mode 100644 crates/nargo_cli/tests/test_data/fibonacci_by_ref/src/main.nr diff --git a/crates/nargo_cli/tests/test_data/fibonacci_by_ref/Nargo.toml b/crates/nargo_cli/tests/test_data/fibonacci_by_ref/Nargo.toml deleted file mode 100644 index f5d28236db2..00000000000 --- a/crates/nargo_cli/tests/test_data/fibonacci_by_ref/Nargo.toml +++ /dev/null @@ -1,6 +0,0 @@ -[package] -name = "fibonacci_by_ref" -authors = [""] -compiler_version = "0.8.0" - -[dependencies] \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data/fibonacci_by_ref/Prover.toml b/crates/nargo_cli/tests/test_data/fibonacci_by_ref/Prover.toml deleted file mode 100644 index d594b02e17d..00000000000 --- a/crates/nargo_cli/tests/test_data/fibonacci_by_ref/Prover.toml +++ /dev/null @@ -1,2 +0,0 @@ -prev = "1" -cur = "2" diff --git a/crates/nargo_cli/tests/test_data/fibonacci_by_ref/src/main.nr b/crates/nargo_cli/tests/test_data/fibonacci_by_ref/src/main.nr deleted file mode 100644 index d972e795822..00000000000 --- a/crates/nargo_cli/tests/test_data/fibonacci_by_ref/src/main.nr +++ /dev/null @@ -1,14 +0,0 @@ -fn fib_fn(a: Field, b: Field, res: &mut Field) { - *res = a + b; -} - -fn main(mut prev: Field, mut cur: Field) { - - let mut fib = prev + cur; - for _ in 1..10 { - prev = cur; - cur = fib; - fib_fn(prev, cur, &mut fib); - assert(prev + cur == fib); - } -} diff --git a/crates/nargo_cli/tests/test_data/higher_order_functions/src/main.nr b/crates/nargo_cli/tests/test_data/higher_order_functions/src/main.nr index eeaa5cf9f64..fefd23b7dbc 100644 --- a/crates/nargo_cli/tests/test_data/higher_order_functions/src/main.nr +++ b/crates/nargo_cli/tests/test_data/higher_order_functions/src/main.nr @@ -61,6 +61,9 @@ fn test_array_functions() { // but it worked with the original ssa backend // (before dropping it) // + // opened #2121 for it + // https://github.com/noir-lang/noir/issues/2121 + // let descending = myarray.sort_via(|a, b| a > b); // assert(descending == [3, 2, 1]); diff --git a/crates/noirc_frontend/src/hir/resolution/resolver.rs b/crates/noirc_frontend/src/hir/resolution/resolver.rs index dae01eee8f4..681c853899f 100644 --- a/crates/noirc_frontend/src/hir/resolution/resolver.rs +++ b/crates/noirc_frontend/src/hir/resolution/resolver.rs @@ -1906,7 +1906,7 @@ mod test { assert_eq!(expected_captures, parsed_captures); } - + #[test] fn resolve_fmt_strings() { let src = r#" diff --git a/crates/noirc_frontend/src/hir_def/types.rs b/crates/noirc_frontend/src/hir_def/types.rs index 6946d6bc596..ba9ff6e2d65 100644 --- a/crates/noirc_frontend/src/hir_def/types.rs +++ b/crates/noirc_frontend/src/hir_def/types.rs @@ -700,9 +700,10 @@ impl Type { Type::Tuple(fields) => { fields.iter().any(|field| field.contains_numeric_typevar(target_id)) } - Type::Function(parameters, return_type, _env) => { + Type::Function(parameters, return_type, env) => { parameters.iter().any(|parameter| parameter.contains_numeric_typevar(target_id)) || return_type.contains_numeric_typevar(target_id) + || env.contains_numeric_typevar(target_id) } Type::Struct(struct_type, generics) => { generics.iter().enumerate().any(|(i, generic)| { diff --git a/crates/noirc_frontend/src/monomorphization/ast.rs b/crates/noirc_frontend/src/monomorphization/ast.rs index 42396543d26..33c3bbebff4 100644 --- a/crates/noirc_frontend/src/monomorphization/ast.rs +++ b/crates/noirc_frontend/src/monomorphization/ast.rs @@ -106,7 +106,6 @@ pub struct Binary { pub struct Lambda { pub function: Ident, pub env: Ident, - pub typ: Type, // TODO: Perhaps this is not necessary } #[derive(Debug, Clone)] @@ -332,7 +331,11 @@ impl std::fmt::Display for Type { } Type::Function(args, ret, env) => { let args = vecmap(args, ToString::to_string); - write!(f, "fn({}) -> {} [{}]", args.join(", "), ret, env) + let closure_env_text = match **env { + Type::Unit => "".to_string(), + _ => format!(" with closure environment {env}"), + }; + write!(f, "fn({}) -> {}{}", args.join(", "), ret, closure_env_text) } Type::Slice(element) => write!(f, "[{element}"), Type::MutableReference(element) => write!(f, "&mut {element}"), diff --git a/crates/noirc_frontend/src/node_interner.rs b/crates/noirc_frontend/src/node_interner.rs index 7a1d53f4dbf..6b3d2757c14 100644 --- a/crates/noirc_frontend/src/node_interner.rs +++ b/crates/noirc_frontend/src/node_interner.rs @@ -683,6 +683,6 @@ fn get_type_method_key(typ: &Type) -> Option { | Type::Error | Type::NotConstant | Type::Struct(_, _) - | Type::FmtString(_, _) => None + | Type::FmtString(_, _) => None, } } From 241e152e6888ab29a11fe7a7b26e6bd56cfdbaeb Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Wed, 2 Aug 2023 13:48:14 +0300 Subject: [PATCH 16/26] refactor: cleanup, remove a line Co-authored-by: jfecher --- crates/noirc_frontend/src/monomorphization/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/noirc_frontend/src/monomorphization/mod.rs b/crates/noirc_frontend/src/monomorphization/mod.rs index 869870022b0..549a52825d1 100644 --- a/crates/noirc_frontend/src/monomorphization/mod.rs +++ b/crates/noirc_frontend/src/monomorphization/mod.rs @@ -1121,7 +1121,6 @@ impl<'interner> Monomorphizer<'interner> { }); (block_let_stmt, closure_ident) - // ast::Expression::Block(vec![block_let_stmt, closure_ident]) } fn lambda(&mut self, lambda: HirLambda, expr: node_interner::ExprId) -> ast::Expression { From 7d7d274abe060d91a0c9bed5d4bdf1b023910538 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Wed, 2 Aug 2023 13:49:46 +0300 Subject: [PATCH 17/26] refactor: cleanup Co-authored-by: jfecher --- crates/noirc_evaluator/src/ssa/ssa_gen/mod.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa/ssa_gen/mod.rs b/crates/noirc_evaluator/src/ssa/ssa_gen/mod.rs index c89254f50f1..0c0dd35211b 100644 --- a/crates/noirc_evaluator/src/ssa/ssa_gen/mod.rs +++ b/crates/noirc_evaluator/src/ssa/ssa_gen/mod.rs @@ -86,8 +86,7 @@ impl<'a> FunctionContext<'a> { /// Codegen any non-tuple expression so that we can unwrap the Values /// tree to return a single value for use with most SSA instructions. fn codegen_non_tuple_expression(&mut self, expr: &Expression) -> ValueId { - let e = self.codegen_expression(expr); - e.into_leaf().eval(self) + self.codegen_expression(expr).into_leaf().eval(self) } /// Codegen a reference to an ident. From 22c9c828bd984e65e211f55733bd96f1b6d5f79c Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Wed, 2 Aug 2023 13:58:34 +0300 Subject: [PATCH 18/26] fix: fix bind_function_type env_type handling type variable binding --- crates/noirc_frontend/src/hir/type_check/expr.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/noirc_frontend/src/hir/type_check/expr.rs b/crates/noirc_frontend/src/hir/type_check/expr.rs index 126a4400985..8905ac9f651 100644 --- a/crates/noirc_frontend/src/hir/type_check/expr.rs +++ b/crates/noirc_frontend/src/hir/type_check/expr.rs @@ -925,7 +925,8 @@ impl<'interner> TypeChecker<'interner> { let ret = self.interner.next_type_variable(); let args = vecmap(args, |(arg, _, _)| arg); - let expected = Type::Function(args, Box::new(ret.clone()), Box::new(Type::Unit)); + let env_type = self.interner.next_type_variable(); + let expected = Type::Function(args, Box::new(ret.clone()), Box::new(env_type)); if let Err(error) = binding.borrow_mut().bind_to(expected, span) { self.errors.push(error); From 81903d8c04293b98f01d91b29903dab243021655 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Wed, 2 Aug 2023 14:08:40 +0300 Subject: [PATCH 19/26] test: improve higher_order_fn_selector test --- .../higher_order_fn_selector/src/main.nr | 28 ++++++------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/crates/nargo_cli/tests/test_data/higher_order_fn_selector/src/main.nr b/crates/nargo_cli/tests/test_data/higher_order_fn_selector/src/main.nr index 4eabe059be0..767cff0c409 100644 --- a/crates/nargo_cli/tests/test_data/higher_order_fn_selector/src/main.nr +++ b/crates/nargo_cli/tests/test_data/higher_order_fn_selector/src/main.nr @@ -1,28 +1,18 @@ use dep::std; -fn f(x: &mut Field) -> Field { - *x = *x - 1; - 1 -} - -fn g(x: &mut Field) -> Field { +fn g(x: &mut Field) -> () { *x *= 2; - 1 } -fn h(x: &mut Field) -> Field { +fn h(x: &mut Field) -> () { *x *= 3; - 1 } -fn selector(flag:&mut bool) -> fn(&mut Field) -> Field { //TODO: Can we have fn(&mut Field) -> () return type? - let mut my_func = f; - - if *flag { - my_func = g; - } - else { - my_func = h; +fn selector(flag: &mut bool) -> fn(&mut Field) -> () { + let my_func = if *flag { + g + } else { + h }; // Flip the flag for the next function call @@ -36,13 +26,13 @@ fn main() { let mut x: Field = 100; let returned_func = selector(&mut flag); - let _status = returned_func(&mut x); + returned_func(&mut x); assert(x == 200); let mut y: Field = 100; let returned_func2 = selector(&mut flag); - let _status2 = returned_func2(&mut y); + returned_func2(&mut y); assert(y == 300); From fb4e2cd817b8c70505b2181a4ce35ec08862636c Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Wed, 2 Aug 2023 15:24:44 +0300 Subject: [PATCH 20/26] fix: remove skip_params/additional param logic from typechecking/display --- .../noirc_frontend/src/hir/type_check/expr.rs | 35 ++++---------- crates/noirc_frontend/src/hir_def/types.rs | 46 ++++--------------- 2 files changed, 19 insertions(+), 62 deletions(-) diff --git a/crates/noirc_frontend/src/hir/type_check/expr.rs b/crates/noirc_frontend/src/hir/type_check/expr.rs index 8905ac9f651..0cfad16915d 100644 --- a/crates/noirc_frontend/src/hir/type_check/expr.rs +++ b/crates/noirc_frontend/src/hir/type_check/expr.rs @@ -283,12 +283,11 @@ impl<'interner> TypeChecker<'interner> { vecmap(lambda.captures, |capture| self.interner.id_type(capture.ident.id)); let env_type = Type::Tuple(captured_vars); - let mut params = vec![env_type.clone()]; - for (pattern, typ) in lambda.parameters { + let params: Vec = vecmap(lambda.parameters, |(pattern, typ)| { self.bind_pattern(&pattern, typ.clone()); - params.push(typ); - } + typ + }); let actual_return = self.check_expression(&lambda.body); @@ -883,20 +882,17 @@ impl<'interner> TypeChecker<'interner> { fn_ret: &Type, callsite_args: &Vec<(Type, ExprId, Span)>, span: Span, - skip_params: usize, ) -> Type { - let real_fn_params_count = fn_params.len() - skip_params; - - if real_fn_params_count != callsite_args.len() { + if fn_params.len() != callsite_args.len() { self.errors.push(TypeCheckError::ParameterCountMismatch { - expected: real_fn_params_count, + expected: fn_params.len(), found: callsite_args.len(), span, }); return Type::Error; } - for (param, (arg, _, arg_span)) in fn_params.iter().skip(skip_params).zip(callsite_args) { + for (param, (arg, _, arg_span)) in fn_params.iter().zip(callsite_args) { arg.make_subtype_of(param, *arg_span, &mut self.errors, || { TypeCheckError::TypeMismatch { expected_typ: param.to_string(), @@ -933,22 +929,9 @@ impl<'interner> TypeChecker<'interner> { } ret } - Type::Function(parameters, ret, env) => { - self.bind_function_type_impl( - parameters.as_ref(), - ret.as_ref(), - args.as_ref(), - span, - match *env { - Type::Unit => 0, - Type::Tuple(_) => { - 1 // closure env - } - _ => unreachable!( - "function env internal type should be either Unit or Tuple" - ), - }, - ) + Type::Function(parameters, ret, _env) => { + // ignoring env for subtype on purpose + self.bind_function_type_impl(parameters.as_ref(), ret.as_ref(), args.as_ref(), span) } Type::Error => Type::Error, found => { diff --git a/crates/noirc_frontend/src/hir_def/types.rs b/crates/noirc_frontend/src/hir_def/types.rs index ba9ff6e2d65..d77b8033ba1 100644 --- a/crates/noirc_frontend/src/hir_def/types.rs +++ b/crates/noirc_frontend/src/hir_def/types.rs @@ -802,12 +802,12 @@ impl std::fmt::Display for Type { write!(f, "forall {}. {}", typevars.join(" "), typ) } Type::Function(args, ret, env) => { - let (params_skip_count, closure_env_text) = match **env { - Type::Unit => (0, "".to_string()), - _ => (1, format!(" with closure environment {env}")), + let closure_env_text = match **env { + Type::Unit => "".to_string(), + _ => format!(" with closure environment {env}"), }; - let args = vecmap(args.iter().skip(params_skip_count), ToString::to_string); + let args = vecmap(args.iter(), ToString::to_string); write!(f, "fn({}) -> {ret}{closure_env_text}", args.join(", ")) } @@ -1206,28 +1206,9 @@ impl Type { } } - (Function(params_a, ret_a, env_a), Function(params_b, ret_b, env_b)) => { - let skip_params_count_a = match **env_a { - // non-closure function: - Type::Unit => 0, - // possibly a closure: so we transform the function to pass env as a first arg - // which means we should skip the first arg now in param checking - _ => 1, - }; - let real_fn_param_count_a = params_a.len() - skip_params_count_a; - - let skip_params_count_b = match **env_b { - Type::Unit => 0, - _ => 1, - }; - let real_fn_param_count_b = params_b.len() - skip_params_count_b; - - if real_fn_param_count_a == real_fn_param_count_b { - for (a, b) in params_a - .iter() - .skip(skip_params_count_a) - .zip(params_b.iter().skip(skip_params_count_b)) - { + (Function(params_a, ret_a, _env_a), Function(params_b, ret_b, _env_b)) => { + if params_a.len() == params_b.len() { + for (a, b) in params_a.iter().zip(params_b.iter()) { a.try_unify(b, span)?; } @@ -1432,16 +1413,9 @@ impl Type { } } - (Function(params_a, ret_a, env_a), Function(params_b, ret_b, _env_b)) => { - let skip_params = match *env_a.clone() { - Type::Unit => 0, - Type::Tuple(_) => { - 1 // closure env - } - _ => unreachable!("function env internal type should be either Unit or Tuple"), - }; - if params_a.len() - skip_params == params_b.len() { - for (a, b) in params_a.iter().skip(skip_params).zip(params_b) { + (Function(params_a, ret_a, _env_a), Function(params_b, ret_b, _env_b)) => { + if params_a.len() == params_b.len() { + for (a, b) in params_a.iter().zip(params_b) { a.is_subtype_of(b, span)?; } From e8bc164b19d83df49ac8df1cf4b76535fa398312 Mon Sep 17 00:00:00 2001 From: Alex Vitkov Date: Wed, 2 Aug 2023 14:08:05 +0300 Subject: [PATCH 21/26] fix: don't use closure capture logic for lambdas without captures --- .../noirc_frontend/src/hir/type_check/expr.rs | 6 ++- .../src/monomorphization/mod.rs | 47 +++++++++++++++++-- 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/crates/noirc_frontend/src/hir/type_check/expr.rs b/crates/noirc_frontend/src/hir/type_check/expr.rs index 0cfad16915d..5f259e84fdf 100644 --- a/crates/noirc_frontend/src/hir/type_check/expr.rs +++ b/crates/noirc_frontend/src/hir/type_check/expr.rs @@ -282,7 +282,11 @@ impl<'interner> TypeChecker<'interner> { let captured_vars = vecmap(lambda.captures, |capture| self.interner.id_type(capture.ident.id)); - let env_type = Type::Tuple(captured_vars); + let env_type: Type = if captured_vars.is_empty() { + Type::Unit + } else { + Type::Tuple(captured_vars.clone()) + }; let params: Vec = vecmap(lambda.parameters, |(pattern, typ)| { self.bind_pattern(&pattern, typ.clone()); diff --git a/crates/noirc_frontend/src/monomorphization/mod.rs b/crates/noirc_frontend/src/monomorphization/mod.rs index 549a52825d1..33b5543d2ec 100644 --- a/crates/noirc_frontend/src/monomorphization/mod.rs +++ b/crates/noirc_frontend/src/monomorphization/mod.rs @@ -996,6 +996,48 @@ impl<'interner> Monomorphizer<'interner> { } } + fn lambda(&mut self, lambda: HirLambda, expr: node_interner::ExprId) -> ast::Expression { + if lambda.captures.is_empty() { + self.lambda_no_capture(lambda) + } else { + let (setup, closure_variable) = self.lambda_with_setup(lambda, expr); + ast::Expression::Block(vec![setup, closure_variable]) + } + } + + fn lambda_no_capture(&mut self, lambda: HirLambda) -> ast::Expression { + let ret_type = Self::convert_type(&lambda.return_type); + let lambda_name = "lambda"; + let parameter_types = vecmap(&lambda.parameters, |(_, typ)| Self::convert_type(typ)); + + // Manually convert to Parameters type so we can reuse the self.parameters method + let parameters = Parameters(vecmap(lambda.parameters, |(pattern, typ)| { + Param(pattern, typ, noirc_abi::AbiVisibility::Private) + })); + + let parameters = self.parameters(parameters); + let body = self.expr(lambda.body); + + let id = self.next_function_id(); + let return_type = ret_type.clone(); + let name = lambda_name.to_owned(); + let unconstrained = false; + + let function = ast::Function { id, name, parameters, body, return_type, unconstrained }; + self.push_function(id, function); + + let typ = ast::Type::Function(parameter_types, Box::new(ret_type), Box::new(ast::Type::Unit)); + + let name = lambda_name.to_owned(); + ast::Expression::Ident(ast::Ident { + definition: Definition::Function(id), + mutable: false, + location: None, + name, + typ, + }) + } + fn lambda_with_setup( &mut self, lambda: HirLambda, @@ -1123,11 +1165,6 @@ impl<'interner> Monomorphizer<'interner> { (block_let_stmt, closure_ident) } - fn lambda(&mut self, lambda: HirLambda, expr: node_interner::ExprId) -> ast::Expression { - let (setup, closure_variable) = self.lambda_with_setup(lambda, expr); - ast::Expression::Block(vec![setup, closure_variable]) - } - /// Implements std::unsafe::zeroed by returning an appropriate zeroed /// ast literal or collection node for the given type. Note that for functions /// there is no obvious zeroed value so this should be considered unsafe to use. From a5fde6166b7b6aa1efd1edee48e079be6fabb7f8 Mon Sep 17 00:00:00 2001 From: Alex Vitkov Date: Wed, 2 Aug 2023 16:02:18 +0300 Subject: [PATCH 22/26] fix: apply cargo fmt & clippy --- crates/noirc_frontend/src/hir/type_check/expr.rs | 2 +- crates/noirc_frontend/src/monomorphization/mod.rs | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/crates/noirc_frontend/src/hir/type_check/expr.rs b/crates/noirc_frontend/src/hir/type_check/expr.rs index 5f259e84fdf..24db3d90b5c 100644 --- a/crates/noirc_frontend/src/hir/type_check/expr.rs +++ b/crates/noirc_frontend/src/hir/type_check/expr.rs @@ -285,7 +285,7 @@ impl<'interner> TypeChecker<'interner> { let env_type: Type = if captured_vars.is_empty() { Type::Unit } else { - Type::Tuple(captured_vars.clone()) + Type::Tuple(captured_vars) }; let params: Vec = vecmap(lambda.parameters, |(pattern, typ)| { diff --git a/crates/noirc_frontend/src/monomorphization/mod.rs b/crates/noirc_frontend/src/monomorphization/mod.rs index 33b5543d2ec..332c727dfc6 100644 --- a/crates/noirc_frontend/src/monomorphization/mod.rs +++ b/crates/noirc_frontend/src/monomorphization/mod.rs @@ -1026,7 +1026,8 @@ impl<'interner> Monomorphizer<'interner> { let function = ast::Function { id, name, parameters, body, return_type, unconstrained }; self.push_function(id, function); - let typ = ast::Type::Function(parameter_types, Box::new(ret_type), Box::new(ast::Type::Unit)); + let typ = + ast::Type::Function(parameter_types, Box::new(ret_type), Box::new(ast::Type::Unit)); let name = lambda_name.to_owned(); ast::Expression::Ident(ast::Ident { From fbbffd737e35ad76b4fb0395a9a239f5572a6954 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Wed, 2 Aug 2023 16:20:51 +0300 Subject: [PATCH 23/26] chore: apply cargo fmt --- crates/noirc_frontend/src/hir/type_check/expr.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/crates/noirc_frontend/src/hir/type_check/expr.rs b/crates/noirc_frontend/src/hir/type_check/expr.rs index 24db3d90b5c..86f5b5f0911 100644 --- a/crates/noirc_frontend/src/hir/type_check/expr.rs +++ b/crates/noirc_frontend/src/hir/type_check/expr.rs @@ -282,11 +282,8 @@ impl<'interner> TypeChecker<'interner> { let captured_vars = vecmap(lambda.captures, |capture| self.interner.id_type(capture.ident.id)); - let env_type: Type = if captured_vars.is_empty() { - Type::Unit - } else { - Type::Tuple(captured_vars) - }; + let env_type: Type = + if captured_vars.is_empty() { Type::Unit } else { Type::Tuple(captured_vars) }; let params: Vec = vecmap(lambda.parameters, |(pattern, typ)| { self.bind_pattern(&pattern, typ.clone()); From 5beb74728ca4131e1b03f34313b23078ae6b737e Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Wed, 2 Aug 2023 22:26:32 +0300 Subject: [PATCH 24/26] test: fix closure rewrite test: actually capture --- crates/noirc_frontend/src/monomorphization/mod.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/crates/noirc_frontend/src/monomorphization/mod.rs b/crates/noirc_frontend/src/monomorphization/mod.rs index 332c727dfc6..c8167baf6bb 100644 --- a/crates/noirc_frontend/src/monomorphization/mod.rs +++ b/crates/noirc_frontend/src/monomorphization/mod.rs @@ -1436,23 +1436,25 @@ mod tests { fn simple_closure_with_no_captured_variables() { let src = r#" fn main() -> Field { - let closure = |x| x; - closure(0) + let x = 1; + let closure = || x; + closure() } "#; let expected_rewrite = r#"fn main$f0() -> Field { + let x$0 = 1; let closure$3 = { let closure_variable$2 = { - let env$1 = (); + let env$1 = (x$l0); (env$l1, lambda$f1) }; closure_variable$l2 }; - closure$l3.1(closure$l3.0, 0) + closure$l3.1(closure$l3.0) } -fn lambda$f1(mut env$l1: (), x$l0: Field) -> Field { - x$l0 +fn lambda$f1(mut env$l1: (Field)) -> Field { + env$l1.0 } "#; check_rewrite(src, expected_rewrite); From 4a581a993ca6943fa88be641a8c5973518de65bc Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Wed, 2 Aug 2023 22:28:20 +0300 Subject: [PATCH 25/26] chore: remove type annotation for `params` --- crates/noirc_frontend/src/hir/type_check/expr.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/noirc_frontend/src/hir/type_check/expr.rs b/crates/noirc_frontend/src/hir/type_check/expr.rs index 86f5b5f0911..6c111a1d6a0 100644 --- a/crates/noirc_frontend/src/hir/type_check/expr.rs +++ b/crates/noirc_frontend/src/hir/type_check/expr.rs @@ -285,7 +285,7 @@ impl<'interner> TypeChecker<'interner> { let env_type: Type = if captured_vars.is_empty() { Type::Unit } else { Type::Tuple(captured_vars) }; - let params: Vec = vecmap(lambda.parameters, |(pattern, typ)| { + let params = vecmap(lambda.parameters, |(pattern, typ)| { self.bind_pattern(&pattern, typ.clone()); typ }); From 5cc34290dc94299230f02a420c92c6a977eed497 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Wed, 2 Aug 2023 22:30:09 +0300 Subject: [PATCH 26/26] chore: run cargo fmt --- crates/noirc_frontend/src/hir/def_collector/dc_crate.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/noirc_frontend/src/hir/def_collector/dc_crate.rs b/crates/noirc_frontend/src/hir/def_collector/dc_crate.rs index 76fbea289be..2beebf6871c 100644 --- a/crates/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/crates/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -12,8 +12,8 @@ use crate::hir::type_check::{type_check_func, TypeChecker}; use crate::hir::Context; use crate::node_interner::{FuncId, NodeInterner, StmtId, StructId, TypeAliasId}; use crate::{ - ExpressionKind, Generics, Ident, LetStatement, NoirFunction, NoirStruct, NoirTypeAlias, - ParsedModule, Shared, Type, TypeBinding, UnresolvedGenerics, UnresolvedType, Literal, + ExpressionKind, Generics, Ident, LetStatement, Literal, NoirFunction, NoirStruct, + NoirTypeAlias, ParsedModule, Shared, Type, TypeBinding, UnresolvedGenerics, UnresolvedType, }; use fm::FileId; use iter_extended::vecmap;