Skip to content

Commit 7216f08

Browse files
authored
feat(ssa): Loop invariant code motion (#6563)
1 parent df8f2ee commit 7216f08

File tree

9 files changed

+415
-14
lines changed

9 files changed

+415
-14
lines changed

compiler/noirc_evaluator/src/ssa.rs

+1
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ pub(crate) fn optimize_into_acir(
103103
Ssa::evaluate_static_assert_and_assert_constant,
104104
"After `static_assert` and `assert_constant`:",
105105
)?
106+
.run_pass(Ssa::loop_invariant_code_motion, "After Loop Invariant Code Motion:")
106107
.try_run_pass(Ssa::unroll_loops_iteratively, "After Unrolling:")?
107108
.run_pass(Ssa::simplify_cfg, "After Simplifying (2nd):")
108109
.run_pass(Ssa::flatten_cfg, "After Flattening:")

compiler/noirc_evaluator/src/ssa/ir/function_inserter.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ pub(crate) struct FunctionInserter<'f> {
2525
///
2626
/// This is optional since caching arrays relies on the inserter inserting strictly
2727
/// in control-flow order. Otherwise, if arrays later in the program are cached first,
28-
/// they may be refered to by instructions earlier in the program.
28+
/// they may be referred to by instructions earlier in the program.
2929
array_cache: Option<ArrayCache>,
3030

3131
/// If this pass is loop unrolling, store the block before the loop to optionally
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,378 @@
1+
//! The loop invariant code motion pass moves code from inside a loop to before the loop
2+
//! if that code will always have the same result on every iteration of the loop.
3+
//!
4+
//! To identify a loop invariant, check whether all of an instruction's values are:
5+
//! - Outside of the loop
6+
//! - Constant
7+
//! - Already marked as loop invariants
8+
//!
9+
//! We also check that we are not hoisting instructions with side effects.
10+
use fxhash::FxHashSet as HashSet;
11+
12+
use crate::ssa::{
13+
ir::{
14+
basic_block::BasicBlockId,
15+
function::{Function, RuntimeType},
16+
function_inserter::FunctionInserter,
17+
instruction::InstructionId,
18+
value::ValueId,
19+
},
20+
Ssa,
21+
};
22+
23+
use super::unrolling::{Loop, Loops};
24+
25+
impl Ssa {
26+
#[tracing::instrument(level = "trace", skip(self))]
27+
pub(crate) fn loop_invariant_code_motion(mut self) -> Ssa {
28+
let brillig_functions = self
29+
.functions
30+
.iter_mut()
31+
.filter(|(_, func)| matches!(func.runtime(), RuntimeType::Brillig(_)));
32+
33+
for (_, function) in brillig_functions {
34+
function.loop_invariant_code_motion();
35+
}
36+
37+
self
38+
}
39+
}
40+
41+
impl Function {
42+
fn loop_invariant_code_motion(&mut self) {
43+
Loops::find_all(self).hoist_loop_invariants(self);
44+
}
45+
}
46+
47+
impl Loops {
48+
fn hoist_loop_invariants(self, function: &mut Function) {
49+
let mut context = LoopInvariantContext::new(function);
50+
51+
for loop_ in self.yet_to_unroll.iter() {
52+
let Ok(pre_header) = loop_.get_pre_header(context.inserter.function, &self.cfg) else {
53+
// If the loop does not have a preheader we skip hoisting loop invariants for this loop
54+
continue;
55+
};
56+
context.hoist_loop_invariants(loop_, pre_header);
57+
}
58+
59+
context.map_dependent_instructions();
60+
}
61+
}
62+
63+
struct LoopInvariantContext<'f> {
64+
inserter: FunctionInserter<'f>,
65+
defined_in_loop: HashSet<ValueId>,
66+
loop_invariants: HashSet<ValueId>,
67+
}
68+
69+
impl<'f> LoopInvariantContext<'f> {
70+
fn new(function: &'f mut Function) -> Self {
71+
Self {
72+
inserter: FunctionInserter::new(function),
73+
defined_in_loop: HashSet::default(),
74+
loop_invariants: HashSet::default(),
75+
}
76+
}
77+
78+
fn hoist_loop_invariants(&mut self, loop_: &Loop, pre_header: BasicBlockId) {
79+
self.set_values_defined_in_loop(loop_);
80+
81+
for block in loop_.blocks.iter() {
82+
for instruction_id in self.inserter.function.dfg[*block].take_instructions() {
83+
let hoist_invariant = self.can_hoist_invariant(instruction_id);
84+
85+
if hoist_invariant {
86+
self.inserter.push_instruction(instruction_id, pre_header);
87+
} else {
88+
self.inserter.push_instruction(instruction_id, *block);
89+
}
90+
91+
self.update_values_defined_in_loop_and_invariants(instruction_id, hoist_invariant);
92+
}
93+
}
94+
}
95+
96+
/// Gather the variables declared within the loop
97+
fn set_values_defined_in_loop(&mut self, loop_: &Loop) {
98+
for block in loop_.blocks.iter() {
99+
let params = self.inserter.function.dfg.block_parameters(*block);
100+
self.defined_in_loop.extend(params);
101+
for instruction_id in self.inserter.function.dfg[*block].instructions() {
102+
let results = self.inserter.function.dfg.instruction_results(*instruction_id);
103+
self.defined_in_loop.extend(results);
104+
}
105+
}
106+
}
107+
108+
/// Update any values defined in the loop and loop invariants after a
109+
/// analyzing and re-inserting a loop's instruction.
110+
fn update_values_defined_in_loop_and_invariants(
111+
&mut self,
112+
instruction_id: InstructionId,
113+
hoist_invariant: bool,
114+
) {
115+
let results = self.inserter.function.dfg.instruction_results(instruction_id).to_vec();
116+
// We will have new IDs after pushing instructions.
117+
// We should mark the resolved result IDs as also being defined within the loop.
118+
let results =
119+
results.into_iter().map(|value| self.inserter.resolve(value)).collect::<Vec<_>>();
120+
self.defined_in_loop.extend(results.iter());
121+
122+
// We also want the update result IDs when we are marking loop invariants as we may not
123+
// be going through the blocks of the loop in execution order
124+
if hoist_invariant {
125+
// Track already found loop invariants
126+
self.loop_invariants.extend(results.iter());
127+
}
128+
}
129+
130+
fn can_hoist_invariant(&mut self, instruction_id: InstructionId) -> bool {
131+
let mut is_loop_invariant = true;
132+
// The list of blocks for a nested loop contain any inner loops as well.
133+
// We may have already re-inserted new instructions if two loops share blocks
134+
// so we need to map all the values in the instruction which we want to check.
135+
let (instruction, _) = self.inserter.map_instruction(instruction_id);
136+
instruction.for_each_value(|value| {
137+
// If an instruction value is defined in the loop and not already a loop invariant
138+
// the instruction results are not loop invariants.
139+
//
140+
// We are implicitly checking whether the values are constant as well.
141+
// The set of values defined in the loop only contains instruction results and block parameters
142+
// which cannot be constants.
143+
is_loop_invariant &=
144+
!self.defined_in_loop.contains(&value) || self.loop_invariants.contains(&value);
145+
});
146+
is_loop_invariant && instruction.can_be_deduplicated(&self.inserter.function.dfg, false)
147+
}
148+
149+
fn map_dependent_instructions(&mut self) {
150+
let blocks = self.inserter.function.reachable_blocks();
151+
for block in blocks {
152+
for instruction_id in self.inserter.function.dfg[block].take_instructions() {
153+
self.inserter.push_instruction(instruction_id, block);
154+
}
155+
self.inserter.map_terminator_in_place(block);
156+
}
157+
}
158+
}
159+
160+
#[cfg(test)]
161+
mod test {
162+
use crate::ssa::opt::assert_normalized_ssa_equals;
163+
use crate::ssa::Ssa;
164+
165+
#[test]
166+
fn simple_loop_invariant_code_motion() {
167+
let src = "
168+
brillig(inline) fn main f0 {
169+
b0(v0: u32, v1: u32):
170+
jmp b1(u32 0)
171+
b1(v2: u32):
172+
v5 = lt v2, u32 4
173+
jmpif v5 then: b3, else: b2
174+
b3():
175+
v6 = mul v0, v1
176+
constrain v6 == u32 6
177+
v8 = add v2, u32 1
178+
jmp b1(v8)
179+
b2():
180+
return
181+
}
182+
";
183+
184+
let mut ssa = Ssa::from_str(src).unwrap();
185+
let main = ssa.main_mut();
186+
187+
let instructions = main.dfg[main.entry_block()].instructions();
188+
assert_eq!(instructions.len(), 0); // The final return is not counted
189+
190+
// `v6 = mul v0, v1` in b3 should now be `v3 = mul v0, v1` in b0
191+
let expected = "
192+
brillig(inline) fn main f0 {
193+
b0(v0: u32, v1: u32):
194+
v3 = mul v0, v1
195+
jmp b1(u32 0)
196+
b1(v2: u32):
197+
v6 = lt v2, u32 4
198+
jmpif v6 then: b3, else: b2
199+
b3():
200+
constrain v3 == u32 6
201+
v9 = add v2, u32 1
202+
jmp b1(v9)
203+
b2():
204+
return
205+
}
206+
";
207+
208+
let ssa = ssa.loop_invariant_code_motion();
209+
assert_normalized_ssa_equals(ssa, expected);
210+
}
211+
212+
#[test]
213+
fn nested_loop_invariant_code_motion() {
214+
// Check that a loop invariant in the inner loop of a nested loop
215+
// is hoisted to the parent loop's pre-header block.
216+
let src = "
217+
brillig(inline) fn main f0 {
218+
b0(v0: u32, v1: u32):
219+
jmp b1(u32 0)
220+
b1(v2: u32):
221+
v6 = lt v2, u32 4
222+
jmpif v6 then: b3, else: b2
223+
b3():
224+
jmp b4(u32 0)
225+
b4(v3: u32):
226+
v7 = lt v3, u32 4
227+
jmpif v7 then: b6, else: b5
228+
b6():
229+
v10 = mul v0, v1
230+
constrain v10 == u32 6
231+
v12 = add v3, u32 1
232+
jmp b4(v12)
233+
b5():
234+
v9 = add v2, u32 1
235+
jmp b1(v9)
236+
b2():
237+
return
238+
}
239+
";
240+
241+
let mut ssa = Ssa::from_str(src).unwrap();
242+
let main = ssa.main_mut();
243+
244+
let instructions = main.dfg[main.entry_block()].instructions();
245+
assert_eq!(instructions.len(), 0); // The final return is not counted
246+
247+
// `v10 = mul v0, v1` in b6 should now be `v4 = mul v0, v1` in b0
248+
let expected = "
249+
brillig(inline) fn main f0 {
250+
b0(v0: u32, v1: u32):
251+
v4 = mul v0, v1
252+
jmp b1(u32 0)
253+
b1(v2: u32):
254+
v7 = lt v2, u32 4
255+
jmpif v7 then: b3, else: b2
256+
b3():
257+
jmp b4(u32 0)
258+
b4(v3: u32):
259+
v8 = lt v3, u32 4
260+
jmpif v8 then: b6, else: b5
261+
b6():
262+
constrain v4 == u32 6
263+
v12 = add v3, u32 1
264+
jmp b4(v12)
265+
b5():
266+
v10 = add v2, u32 1
267+
jmp b1(v10)
268+
b2():
269+
return
270+
}
271+
";
272+
273+
let ssa = ssa.loop_invariant_code_motion();
274+
assert_normalized_ssa_equals(ssa, expected);
275+
}
276+
277+
#[test]
278+
fn hoist_invariant_with_invariant_as_argument() {
279+
// Check that an instruction which has arguments defined in the loop
280+
// but which are already marked loop invariants is still hoisted to the preheader.
281+
//
282+
// For example, in b3 we have the following instructions:
283+
// ```text
284+
// v6 = mul v0, v1
285+
// v7 = mul v6, v0
286+
// ```
287+
// `v6` should be marked a loop invariants as `v0` and `v1` are both declared outside of the loop.
288+
// As we will be hoisting `v6 = mul v0, v1` to the loop preheader we know that we can also
289+
// hoist `v7 = mul v6, v0`.
290+
let src = "
291+
brillig(inline) fn main f0 {
292+
b0(v0: u32, v1: u32):
293+
jmp b1(u32 0)
294+
b1(v2: u32):
295+
v5 = lt v2, u32 4
296+
jmpif v5 then: b3, else: b2
297+
b3():
298+
v6 = mul v0, v1
299+
v7 = mul v6, v0
300+
v8 = eq v7, u32 12
301+
constrain v7 == u32 12
302+
v9 = add v2, u32 1
303+
jmp b1(v9)
304+
b2():
305+
return
306+
}
307+
";
308+
309+
let mut ssa = Ssa::from_str(src).unwrap();
310+
let main = ssa.main_mut();
311+
312+
let instructions = main.dfg[main.entry_block()].instructions();
313+
assert_eq!(instructions.len(), 0); // The final return is not counted
314+
315+
let expected = "
316+
brillig(inline) fn main f0 {
317+
b0(v0: u32, v1: u32):
318+
v3 = mul v0, v1
319+
v4 = mul v3, v0
320+
v6 = eq v4, u32 12
321+
jmp b1(u32 0)
322+
b1(v2: u32):
323+
v9 = lt v2, u32 4
324+
jmpif v9 then: b3, else: b2
325+
b3():
326+
constrain v4 == u32 12
327+
v11 = add v2, u32 1
328+
jmp b1(v11)
329+
b2():
330+
return
331+
}
332+
";
333+
334+
let ssa = ssa.loop_invariant_code_motion();
335+
assert_normalized_ssa_equals(ssa, expected);
336+
}
337+
338+
#[test]
339+
fn do_not_hoist_instructions_with_side_effects() {
340+
// In `v12 = load v5` in `b3`, `v5` is defined outside the loop.
341+
// However, as the instruction has side effects, we want to make sure
342+
// we do not hoist the instruction to the loop preheader.
343+
let src = "
344+
brillig(inline) fn main f0 {
345+
b0(v0: u32, v1: u32):
346+
v4 = make_array [u32 0, u32 0, u32 0, u32 0, u32 0] : [u32; 5]
347+
inc_rc v4
348+
v5 = allocate -> &mut [u32; 5]
349+
store v4 at v5
350+
jmp b1(u32 0)
351+
b1(v2: u32):
352+
v7 = lt v2, u32 4
353+
jmpif v7 then: b3, else: b2
354+
b3():
355+
v12 = load v5 -> [u32; 5]
356+
v13 = array_set v12, index v0, value v1
357+
store v13 at v5
358+
v15 = add v2, u32 1
359+
jmp b1(v15)
360+
b2():
361+
v8 = load v5 -> [u32; 5]
362+
v10 = array_get v8, index u32 2 -> u32
363+
constrain v10 == u32 3
364+
return
365+
}
366+
";
367+
368+
let mut ssa = Ssa::from_str(src).unwrap();
369+
let main = ssa.main_mut();
370+
371+
let instructions = main.dfg[main.entry_block()].instructions();
372+
assert_eq!(instructions.len(), 4); // The final return is not counted
373+
374+
let ssa = ssa.loop_invariant_code_motion();
375+
// The code should be unchanged
376+
assert_normalized_ssa_equals(ssa, src);
377+
}
378+
}

0 commit comments

Comments
 (0)