Skip to content

Commit e5773e4

Browse files
authored
feat(ssa refactor): Implement first-class references (#1849)
* Explore work on references * Cleanup * Implement first-class references * Fix frontend test * Remove 'Mutability' struct, it is no longer needed * Remove some extra lines * Remove another function * Revert another line * Fix test again * Fix a bug in mem2reg for nested references * Fix inconsistent .eval during ssa-gen on assign statements * Revert some code * Add check for mutating immutable self objects
1 parent d0894ad commit e5773e4

File tree

27 files changed

+611
-103
lines changed

27 files changed

+611
-103
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[package]
2+
authors = [""]
3+
compiler_version = "0.5.1"
4+
5+
[dependencies]

crates/nargo_cli/tests/test_data_ssa_refactor/references/Prover.toml

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
fn main() {
2+
let mut x = 2;
3+
add1(&mut x);
4+
assert(x == 3);
5+
6+
let mut s = S { y: x };
7+
s.add2();
8+
assert(s.y == 5);
9+
10+
// Test that normal mutable variables are still copied
11+
let mut a = 0;
12+
mutate_copy(a);
13+
assert(a == 0);
14+
15+
// Test something 3 allocations deep
16+
let mut nested_allocations = Nested { y: &mut &mut 0 };
17+
add1(*nested_allocations.y);
18+
assert(**nested_allocations.y == 1);
19+
20+
// Test nested struct allocations with a mutable reference to an array.
21+
let mut c = C {
22+
foo: 0,
23+
bar: &mut C2 {
24+
array: &mut [1, 2],
25+
},
26+
};
27+
*c.bar.array = [3, 4];
28+
assert(*c.bar.array == [3, 4]);
29+
}
30+
31+
fn add1(x: &mut Field) {
32+
*x += 1;
33+
}
34+
35+
struct S { y: Field }
36+
37+
struct Nested { y: &mut &mut Field }
38+
39+
struct C {
40+
foo: Field,
41+
bar: &mut C2,
42+
}
43+
44+
struct C2 {
45+
array: &mut [Field; 2]
46+
}
47+
48+
impl S {
49+
fn add2(&mut self) {
50+
self.y += 2;
51+
}
52+
}
53+
54+
fn mutate_copy(mut a: Field) {
55+
a = 7;
56+
}

crates/nargo_cli/tests/test_data_ssa_refactor/tuples/src/main.nr

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ fn main(x: Field, y: Field) {
1919

2020
// Test mutating tuples
2121
let mut mutable = ((0, 0), 1, 2, 3);
22-
mutable.0 = pair;
22+
mutable.0 = (x, y);
2323
mutable.2 = 7;
2424
assert(mutable.0.0 == 1);
2525
assert(mutable.0.1 == 0);

crates/noirc_evaluator/src/ssa/context.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1207,6 +1207,7 @@ impl SsaContext {
12071207
}
12081208
}
12091209
Type::Array(..) => panic!("Cannot convert an array type {t} into an ObjectType since it is unknown which array it refers to"),
1210+
Type::MutableReference(..) => panic!("Mutable reference types are unimplemented in the old ssa backend"),
12101211
Type::Unit => ObjectType::NotAnObject,
12111212
Type::Function(..) => ObjectType::Function,
12121213
Type::Tuple(_) => todo!("Conversion to ObjectType is unimplemented for tuples"),

crates/noirc_evaluator/src/ssa/ssa_gen.rs

+10
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,9 @@ impl IrGenerator {
208208
self.context.new_instruction(op, rhs_type)
209209
}
210210
UnaryOp::Not => self.context.new_instruction(Operation::Not(rhs), rhs_type),
211+
UnaryOp::MutableReference | UnaryOp::Dereference => {
212+
unimplemented!("Mutable references are unimplemented in the old ssa backend")
213+
}
211214
}
212215
}
213216

@@ -248,6 +251,9 @@ impl IrGenerator {
248251
let val = self.find_variable(ident_def).unwrap();
249252
val.get_field_member(*field_index)
250253
}
254+
LValue::Dereference { .. } => {
255+
unreachable!("Mutable references are unsupported in the old ssa backend")
256+
}
251257
}
252258
}
253259

@@ -256,6 +262,7 @@ impl IrGenerator {
256262
LValue::Ident(ident) => &ident.definition,
257263
LValue::Index { array, .. } => Self::lvalue_ident_def(array.as_ref()),
258264
LValue::MemberAccess { object, .. } => Self::lvalue_ident_def(object.as_ref()),
265+
LValue::Dereference { reference, .. } => Self::lvalue_ident_def(reference.as_ref()),
259266
}
260267
}
261268

@@ -462,6 +469,9 @@ impl IrGenerator {
462469
let value = val.get_field_member(*field_index).clone();
463470
self.assign_pattern(&value, rhs)?;
464471
}
472+
LValue::Dereference { .. } => {
473+
unreachable!("Mutable references are unsupported in the old ssa backend")
474+
}
465475
}
466476
Ok(Value::dummy())
467477
}

crates/noirc_evaluator/src/ssa/value.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ impl Value {
100100
| Type::String(..)
101101
| Type::Integer(..)
102102
| Type::Bool
103-
| Type::Field => Value::Node(*iter.next().unwrap()),
103+
| Type::Field
104+
| Type::MutableReference(_) => Value::Node(*iter.next().unwrap()),
104105
}
105106
}
106107

crates/noirc_evaluator/src/ssa_refactor/ir/function.rs

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub(crate) enum RuntimeType {
1313
// Unconstrained function, to be compiled to brillig and executed by the Brillig VM
1414
Brillig,
1515
}
16+
1617
/// A function holds a list of instructions.
1718
/// These instructions are further grouped into Basic blocks
1819
///

crates/noirc_evaluator/src/ssa_refactor/opt/mem2reg.rs

+28-12
Original file line numberDiff line numberDiff line change
@@ -64,20 +64,39 @@ impl PerBlockContext {
6464
dfg: &mut DataFlowGraph,
6565
) -> HashSet<AllocId> {
6666
let mut protected_allocations = HashSet::new();
67-
let mut loads_to_substitute = HashMap::new();
6867
let block = &dfg[self.block_id];
6968

69+
// Maps Load instruction id -> value to replace the result of the load with
70+
let mut loads_to_substitute = HashMap::new();
71+
72+
// Maps Load result id -> value to replace the result of the load with
73+
let mut load_values_to_substitute = HashMap::new();
74+
7075
for instruction_id in block.instructions() {
7176
match &dfg[*instruction_id] {
72-
Instruction::Store { address, value } => {
73-
self.last_stores.insert(*address, *value);
77+
Instruction::Store { mut address, value } => {
78+
if let Some(value) = load_values_to_substitute.get(&address) {
79+
address = *value;
80+
}
81+
82+
self.last_stores.insert(address, *value);
7483
self.store_ids.push(*instruction_id);
7584
}
76-
Instruction::Load { address } => {
77-
if let Some(last_value) = self.last_stores.get(address) {
85+
Instruction::Load { mut address } => {
86+
if let Some(value) = load_values_to_substitute.get(&address) {
87+
address = *value;
88+
}
89+
90+
if let Some(last_value) = self.last_stores.get(&address) {
91+
let result_value = *dfg
92+
.instruction_results(*instruction_id)
93+
.first()
94+
.expect("ICE: Load instructions should have single result");
95+
7896
loads_to_substitute.insert(*instruction_id, *last_value);
97+
load_values_to_substitute.insert(result_value, *last_value);
7998
} else {
80-
protected_allocations.insert(*address);
99+
protected_allocations.insert(address);
81100
}
82101
}
83102
Instruction::Call { arguments, .. } => {
@@ -103,12 +122,9 @@ impl PerBlockContext {
103122
}
104123

105124
// Substitute load result values
106-
for (instruction_id, new_value) in &loads_to_substitute {
107-
let result_value = *dfg
108-
.instruction_results(*instruction_id)
109-
.first()
110-
.expect("ICE: Load instructions should have single result");
111-
dfg.set_value_from_id(result_value, *new_value);
125+
for (result_value, new_value) in load_values_to_substitute {
126+
let result_value = dfg.resolve(result_value);
127+
dfg.set_value_from_id(result_value, new_value);
112128
}
113129

114130
// Delete load instructions

crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs

+35-5
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,17 @@ impl<'a> FunctionContext<'a> {
165165

166166
// This helper is needed because we need to take f by mutable reference,
167167
// otherwise we cannot move it multiple times each loop of vecmap.
168-
fn map_type_helper<T>(typ: &ast::Type, f: &mut impl FnMut(Type) -> T) -> Tree<T> {
168+
fn map_type_helper<T>(typ: &ast::Type, f: &mut dyn FnMut(Type) -> T) -> Tree<T> {
169169
match typ {
170170
ast::Type::Tuple(fields) => {
171171
Tree::Branch(vecmap(fields, |field| Self::map_type_helper(field, f)))
172172
}
173173
ast::Type::Unit => Tree::empty(),
174+
// A mutable reference wraps each element into a reference.
175+
// This can be multiple values if the element type is a tuple.
176+
ast::Type::MutableReference(element) => {
177+
Self::map_type_helper(element, &mut |_| f(Type::Reference))
178+
}
174179
other => Tree::Leaf(f(Self::convert_non_tuple_type(other))),
175180
}
176181
}
@@ -201,6 +206,11 @@ impl<'a> FunctionContext<'a> {
201206
ast::Type::Unit => panic!("convert_non_tuple_type called on a unit type"),
202207
ast::Type::Tuple(_) => panic!("convert_non_tuple_type called on a tuple: {typ}"),
203208
ast::Type::Function(_, _) => Type::Function,
209+
ast::Type::MutableReference(element) => {
210+
// Recursive call to panic if element is a tuple
211+
Self::convert_non_tuple_type(element);
212+
Type::Reference
213+
}
204214

205215
// How should we represent Vecs?
206216
// Are they a struct of array + length + capacity?
@@ -473,9 +483,21 @@ impl<'a> FunctionContext<'a> {
473483
let object_lvalue = Box::new(object_lvalue);
474484
LValue::MemberAccess { old_object, object_lvalue, index: *field_index }
475485
}
486+
ast::LValue::Dereference { reference, .. } => {
487+
let (reference, _) = self.extract_current_value_recursive(reference);
488+
LValue::Dereference { reference }
489+
}
476490
}
477491
}
478492

493+
pub(super) fn dereference(&mut self, values: &Values, element_type: &ast::Type) -> Values {
494+
let element_types = Self::convert_type(element_type);
495+
values.map_both(element_types, |value, element_type| {
496+
let reference = value.eval(self);
497+
self.builder.insert_load(reference, element_type).into()
498+
})
499+
}
500+
479501
/// Compile the given identifier as a reference - ie. avoid calling .eval()
480502
fn ident_lvalue(&self, ident: &ast::Ident) -> Values {
481503
match &ident.definition {
@@ -516,16 +538,19 @@ impl<'a> FunctionContext<'a> {
516538
let element = Self::get_field_ref(&old_object, *index).clone();
517539
(element, LValue::MemberAccess { old_object, object_lvalue, index: *index })
518540
}
541+
ast::LValue::Dereference { reference, element_type } => {
542+
let (reference, _) = self.extract_current_value_recursive(reference);
543+
let dereferenced = self.dereference(&reference, element_type);
544+
(dereferenced, LValue::Dereference { reference })
545+
}
519546
}
520547
}
521548

522549
/// Assigns a new value to the given LValue.
523550
/// The LValue can be created via a previous call to extract_current_value.
524551
/// This method recurs on the given LValue to create a new value to assign an allocation
525-
/// instruction within an LValue::Ident - see the comment on `extract_current_value` for more
526-
/// details.
527-
/// When first-class references are supported the nearest reference may be in any LValue
528-
/// variant rather than just LValue::Ident.
552+
/// instruction within an LValue::Ident or LValue::Dereference - see the comment on
553+
/// `extract_current_value` for more details.
529554
pub(super) fn assign_new_value(&mut self, lvalue: LValue, new_value: Values) {
530555
match lvalue {
531556
LValue::Ident(references) => self.assign(references, new_value),
@@ -538,6 +563,9 @@ impl<'a> FunctionContext<'a> {
538563
let new_object = Self::replace_field(old_object, index, new_value);
539564
self.assign_new_value(*object_lvalue, new_object);
540565
}
566+
LValue::Dereference { reference } => {
567+
self.assign(reference, new_value);
568+
}
541569
}
542570
}
543571

@@ -705,8 +733,10 @@ impl SharedContext {
705733
}
706734

707735
/// Used to remember the results of each step of extracting a value from an ast::LValue
736+
#[derive(Debug)]
708737
pub(super) enum LValue {
709738
Ident(Values),
710739
Index { old_array: ValueId, index: ValueId, array_lvalue: Box<LValue> },
711740
MemberAccess { old_object: Values, index: usize, object_lvalue: Box<LValue> },
741+
Dereference { reference: Values },
712742
}

crates/noirc_evaluator/src/ssa_refactor/ssa_gen/mod.rs

+23-4
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ impl<'a> FunctionContext<'a> {
9999
/// Codegen for identifiers
100100
fn codegen_ident(&mut self, ident: &ast::Ident) -> Values {
101101
match &ident.definition {
102-
ast::Definition::Local(id) => self.lookup(*id).map(|value| value.eval(self).into()),
102+
ast::Definition::Local(id) => self.lookup(*id),
103103
ast::Definition::Function(id) => self.get_or_queue_function(*id),
104104
ast::Definition::Oracle(name) => self.builder.import_foreign_function(name).into(),
105105
ast::Definition::Builtin(name) | ast::Definition::LowLevel(name) => {
@@ -165,14 +165,33 @@ impl<'a> FunctionContext<'a> {
165165
}
166166

167167
fn codegen_unary(&mut self, unary: &ast::Unary) -> Values {
168-
let rhs = self.codegen_non_tuple_expression(&unary.rhs);
168+
let rhs = self.codegen_expression(&unary.rhs);
169169
match unary.operator {
170-
noirc_frontend::UnaryOp::Not => self.builder.insert_not(rhs).into(),
170+
noirc_frontend::UnaryOp::Not => {
171+
let rhs = rhs.into_leaf().eval(self);
172+
self.builder.insert_not(rhs).into()
173+
}
171174
noirc_frontend::UnaryOp::Minus => {
175+
let rhs = rhs.into_leaf().eval(self);
172176
let typ = self.builder.type_of_value(rhs);
173177
let zero = self.builder.numeric_constant(0u128, typ);
174178
self.builder.insert_binary(zero, BinaryOp::Sub, rhs).into()
175179
}
180+
noirc_frontend::UnaryOp::MutableReference => {
181+
rhs.map(|rhs| {
182+
match rhs {
183+
value::Value::Normal(value) => {
184+
let alloc = self.builder.insert_allocate();
185+
self.builder.insert_store(alloc, value);
186+
Tree::Leaf(value::Value::Normal(alloc))
187+
}
188+
// NOTE: The `.into()` here converts the Value::Mutable into
189+
// a Value::Normal so it is no longer automatically dereferenced.
190+
value::Value::Mutable(reference, _) => reference.into(),
191+
}
192+
})
193+
}
194+
noirc_frontend::UnaryOp::Dereference => self.dereference(&rhs, &unary.result_type),
176195
}
177196
}
178197

@@ -343,13 +362,13 @@ impl<'a> FunctionContext<'a> {
343362
/// Generate SSA for a function call. Note that calls to built-in functions
344363
/// and intrinsics are also represented by the function call instruction.
345364
fn codegen_call(&mut self, call: &ast::Call) -> Values {
365+
let function = self.codegen_non_tuple_expression(&call.func);
346366
let arguments = call
347367
.arguments
348368
.iter()
349369
.flat_map(|argument| self.codegen_expression(argument).into_value_list(self))
350370
.collect();
351371

352-
let function = self.codegen_non_tuple_expression(&call.func);
353372
self.insert_call(function, arguments, &call.return_type)
354373
}
355374

crates/noirc_evaluator/src/ssa_refactor/ssa_gen/value.rs

+30
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,36 @@ impl<T> Tree<T> {
123123
}
124124
}
125125

126+
/// Map two trees alongside each other.
127+
/// This asserts each tree has the same internal structure.
128+
pub(super) fn map_both<U, R>(
129+
&self,
130+
other: Tree<U>,
131+
mut f: impl FnMut(T, U) -> Tree<R>,
132+
) -> Tree<R>
133+
where
134+
T: std::fmt::Debug + Clone,
135+
U: std::fmt::Debug,
136+
{
137+
self.map_both_helper(other, &mut f)
138+
}
139+
140+
fn map_both_helper<U, R>(&self, other: Tree<U>, f: &mut impl FnMut(T, U) -> Tree<R>) -> Tree<R>
141+
where
142+
T: std::fmt::Debug + Clone,
143+
U: std::fmt::Debug,
144+
{
145+
match (self, other) {
146+
(Tree::Branch(self_trees), Tree::Branch(other_trees)) => {
147+
assert_eq!(self_trees.len(), other_trees.len());
148+
let trees = self_trees.iter().zip(other_trees);
149+
Tree::Branch(vecmap(trees, |(l, r)| l.map_both_helper(r, f)))
150+
}
151+
(Tree::Leaf(self_value), Tree::Leaf(other_value)) => f(self_value.clone(), other_value),
152+
other => panic!("Found unexpected tree combination during SSA: {other:?}"),
153+
}
154+
}
155+
126156
/// Unwraps this Tree into the value of the leaf node. Panics if
127157
/// this Tree is a Branch
128158
pub(super) fn into_leaf(self) -> T {

0 commit comments

Comments
 (0)