Skip to content

Commit a94cdd2

Browse files
committed
feat: add syntax for explicitly specifying closure types
1 parent 2c5b35d commit a94cdd2

File tree

9 files changed

+158
-80
lines changed

9 files changed

+158
-80
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[package]
2+
name = "closure_explicit_types"
3+
type = "bin"
4+
authors = [""]
5+
compiler_version = "0.10.3"
6+
7+
[dependencies]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
2+
fn ret_normal_lambda1() -> fn() -> Field {
3+
|| 10
4+
}
5+
6+
// explicitly specified empty capture group
7+
fn ret_normal_lambda2() -> fn[]() -> Field {
8+
|| 20
9+
}
10+
11+
// return lamda that captures a thing
12+
fn ret_closure1() -> fn[Field]() -> Field {
13+
let x = 20;
14+
|| x + 10
15+
}
16+
17+
// return lamda that captures two things
18+
fn ret_closure2() -> fn[Field,Field]() -> Field {
19+
let x = 20;
20+
let y = 10;
21+
|| x + y + 10
22+
}
23+
24+
// return lamda that captures two things with different types
25+
fn ret_closure3() -> fn[u32,u64]() -> u64 {
26+
let x: u32 = 20;
27+
let y: u64 = 10;
28+
|| x as u64 + y + 10
29+
}
30+
31+
// accepts closure that has 1 thing in its env, calls it and returns the result
32+
fn accepts_closure1(f: fn[Field]() -> Field) -> Field {
33+
f()
34+
}
35+
36+
// accepts closure that has 1 thing in its env and returns it
37+
fn accepts_closure2(f: fn[Field]() -> Field) -> fn[Field]() -> Field {
38+
f
39+
}
40+
41+
// accepts closure with different types in the capture group
42+
fn accepts_closure3(f: fn[u32, u64]() -> u64) -> u64 {
43+
f()
44+
}
45+
46+
fn main() {
47+
assert(ret_normal_lambda1()() == 10);
48+
assert(ret_normal_lambda2()() == 20);
49+
assert(ret_closure1()() == 30);
50+
assert(ret_closure2()() == 40);
51+
assert(ret_closure3()() == 40);
52+
53+
let x = 50;
54+
assert(accepts_closure1(|| x) == 50);
55+
assert(accepts_closure2(|| x + 10)() == 60);
56+
57+
let y: u32 = 30;
58+
let z: u64 = 40;
59+
assert(accepts_closure3(|| y as u64 + z) == 70);
60+
}

crates/noirc_frontend/src/ast/mod.rs

+17-3
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@ pub enum UnresolvedType {
5050
// Note: Tuples have no visibility, instead each of their elements may have one.
5151
Tuple(Vec<UnresolvedType>),
5252

53-
Function(/*args:*/ Vec<UnresolvedType>, /*ret:*/ Box<UnresolvedType>),
53+
Function(
54+
/*args:*/ Vec<UnresolvedType>,
55+
/*ret:*/ Box<UnresolvedType>,
56+
/*env:*/ Box<UnresolvedType>,
57+
),
5458

5559
Unspecified, // This is for when the user declares a variable without specifying it's type
5660
Error,
@@ -109,9 +113,19 @@ impl std::fmt::Display for UnresolvedType {
109113
Some(len) => write!(f, "str<{len}>"),
110114
},
111115
FormatString(len, elements) => write!(f, "fmt<{len}, {elements}"),
112-
Function(args, ret) => {
116+
Function(args, ret, env) => {
113117
let args = vecmap(args, ToString::to_string);
114-
write!(f, "fn({}) -> {ret}", args.join(", "))
118+
119+
match &**env {
120+
UnresolvedType::Unit => {
121+
write!(f, "fn({}) -> {ret}", args.join(", "))
122+
}
123+
UnresolvedType::Tuple(env_types) => {
124+
let env_types = vecmap(env_types, ToString::to_string);
125+
write!(f, "fn[{}]({}) -> {ret}", env_types.join(", "), args.join(", "))
126+
}
127+
_ => unreachable!(),
128+
}
115129
}
116130
MutableReference(element) => write!(f, "&mut {element}"),
117131
Unit => write!(f, "()"),

crates/noirc_frontend/src/hir/resolution/resolver.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -361,10 +361,10 @@ impl<'a> Resolver<'a> {
361361
UnresolvedType::Tuple(fields) => {
362362
Type::Tuple(vecmap(fields, |field| self.resolve_type_inner(field, new_variables)))
363363
}
364-
UnresolvedType::Function(args, ret) => {
364+
UnresolvedType::Function(args, ret, env) => {
365365
let args = vecmap(args, |arg| self.resolve_type_inner(arg, new_variables));
366366
let ret = Box::new(self.resolve_type_inner(*ret, new_variables));
367-
let env = Box::new(Type::Unit);
367+
let env = Box::new(self.resolve_type_inner(*env, new_variables));
368368
Type::Function(args, ret, env)
369369
}
370370
UnresolvedType::MutableReference(element) => {

crates/noirc_frontend/src/hir/type_check/expr.rs

+5-7
Original file line numberDiff line numberDiff line change
@@ -837,13 +837,11 @@ impl<'interner> TypeChecker<'interner> {
837837
}
838838

839839
for (param, (arg, _, arg_span)) in fn_params.iter().zip(callsite_args) {
840-
if arg.try_unify_allow_incompat_lambdas(param).is_err() {
841-
self.errors.push(TypeCheckError::TypeMismatch {
842-
expected_typ: param.to_string(),
843-
expr_typ: arg.to_string(),
844-
expr_span: *arg_span,
845-
});
846-
}
840+
self.unify(arg, param, || TypeCheckError::TypeMismatch {
841+
expected_typ: param.to_string(),
842+
expr_typ: arg.to_string(),
843+
expr_span: *arg_span,
844+
});
847845
}
848846

849847
fn_ret.clone()

crates/noirc_frontend/src/hir/type_check/mod.rs

+20-25
Original file line numberDiff line numberDiff line change
@@ -63,33 +63,28 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec<Type
6363
let (expr_span, empty_function) = function_info(interner, function_body_id);
6464

6565
let func_span = interner.expr_span(function_body_id); // XXX: We could be more specific and return the span of the last stmt, however stmts do not have spans yet
66+
function_last_type.unify_with_coercions(
67+
&declared_return_type,
68+
*function_body_id,
69+
interner,
70+
&mut errors,
71+
|| {
72+
let mut error = TypeCheckError::TypeMismatchWithSource {
73+
expected: declared_return_type.clone(),
74+
actual: function_last_type.clone(),
75+
span: func_span,
76+
source: Source::Return(meta.return_type, expr_span),
77+
};
6678

67-
let result = function_last_type.try_unify_allow_incompat_lambdas(&declared_return_type);
68-
69-
if result.is_err() {
70-
function_last_type.unify_with_coercions(
71-
&declared_return_type,
72-
*function_body_id,
73-
interner,
74-
&mut errors,
75-
|| {
76-
let mut error = TypeCheckError::TypeMismatchWithSource {
77-
expected: declared_return_type.clone(),
78-
actual: function_last_type.clone(),
79-
span: func_span,
80-
source: Source::Return(meta.return_type, expr_span),
81-
};
82-
83-
if empty_function {
84-
error = error.add_context(
85-
"implicitly returns `()` as its body has no tail or `return` expression",
86-
);
87-
}
79+
if empty_function {
80+
error = error.add_context(
81+
"implicitly returns `()` as its body has no tail or `return` expression",
82+
);
83+
}
8884

89-
error
90-
},
91-
);
92-
}
85+
error
86+
},
87+
);
9388
}
9489

9590
errors

crates/noirc_frontend/src/hir_def/types.rs

-29
Original file line numberDiff line numberDiff line change
@@ -947,35 +947,6 @@ impl Type {
947947
}
948948
}
949949

950-
/// Similar to try_unify() but allows non-matching capture groups for function types
951-
pub fn try_unify_allow_incompat_lambdas(&self, other: &Type) -> Result<(), UnificationError> {
952-
use Type::*;
953-
use TypeVariableKind::*;
954-
955-
match (self, other) {
956-
(TypeVariable(binding, Normal), other) | (other, TypeVariable(binding, Normal)) => {
957-
if let TypeBinding::Bound(link) = &*binding.borrow() {
958-
return link.try_unify_allow_incompat_lambdas(other);
959-
}
960-
961-
other.try_bind_to(binding)
962-
}
963-
(Function(params_a, ret_a, _), Function(params_b, ret_b, _)) => {
964-
if params_a.len() == params_b.len() {
965-
for (a, b) in params_a.iter().zip(params_b.iter()) {
966-
a.try_unify_allow_incompat_lambdas(b)?;
967-
}
968-
969-
// no check for environments here!
970-
ret_b.try_unify_allow_incompat_lambdas(ret_a)
971-
} else {
972-
Err(UnificationError)
973-
}
974-
}
975-
_ => self.try_unify(other),
976-
}
977-
}
978-
979950
/// Similar to `unify` but if the check fails this will attempt to coerce the
980951
/// argument to the target type. When this happens, the given expression is wrapped in
981952
/// a new expression to convert its type. E.g. `array` -> `array.as_slice()`

crates/noirc_frontend/src/monomorphization/mod.rs

+26-11
Original file line numberDiff line numberDiff line change
@@ -784,15 +784,27 @@ impl<'interner> Monomorphizer<'interner> {
784784

785785
let is_closure = self.is_function_closure(call.func);
786786
if is_closure {
787-
let extracted_func: ast::Expression;
788-
let hir_call_func = self.interner.expression(&call.func);
789-
if let HirExpression::Lambda(l) = hir_call_func {
790-
let (setup, closure_variable) = self.lambda_with_setup(l, call.func);
791-
block_expressions.push(setup);
792-
extracted_func = closure_variable;
793-
} else {
794-
extracted_func = *original_func;
795-
}
787+
let local_id = self.next_local_id();
788+
789+
// store the function in a temporary variable before calling it
790+
// this is needed for example if call.func is of the form `foo()()`
791+
// without this, we would translate it to `foo().1(foo().0)`
792+
let let_stmt = ast::Expression::Let(ast::Let {
793+
id: local_id,
794+
mutable: false,
795+
name: "tmp".to_string(),
796+
expression: Box::new(*original_func),
797+
});
798+
block_expressions.push(let_stmt);
799+
800+
let extracted_func = ast::Expression::Ident(ast::Ident {
801+
location: None,
802+
definition: Definition::Local(local_id),
803+
mutable: false,
804+
name: "tmp".to_string(),
805+
typ: Self::convert_type(&self.interner.id_type(call.func)),
806+
});
807+
796808
func = Box::new(ast::Expression::ExtractTupleField(
797809
Box::new(extracted_func.clone()),
798810
1usize,
@@ -1435,7 +1447,7 @@ mod tests {
14351447
#[test]
14361448
fn simple_closure_with_no_captured_variables() {
14371449
let src = r#"
1438-
fn main() -> Field {
1450+
fn main() -> pub Field {
14391451
let x = 1;
14401452
let closure = || x;
14411453
closure()
@@ -1451,7 +1463,10 @@ mod tests {
14511463
};
14521464
closure_variable$l2
14531465
};
1454-
closure$l3.1(closure$l3.0)
1466+
{
1467+
let tmp$4 = closure$l3;
1468+
tmp$l4.1(tmp$l4.0)
1469+
}
14551470
}
14561471
fn lambda$f1(mut env$l1: (Field)) -> Field {
14571472
env$l1.0

crates/noirc_frontend/src/parser/parser.rs

+21-3
Original file line numberDiff line numberDiff line change
@@ -971,12 +971,30 @@ fn function_type<T>(type_parser: T) -> impl NoirParser<UnresolvedType>
971971
where
972972
T: NoirParser<UnresolvedType>,
973973
{
974-
let args = parenthesized(type_parser.clone().separated_by(just(Token::Comma)).allow_trailing());
974+
let types = type_parser.clone().separated_by(just(Token::Comma)).allow_trailing();
975+
let args = parenthesized(types.clone());
976+
977+
let env = just(Token::LeftBracket)
978+
.ignore_then(types)
979+
.then_ignore(just(Token::RightBracket))
980+
.or_not()
981+
.map(|args| match args {
982+
Some(args) => {
983+
if args.is_empty() {
984+
UnresolvedType::Unit
985+
} else {
986+
UnresolvedType::Tuple(args)
987+
}
988+
}
989+
None => UnresolvedType::Unit,
990+
});
991+
975992
keyword(Keyword::Fn)
976-
.ignore_then(args)
993+
.ignore_then(env)
994+
.then(args)
977995
.then_ignore(just(Token::Arrow))
978996
.then(type_parser)
979-
.map(|(args, ret)| UnresolvedType::Function(args, Box::new(ret)))
997+
.map(|((env, args), ret)| UnresolvedType::Function(args, Box::new(ret), Box::new(env)))
980998
}
981999

9821000
fn mutable_reference_type<T>(type_parser: T) -> impl NoirParser<UnresolvedType>

0 commit comments

Comments
 (0)