@@ -209,8 +209,9 @@ impl Loops {
209
209
let mut unroll_errors = vec ! [ ] ;
210
210
while let Some ( next_loop) = self . yet_to_unroll . pop ( ) {
211
211
if function. runtime ( ) . is_brillig ( ) {
212
- // TODO (#6470): Decide whether to unroll this loop.
213
- continue ;
212
+ if !next_loop. is_small_loop ( function, & self . cfg ) {
213
+ continue ;
214
+ }
214
215
}
215
216
// If we've previously modified a block in this loop we need to refresh the context.
216
217
// This happens any time we have nested loops.
@@ -593,21 +594,54 @@ impl Loop {
593
594
/// of unrolled instructions times the number of iterations would result in smaller bytecode
594
595
/// than if we keep the loops with their overheads.
595
596
fn is_small_loop ( & self , function : & Function , cfg : & ControlFlowGraph ) -> bool {
597
+ self . boilerplate_stats ( function, cfg) . map ( |s| s. is_small ( ) ) . unwrap_or_default ( )
598
+ }
599
+
600
+ /// Collect boilerplate stats if we can figure out the upper and lower bounds of the loop.
601
+ fn boilerplate_stats (
602
+ & self ,
603
+ function : & Function ,
604
+ cfg : & ControlFlowGraph ,
605
+ ) -> Option < BoilerplateStats > {
596
606
let Ok ( Some ( ( lower, upper) ) ) = self . get_const_bounds ( function, cfg) else {
597
- return false ;
607
+ return None ;
598
608
} ;
599
609
let Some ( lower) = lower. try_to_u64 ( ) else {
600
- return false ;
610
+ return None ;
601
611
} ;
602
612
let Some ( upper) = upper. try_to_u64 ( ) else {
603
- return false ;
613
+ return None ;
604
614
} ;
605
- let num_iterations = ( upper - lower) as usize ;
606
615
let refs = self . find_pre_header_reference_values ( function, cfg) ;
607
616
let ( loads, stores) = self . count_loads_and_stores ( function, & refs) ;
608
617
let all_instructions = self . count_all_instructions ( function) ;
609
618
let useful_instructions = all_instructions - loads - stores - LOOP_BOILERPLATE_COUNT ;
610
- useful_instructions * num_iterations < all_instructions
619
+ Some ( BoilerplateStats {
620
+ iterations : ( upper - lower) as usize ,
621
+ loads,
622
+ stores,
623
+ all_instructions,
624
+ useful_instructions,
625
+ } )
626
+ }
627
+ }
628
+
629
+ #[ derive( Debug ) ]
630
+ struct BoilerplateStats {
631
+ iterations : usize ,
632
+ loads : usize ,
633
+ stores : usize ,
634
+ all_instructions : usize ,
635
+ useful_instructions : usize ,
636
+ }
637
+
638
+ impl BoilerplateStats {
639
+ /// A small loop is where if we unroll it into the pre-header then considering the
640
+ /// number of iterations we still end up with a smaller bytecode than if we leave
641
+ /// the blocks in tact with all the boilerplate involved in jumping, and the extra
642
+ /// reference access instructions.
643
+ fn is_small ( & self ) -> bool {
644
+ self . useful_instructions * self . iterations < self . all_instructions
611
645
}
612
646
}
613
647
@@ -1014,6 +1048,105 @@ mod tests {
1014
1048
assert ! ( loop0. is_small_loop( function, & loops. cfg) ) ;
1015
1049
}
1016
1050
1051
+ #[ test]
1052
+ fn test_brillig_unroll_small_loop ( ) {
1053
+ let ssa = brillig_unroll_test_case ( ) ;
1054
+
1055
+ // Example taken from an equivalent ACIR program (ie. remove the `unconstrained`) and run
1056
+ // `cargo run -q -p nargo_cli -- --program-dir . compile --show-ssa`
1057
+ let expected = "
1058
+ brillig(inline) fn main f0 {
1059
+ b0(v0: u32):
1060
+ v1 = allocate -> &mut u32
1061
+ store u32 0 at v1
1062
+ v3 = load v1 -> u32
1063
+ store v3 at v1
1064
+ v4 = load v1 -> u32
1065
+ v6 = add v4, u32 1
1066
+ store v6 at v1
1067
+ v7 = load v1 -> u32
1068
+ v9 = add v7, u32 2
1069
+ store v9 at v1
1070
+ v10 = load v1 -> u32
1071
+ v12 = add v10, u32 3
1072
+ store v12 at v1
1073
+ jmp b1()
1074
+ b1():
1075
+ v13 = load v1 -> u32
1076
+ v14 = eq v13, v0
1077
+ constrain v13 == v0
1078
+ return
1079
+ }
1080
+ " ;
1081
+
1082
+ let ( ssa, errors) = ssa. try_unroll_loops ( ) ;
1083
+ assert_eq ! ( errors. len( ) , 0 , "Unroll should have no errors" ) ;
1084
+ assert_eq ! ( ssa. main( ) . reachable_blocks( ) . len( ) , 2 , "The loop should be unrolled" ) ;
1085
+
1086
+ assert_normalized_ssa_equals ( ssa, expected) ;
1087
+ }
1088
+
1089
+ #[ test]
1090
+ fn test_brillig_unroll_6470_small ( ) {
1091
+ // Few enough iterations so that we can perform the unroll.
1092
+ let ssa = brillig_unroll_test_case_6470 ( 3 ) ;
1093
+ let ( ssa, errors) = ssa. try_unroll_loops ( ) ;
1094
+ assert_eq ! ( errors. len( ) , 0 , "Unroll should have no errors" ) ;
1095
+ assert_eq ! ( ssa. main( ) . reachable_blocks( ) . len( ) , 2 , "The loop should be unrolled" ) ;
1096
+
1097
+ // The IDs are shifted by one compared to what the ACIR version printed.
1098
+ let expected = "
1099
+ brillig(inline) fn __validate_gt_remainder f0 {
1100
+ b0(v0: [u64; 6]):
1101
+ inc_rc v0
1102
+ inc_rc [u64 0, u64 0, u64 0, u64 0, u64 0, u64 0] of u64
1103
+ v3 = allocate -> &mut [u64; 6]
1104
+ store [u64 0, u64 0, u64 0, u64 0, u64 0, u64 0] of u64 at v3
1105
+ v5 = load v3 -> [u64; 6]
1106
+ v7 = array_get v0, index u32 0 -> u64
1107
+ v9 = add v7, u64 1
1108
+ v10 = array_set v5, index u32 0, value v9
1109
+ store v10 at v3
1110
+ v11 = load v3 -> [u64; 6]
1111
+ v13 = array_get v0, index u32 1 -> u64
1112
+ v14 = add v13, u64 1
1113
+ v15 = array_set v11, index u32 1, value v14
1114
+ store v15 at v3
1115
+ v16 = load v3 -> [u64; 6]
1116
+ v18 = array_get v0, index u32 2 -> u64
1117
+ v19 = add v18, u64 1
1118
+ v20 = array_set v16, index u32 2, value v19
1119
+ store v20 at v3
1120
+ jmp b1()
1121
+ b1():
1122
+ v21 = load v3 -> [u64; 6]
1123
+ dec_rc v0
1124
+ return v21
1125
+ }
1126
+ " ;
1127
+ assert_normalized_ssa_equals ( ssa, expected) ;
1128
+ }
1129
+
1130
+ #[ test]
1131
+ fn test_brillig_unroll_6470_large ( ) {
1132
+ // More iterations than it can unroll
1133
+ let ssa = brillig_unroll_test_case_6470 ( 6 ) ;
1134
+
1135
+ let function = ssa. main ( ) ;
1136
+ let mut loops = Loops :: find_all ( function) ;
1137
+ let loop0 = loops. yet_to_unroll . pop ( ) . unwrap ( ) ;
1138
+ let stats = loop0. boilerplate_stats ( function, & loops. cfg ) . unwrap ( ) ;
1139
+ assert_eq ! ( stats. is_small( ) , false ) ;
1140
+
1141
+ let ( ssa, errors) = ssa. try_unroll_loops ( ) ;
1142
+ assert_eq ! ( errors. len( ) , 0 , "Unroll should have no errors" ) ;
1143
+ assert_eq ! (
1144
+ ssa. main( ) . reachable_blocks( ) . len( ) ,
1145
+ 4 ,
1146
+ "The loop should be considered too costly to unroll"
1147
+ ) ;
1148
+ }
1149
+
1017
1150
/// Simple test loop:
1018
1151
/// ```text
1019
1152
/// unconstrained fn main(sum: u32) {
@@ -1054,4 +1187,50 @@ mod tests {
1054
1187
" ;
1055
1188
Ssa :: from_str ( src) . unwrap ( )
1056
1189
}
1190
+
1191
+ /// Test case from #6470:
1192
+ /// ```text
1193
+ /// unconstrained fn __validate_gt_remainder(a_u60: [u64; 6]) -> [u64; 6] {
1194
+ /// let mut result_u60: [u64; 6] = [0; 6];
1195
+ ///
1196
+ /// for i in 0..6 {
1197
+ /// result_u60[i] = a_u60[i] + 1;
1198
+ /// }
1199
+ ///
1200
+ /// result_u60
1201
+ /// }
1202
+ /// ```
1203
+ /// The `num_iterations` parameter can be used to make it more costly to inline.
1204
+ fn brillig_unroll_test_case_6470 ( num_iterations : usize ) -> Ssa {
1205
+ let src = format ! (
1206
+ "
1207
+ // After `static_assert` and `assert_constant`:
1208
+ brillig(inline) fn __validate_gt_remainder f0 {{
1209
+ b0(v0: [u64; 6]):
1210
+ inc_rc v0
1211
+ inc_rc [u64 0, u64 0, u64 0, u64 0, u64 0, u64 0] of u64
1212
+ v4 = allocate -> &mut [u64; 6]
1213
+ store [u64 0, u64 0, u64 0, u64 0, u64 0, u64 0] of u64 at v4
1214
+ jmp b1(u32 0)
1215
+ b1(v1: u32):
1216
+ v7 = lt v1, u32 {num_iterations}
1217
+ jmpif v7 then: b3, else: b2
1218
+ b3():
1219
+ v9 = load v4 -> [u64; 6]
1220
+ v10 = array_get v0, index v1 -> u64
1221
+ v12 = add v10, u64 1
1222
+ v13 = array_set v9, index v1, value v12
1223
+ v15 = add v1, u32 1
1224
+ store v13 at v4
1225
+ v16 = add v1, u32 1
1226
+ jmp b1(v16)
1227
+ b2():
1228
+ v8 = load v4 -> [u64; 6]
1229
+ dec_rc v0
1230
+ return v8
1231
+ }}
1232
+ "
1233
+ ) ;
1234
+ Ssa :: from_str ( & src) . unwrap ( )
1235
+ }
1057
1236
}
0 commit comments