7
7
//! - Already marked as loop invariants
8
8
//!
9
9
//! We also check that we are not hoisting instructions with side effects.
10
- use fxhash:: FxHashSet as HashSet ;
10
+ use acvm:: { acir:: AcirField , FieldElement } ;
11
+ use fxhash:: { FxHashMap as HashMap , FxHashSet as HashSet } ;
11
12
12
13
use crate :: ssa:: {
13
14
ir:: {
14
15
basic_block:: BasicBlockId ,
15
16
function:: { Function , RuntimeType } ,
16
17
function_inserter:: FunctionInserter ,
17
- instruction:: InstructionId ,
18
+ instruction:: { Instruction , InstructionId } ,
19
+ types:: Type ,
18
20
value:: ValueId ,
19
21
} ,
20
22
Ssa ,
@@ -45,25 +47,51 @@ impl Function {
45
47
}
46
48
47
49
impl Loops {
48
- fn hoist_loop_invariants ( self , function : & mut Function ) {
50
+ fn hoist_loop_invariants ( mut self , function : & mut Function ) {
49
51
let mut context = LoopInvariantContext :: new ( function) ;
50
52
51
- for loop_ in self . yet_to_unroll . iter ( ) {
53
+ // The loops should be sorted by the number of blocks.
54
+ // We want to access outer nested loops first, which we do by popping
55
+ // from the top of the list.
56
+ while let Some ( loop_) = self . yet_to_unroll . pop ( ) {
52
57
let Ok ( pre_header) = loop_. get_pre_header ( context. inserter . function , & self . cfg ) else {
53
58
// If the loop does not have a preheader we skip hoisting loop invariants for this loop
54
59
continue ;
55
60
} ;
56
- context. hoist_loop_invariants ( loop_, pre_header) ;
61
+
62
+ context. hoist_loop_invariants ( & loop_, pre_header) ;
57
63
}
58
64
59
65
context. map_dependent_instructions ( ) ;
60
66
}
61
67
}
62
68
69
+ impl Loop {
70
+ /// Find the value that controls whether to perform a loop iteration.
71
+ /// This is going to be the block parameter of the loop header.
72
+ ///
73
+ /// Consider the following example of a `for i in 0..4` loop:
74
+ /// ```text
75
+ /// brillig(inline) fn main f0 {
76
+ /// b0(v0: u32):
77
+ /// ...
78
+ /// jmp b1(u32 0)
79
+ /// b1(v1: u32): // Loop header
80
+ /// v5 = lt v1, u32 4 // Upper bound
81
+ /// jmpif v5 then: b3, else: b2
82
+ /// ```
83
+ /// In the example above, `v1` is the induction variable
84
+ fn get_induction_variable ( & self , function : & Function ) -> ValueId {
85
+ function. dfg . block_parameters ( self . header ) [ 0 ]
86
+ }
87
+ }
88
+
63
89
struct LoopInvariantContext < ' f > {
64
90
inserter : FunctionInserter < ' f > ,
65
91
defined_in_loop : HashSet < ValueId > ,
66
92
loop_invariants : HashSet < ValueId > ,
93
+ // Maps induction variable -> fixed upper loop bound
94
+ outer_induction_variables : HashMap < ValueId , FieldElement > ,
67
95
}
68
96
69
97
impl < ' f > LoopInvariantContext < ' f > {
@@ -72,6 +100,7 @@ impl<'f> LoopInvariantContext<'f> {
72
100
inserter : FunctionInserter :: new ( function) ,
73
101
defined_in_loop : HashSet :: default ( ) ,
74
102
loop_invariants : HashSet :: default ( ) ,
103
+ outer_induction_variables : HashMap :: default ( ) ,
75
104
}
76
105
}
77
106
@@ -88,13 +117,29 @@ impl<'f> LoopInvariantContext<'f> {
88
117
self . inserter . push_instruction ( instruction_id, * block) ;
89
118
}
90
119
91
- self . update_values_defined_in_loop_and_invariants ( instruction_id, hoist_invariant) ;
120
+ self . extend_values_defined_in_loop_and_invariants ( instruction_id, hoist_invariant) ;
92
121
}
93
122
}
123
+
124
+ // Keep track of a loop induction variable and respective upper bound.
125
+ // This will be used by later loops to determine whether they have operations
126
+ // reliant upon the maximum induction variable.
127
+ let upper_bound = loop_. get_const_upper_bound ( self . inserter . function ) ;
128
+ if let Some ( upper_bound) = upper_bound {
129
+ let induction_variable = loop_. get_induction_variable ( self . inserter . function ) ;
130
+ let induction_variable = self . inserter . resolve ( induction_variable) ;
131
+ self . outer_induction_variables . insert ( induction_variable, upper_bound) ;
132
+ }
94
133
}
95
134
96
135
/// Gather the variables declared within the loop
97
136
fn set_values_defined_in_loop ( & mut self , loop_ : & Loop ) {
137
+ // Clear any values that may be defined in previous loops, as the context is per function.
138
+ self . defined_in_loop . clear ( ) ;
139
+ // These are safe to keep per function, but we want to be clear that these values
140
+ // are used per loop.
141
+ self . loop_invariants . clear ( ) ;
142
+
98
143
for block in loop_. blocks . iter ( ) {
99
144
let params = self . inserter . function . dfg . block_parameters ( * block) ;
100
145
self . defined_in_loop . extend ( params) ;
@@ -107,7 +152,7 @@ impl<'f> LoopInvariantContext<'f> {
107
152
108
153
/// Update any values defined in the loop and loop invariants after a
109
154
/// analyzing and re-inserting a loop's instruction.
110
- fn update_values_defined_in_loop_and_invariants (
155
+ fn extend_values_defined_in_loop_and_invariants (
111
156
& mut self ,
112
157
instruction_id : InstructionId ,
113
158
hoist_invariant : bool ,
@@ -143,9 +188,45 @@ impl<'f> LoopInvariantContext<'f> {
143
188
is_loop_invariant &=
144
189
!self . defined_in_loop . contains ( & value) || self . loop_invariants . contains ( & value) ;
145
190
} ) ;
146
- is_loop_invariant && instruction. can_be_deduplicated ( & self . inserter . function . dfg , false )
191
+
192
+ let can_be_deduplicated = instruction
193
+ . can_be_deduplicated ( & self . inserter . function . dfg , false )
194
+ || self . can_be_deduplicated_from_upper_bound ( & instruction) ;
195
+
196
+ is_loop_invariant && can_be_deduplicated
197
+ }
198
+
199
+ /// Certain instructions can take advantage of that our induction variable has a fixed maximum.
200
+ ///
201
+ /// For example, an array access can usually only be safely deduplicated when we have a constant
202
+ /// index that is below the length of the array.
203
+ /// Checking an array get where the index is the loop's induction variable on its own
204
+ /// would determine that the instruction is not safe for hoisting.
205
+ /// However, if we know that the induction variable's upper bound will always be in bounds of the array
206
+ /// we can safely hoist the array access.
207
+ fn can_be_deduplicated_from_upper_bound ( & self , instruction : & Instruction ) -> bool {
208
+ match instruction {
209
+ Instruction :: ArrayGet { array, index } => {
210
+ let array_typ = self . inserter . function . dfg . type_of_value ( * array) ;
211
+ let upper_bound = self . outer_induction_variables . get ( index) ;
212
+ if let ( Type :: Array ( _, len) , Some ( upper_bound) ) = ( array_typ, upper_bound) {
213
+ upper_bound. to_u128 ( ) as usize <= len
214
+ } else {
215
+ false
216
+ }
217
+ }
218
+ _ => false ,
219
+ }
147
220
}
148
221
222
+ /// Loop invariant hoisting only operates over loop instructions.
223
+ /// The `FunctionInserter` is used for mapping old values to new values after
224
+ /// re-inserting loop invariant instructions.
225
+ /// However, there may be instructions which are not within loops that are
226
+ /// still reliant upon the instruction results altered during the pass.
227
+ /// This method re-inserts all instructions so that all instructions have
228
+ /// correct new value IDs based upon the `FunctionInserter` internal map.
229
+ /// Leaving out this mapping could lead to instructions with values that do not exist.
149
230
fn map_dependent_instructions ( & mut self ) {
150
231
let blocks = self . inserter . function . reachable_blocks ( ) ;
151
232
for block in blocks {
@@ -375,4 +456,108 @@ mod test {
375
456
// The code should be unchanged
376
457
assert_normalized_ssa_equals ( ssa, src) ;
377
458
}
459
+
460
+ #[ test]
461
+ fn hoist_array_gets_using_induction_variable_with_const_bound ( ) {
462
+ // SSA for the following program:
463
+ //
464
+ // fn triple_loop(x: u32) {
465
+ // let arr = [2; 5];
466
+ // for i in 0..4 {
467
+ // for j in 0..4 {
468
+ // for _ in 0..4 {
469
+ // assert_eq(arr[i], x);
470
+ // assert_eq(arr[j], x);
471
+ // }
472
+ // }
473
+ // }
474
+ // }
475
+ //
476
+ // `arr[i]` and `arr[j]` are safe to hoist as we know the maximum possible index
477
+ // to be used for both array accesses.
478
+ // We want to make sure `arr[i]` is hoisted to the outermost loop body and that
479
+ // `arr[j]` is hoisted to the second outermost loop body.
480
+ let src = "
481
+ brillig(inline) fn main f0 {
482
+ b0(v0: u32, v1: u32):
483
+ v6 = make_array [u32 2, u32 2, u32 2, u32 2, u32 2] : [u32; 5]
484
+ inc_rc v6
485
+ jmp b1(u32 0)
486
+ b1(v2: u32):
487
+ v9 = lt v2, u32 4
488
+ jmpif v9 then: b3, else: b2
489
+ b3():
490
+ jmp b4(u32 0)
491
+ b4(v3: u32):
492
+ v10 = lt v3, u32 4
493
+ jmpif v10 then: b6, else: b5
494
+ b6():
495
+ jmp b7(u32 0)
496
+ b7(v4: u32):
497
+ v13 = lt v4, u32 4
498
+ jmpif v13 then: b9, else: b8
499
+ b9():
500
+ v15 = array_get v6, index v2 -> u32
501
+ v16 = eq v15, v0
502
+ constrain v15 == v0
503
+ v17 = array_get v6, index v3 -> u32
504
+ v18 = eq v17, v0
505
+ constrain v17 == v0
506
+ v19 = add v4, u32 1
507
+ jmp b7(v19)
508
+ b8():
509
+ v14 = add v3, u32 1
510
+ jmp b4(v14)
511
+ b5():
512
+ v12 = add v2, u32 1
513
+ jmp b1(v12)
514
+ b2():
515
+ return
516
+ }
517
+ " ;
518
+
519
+ let ssa = Ssa :: from_str ( src) . unwrap ( ) ;
520
+
521
+ let expected = "
522
+ brillig(inline) fn main f0 {
523
+ b0(v0: u32, v1: u32):
524
+ v6 = make_array [u32 2, u32 2, u32 2, u32 2, u32 2] : [u32; 5]
525
+ inc_rc v6
526
+ jmp b1(u32 0)
527
+ b1(v2: u32):
528
+ v9 = lt v2, u32 4
529
+ jmpif v9 then: b3, else: b2
530
+ b3():
531
+ v10 = array_get v6, index v2 -> u32
532
+ v11 = eq v10, v0
533
+ jmp b4(u32 0)
534
+ b4(v3: u32):
535
+ v12 = lt v3, u32 4
536
+ jmpif v12 then: b6, else: b5
537
+ b6():
538
+ v15 = array_get v6, index v3 -> u32
539
+ v16 = eq v15, v0
540
+ jmp b7(u32 0)
541
+ b7(v4: u32):
542
+ v17 = lt v4, u32 4
543
+ jmpif v17 then: b9, else: b8
544
+ b9():
545
+ constrain v10 == v0
546
+ constrain v15 == v0
547
+ v19 = add v4, u32 1
548
+ jmp b7(v19)
549
+ b8():
550
+ v18 = add v3, u32 1
551
+ jmp b4(v18)
552
+ b5():
553
+ v14 = add v2, u32 1
554
+ jmp b1(v14)
555
+ b2():
556
+ return
557
+ }
558
+ " ;
559
+
560
+ let ssa = ssa. loop_invariant_code_motion ( ) ;
561
+ assert_normalized_ssa_equals ( ssa, expected) ;
562
+ }
378
563
}
0 commit comments