Skip to content

Commit cb6b125

Browse files
committed
Unroll small Brillig loops
1 parent f991f44 commit cb6b125

File tree

1 file changed

+186
-7
lines changed

1 file changed

+186
-7
lines changed

compiler/noirc_evaluator/src/ssa/opt/unrolling.rs

+186-7
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,9 @@ impl Loops {
209209
let mut unroll_errors = vec![];
210210
while let Some(next_loop) = self.yet_to_unroll.pop() {
211211
if function.runtime().is_brillig() {
212-
// TODO (#6470): Decide whether to unroll this loop.
213-
continue;
212+
if !next_loop.is_small_loop(function, &self.cfg) {
213+
continue;
214+
}
214215
}
215216
// If we've previously modified a block in this loop we need to refresh the context.
216217
// This happens any time we have nested loops.
@@ -593,21 +594,54 @@ impl Loop {
593594
/// of unrolled instructions times the number of iterations would result in smaller bytecode
594595
/// than if we keep the loops with their overheads.
595596
fn is_small_loop(&self, function: &Function, cfg: &ControlFlowGraph) -> bool {
597+
self.boilerplate_stats(function, cfg).map(|s| s.is_small()).unwrap_or_default()
598+
}
599+
600+
/// Collect boilerplate stats if we can figure out the upper and lower bounds of the loop.
601+
fn boilerplate_stats(
602+
&self,
603+
function: &Function,
604+
cfg: &ControlFlowGraph,
605+
) -> Option<BoilerplateStats> {
596606
let Ok(Some((lower, upper))) = self.get_const_bounds(function, cfg) else {
597-
return false;
607+
return None;
598608
};
599609
let Some(lower) = lower.try_to_u64() else {
600-
return false;
610+
return None;
601611
};
602612
let Some(upper) = upper.try_to_u64() else {
603-
return false;
613+
return None;
604614
};
605-
let num_iterations = (upper - lower) as usize;
606615
let refs = self.find_pre_header_reference_values(function, cfg);
607616
let (loads, stores) = self.count_loads_and_stores(function, &refs);
608617
let all_instructions = self.count_all_instructions(function);
609618
let useful_instructions = all_instructions - loads - stores - LOOP_BOILERPLATE_COUNT;
610-
useful_instructions * num_iterations < all_instructions
619+
Some(BoilerplateStats {
620+
iterations: (upper - lower) as usize,
621+
loads,
622+
stores,
623+
all_instructions,
624+
useful_instructions,
625+
})
626+
}
627+
}
628+
629+
#[derive(Debug)]
630+
struct BoilerplateStats {
631+
iterations: usize,
632+
loads: usize,
633+
stores: usize,
634+
all_instructions: usize,
635+
useful_instructions: usize,
636+
}
637+
638+
impl BoilerplateStats {
639+
/// A small loop is where if we unroll it into the pre-header then considering the
640+
/// number of iterations we still end up with a smaller bytecode than if we leave
641+
/// the blocks in tact with all the boilerplate involved in jumping, and the extra
642+
/// reference access instructions.
643+
fn is_small(&self) -> bool {
644+
self.useful_instructions * self.iterations < self.all_instructions
611645
}
612646
}
613647

@@ -1014,6 +1048,105 @@ mod tests {
10141048
assert!(loop0.is_small_loop(function, &loops.cfg));
10151049
}
10161050

1051+
#[test]
1052+
fn test_brillig_unroll_small_loop() {
1053+
let ssa = brillig_unroll_test_case();
1054+
1055+
// Example taken from an equivalent ACIR program (ie. remove the `unconstrained`) and run
1056+
// `cargo run -q -p nargo_cli -- --program-dir . compile --show-ssa`
1057+
let expected = "
1058+
brillig(inline) fn main f0 {
1059+
b0(v0: u32):
1060+
v1 = allocate -> &mut u32
1061+
store u32 0 at v1
1062+
v3 = load v1 -> u32
1063+
store v3 at v1
1064+
v4 = load v1 -> u32
1065+
v6 = add v4, u32 1
1066+
store v6 at v1
1067+
v7 = load v1 -> u32
1068+
v9 = add v7, u32 2
1069+
store v9 at v1
1070+
v10 = load v1 -> u32
1071+
v12 = add v10, u32 3
1072+
store v12 at v1
1073+
jmp b1()
1074+
b1():
1075+
v13 = load v1 -> u32
1076+
v14 = eq v13, v0
1077+
constrain v13 == v0
1078+
return
1079+
}
1080+
";
1081+
1082+
let (ssa, errors) = ssa.try_unroll_loops();
1083+
assert_eq!(errors.len(), 0, "Unroll should have no errors");
1084+
assert_eq!(ssa.main().reachable_blocks().len(), 2, "The loop should be unrolled");
1085+
1086+
assert_normalized_ssa_equals(ssa, expected);
1087+
}
1088+
1089+
#[test]
1090+
fn test_brillig_unroll_6470_small() {
1091+
// Few enough iterations so that we can perform the unroll.
1092+
let ssa = brillig_unroll_test_case_6470(3);
1093+
let (ssa, errors) = ssa.try_unroll_loops();
1094+
assert_eq!(errors.len(), 0, "Unroll should have no errors");
1095+
assert_eq!(ssa.main().reachable_blocks().len(), 2, "The loop should be unrolled");
1096+
1097+
// The IDs are shifted by one compared to what the ACIR version printed.
1098+
let expected = "
1099+
brillig(inline) fn __validate_gt_remainder f0 {
1100+
b0(v0: [u64; 6]):
1101+
inc_rc v0
1102+
inc_rc [u64 0, u64 0, u64 0, u64 0, u64 0, u64 0] of u64
1103+
v3 = allocate -> &mut [u64; 6]
1104+
store [u64 0, u64 0, u64 0, u64 0, u64 0, u64 0] of u64 at v3
1105+
v5 = load v3 -> [u64; 6]
1106+
v7 = array_get v0, index u32 0 -> u64
1107+
v9 = add v7, u64 1
1108+
v10 = array_set v5, index u32 0, value v9
1109+
store v10 at v3
1110+
v11 = load v3 -> [u64; 6]
1111+
v13 = array_get v0, index u32 1 -> u64
1112+
v14 = add v13, u64 1
1113+
v15 = array_set v11, index u32 1, value v14
1114+
store v15 at v3
1115+
v16 = load v3 -> [u64; 6]
1116+
v18 = array_get v0, index u32 2 -> u64
1117+
v19 = add v18, u64 1
1118+
v20 = array_set v16, index u32 2, value v19
1119+
store v20 at v3
1120+
jmp b1()
1121+
b1():
1122+
v21 = load v3 -> [u64; 6]
1123+
dec_rc v0
1124+
return v21
1125+
}
1126+
";
1127+
assert_normalized_ssa_equals(ssa, expected);
1128+
}
1129+
1130+
#[test]
1131+
fn test_brillig_unroll_6470_large() {
1132+
// More iterations than it can unroll
1133+
let ssa = brillig_unroll_test_case_6470(6);
1134+
1135+
let function = ssa.main();
1136+
let mut loops = Loops::find_all(function);
1137+
let loop0 = loops.yet_to_unroll.pop().unwrap();
1138+
let stats = loop0.boilerplate_stats(function, &loops.cfg).unwrap();
1139+
assert_eq!(stats.is_small(), false);
1140+
1141+
let (ssa, errors) = ssa.try_unroll_loops();
1142+
assert_eq!(errors.len(), 0, "Unroll should have no errors");
1143+
assert_eq!(
1144+
ssa.main().reachable_blocks().len(),
1145+
4,
1146+
"The loop should be considered too costly to unroll"
1147+
);
1148+
}
1149+
10171150
/// Simple test loop:
10181151
/// ```text
10191152
/// unconstrained fn main(sum: u32) {
@@ -1054,4 +1187,50 @@ mod tests {
10541187
";
10551188
Ssa::from_str(src).unwrap()
10561189
}
1190+
1191+
/// Test case from #6470:
1192+
/// ```text
1193+
/// unconstrained fn __validate_gt_remainder(a_u60: [u64; 6]) -> [u64; 6] {
1194+
/// let mut result_u60: [u64; 6] = [0; 6];
1195+
///
1196+
/// for i in 0..6 {
1197+
/// result_u60[i] = a_u60[i] + 1;
1198+
/// }
1199+
///
1200+
/// result_u60
1201+
/// }
1202+
/// ```
1203+
/// The `num_iterations` parameter can be used to make it more costly to inline.
1204+
fn brillig_unroll_test_case_6470(num_iterations: usize) -> Ssa {
1205+
let src = format!(
1206+
"
1207+
// After `static_assert` and `assert_constant`:
1208+
brillig(inline) fn __validate_gt_remainder f0 {{
1209+
b0(v0: [u64; 6]):
1210+
inc_rc v0
1211+
inc_rc [u64 0, u64 0, u64 0, u64 0, u64 0, u64 0] of u64
1212+
v4 = allocate -> &mut [u64; 6]
1213+
store [u64 0, u64 0, u64 0, u64 0, u64 0, u64 0] of u64 at v4
1214+
jmp b1(u32 0)
1215+
b1(v1: u32):
1216+
v7 = lt v1, u32 {num_iterations}
1217+
jmpif v7 then: b3, else: b2
1218+
b3():
1219+
v9 = load v4 -> [u64; 6]
1220+
v10 = array_get v0, index v1 -> u64
1221+
v12 = add v10, u64 1
1222+
v13 = array_set v9, index v1, value v12
1223+
v15 = add v1, u32 1
1224+
store v13 at v4
1225+
v16 = add v1, u32 1
1226+
jmp b1(v16)
1227+
b2():
1228+
v8 = load v4 -> [u64; 6]
1229+
dec_rc v0
1230+
return v8
1231+
}}
1232+
"
1233+
);
1234+
Ssa::from_str(&src).unwrap()
1235+
}
10571236
}

0 commit comments

Comments
 (0)