1
- use std :: collections :: { HashMap , HashSet } ;
1
+ use fxhash :: { FxHashMap as HashMap , FxHashSet as HashSet } ;
2
2
3
3
use crate :: ssa:: {
4
4
ir:: {
@@ -29,16 +29,27 @@ impl Ssa {
29
29
}
30
30
}
31
31
32
- #[ derive( Default ) ]
33
- struct Context {
32
+ struct Context < ' f > {
33
+ function : & ' f Function ,
34
+
35
+ last_block : BasicBlockId ,
34
36
// All inc_rc instructions encountered without a corresponding dec_rc.
35
- // These are only searched for in the first block of a function.
37
+ // These are only searched for in the first and exit block of a function.
36
38
//
37
39
// The type of the array being operated on is recorded.
38
40
// If an array_set to that array type is encountered, that is also recorded.
39
41
inc_rcs : HashMap < Type , Vec < IncRc > > ,
40
42
}
41
43
44
+ impl < ' f > Context < ' f > {
45
+ fn new ( function : & ' f Function ) -> Self {
46
+ let last_block = Self :: find_last_block ( function) ;
47
+ // let all_block_params =
48
+ Context { function, last_block, inc_rcs : HashMap :: default ( ) }
49
+ }
50
+ }
51
+
52
+ #[ derive( Clone , Debug ) ]
42
53
struct IncRc {
43
54
id : InstructionId ,
44
55
array : ValueId ,
@@ -59,11 +70,11 @@ fn remove_paired_rc(function: &mut Function) {
59
70
return ;
60
71
}
61
72
62
- let mut context = Context :: default ( ) ;
73
+ let mut context = Context :: new ( function ) ;
63
74
64
- context. find_rcs_in_entry_block ( function ) ;
65
- context. scan_for_array_sets ( function ) ;
66
- let to_remove = context. find_rcs_to_remove ( function ) ;
75
+ context. find_rcs_in_entry_and_exit_block ( ) ;
76
+ context. scan_for_array_sets ( ) ;
77
+ let to_remove = context. find_rcs_to_remove ( ) ;
67
78
remove_instructions ( to_remove, function) ;
68
79
}
69
80
@@ -72,13 +83,17 @@ fn contains_array_parameter(function: &mut Function) -> bool {
72
83
parameters. any ( |parameter| function. dfg . type_of_value ( * parameter) . contains_an_array ( ) )
73
84
}
74
85
75
- impl Context {
76
- fn find_rcs_in_entry_block ( & mut self , function : & Function ) {
77
- let entry = function. entry_block ( ) ;
86
+ impl < ' f > Context < ' f > {
87
+ fn find_rcs_in_entry_and_exit_block ( & mut self ) {
88
+ let entry = self . function . entry_block ( ) ;
89
+ self . find_rcs_in_block ( entry) ;
90
+ self . find_rcs_in_block ( self . last_block ) ;
91
+ }
78
92
79
- for instruction in function. dfg [ entry] . instructions ( ) {
80
- if let Instruction :: IncrementRc { value } = & function. dfg [ * instruction] {
81
- let typ = function. dfg . type_of_value ( * value) ;
93
+ fn find_rcs_in_block ( & mut self , block_id : BasicBlockId ) {
94
+ for instruction in self . function . dfg [ block_id] . instructions ( ) {
95
+ if let Instruction :: IncrementRc { value } = & self . function . dfg [ * instruction] {
96
+ let typ = self . function . dfg . type_of_value ( * value) ;
82
97
83
98
// We assume arrays aren't mutated until we find an array_set
84
99
let inc_rc = IncRc { id : * instruction, array : * value, possibly_mutated : false } ;
@@ -89,14 +104,28 @@ impl Context {
89
104
90
105
/// Find each array_set instruction in the function and mark any arrays used
91
106
/// by the inc_rc instructions as possibly mutated if they're the same type.
92
- fn scan_for_array_sets ( & mut self , function : & Function ) {
93
- for block in function. reachable_blocks ( ) {
94
- for instruction in function. dfg [ block] . instructions ( ) {
95
- if let Instruction :: ArraySet { array, .. } = function. dfg [ * instruction] {
96
- let typ = function. dfg . type_of_value ( array) ;
107
+ fn scan_for_array_sets ( & mut self ) {
108
+ // Block parameters could be passed to from function parameters.
109
+ // Thus, any inc rcs from block parameters with matching array sets need to marked possibly mutated.
110
+ let mut per_func_block_params: HashSet < ValueId > = HashSet :: default ( ) ;
111
+
112
+ for block in self . function . reachable_blocks ( ) {
113
+ let block_params = self . function . dfg . block_parameters ( block) ;
114
+ per_func_block_params. extend ( block_params. iter ( ) ) ;
115
+ }
116
+
117
+ for block in self . function . reachable_blocks ( ) {
118
+ for instruction in self . function . dfg [ block] . instructions ( ) {
119
+ if let Instruction :: ArraySet { array, .. } = self . function . dfg [ * instruction] {
120
+ let typ = self . function . dfg . type_of_value ( array) ;
97
121
if let Some ( inc_rcs) = self . inc_rcs . get_mut ( & typ) {
98
122
for inc_rc in inc_rcs {
99
- inc_rc. possibly_mutated = true ;
123
+ if inc_rc. array == array
124
+ || self . function . parameters ( ) . contains ( & inc_rc. array )
125
+ || per_func_block_params. contains ( & inc_rc. array )
126
+ {
127
+ inc_rc. possibly_mutated = true ;
128
+ }
100
129
}
101
130
}
102
131
}
@@ -106,13 +135,12 @@ impl Context {
106
135
107
136
/// Find each dec_rc instruction and if the most recent inc_rc instruction for the same value
108
137
/// is not possibly mutated, then we can remove them both. Returns each such pair.
109
- fn find_rcs_to_remove ( & mut self , function : & Function ) -> HashSet < InstructionId > {
110
- let last_block = Self :: find_last_block ( function) ;
111
- let mut to_remove = HashSet :: new ( ) ;
138
+ fn find_rcs_to_remove ( & mut self ) -> HashSet < InstructionId > {
139
+ let mut to_remove = HashSet :: default ( ) ;
112
140
113
- for instruction in function. dfg [ last_block] . instructions ( ) {
114
- if let Instruction :: DecrementRc { value } = & function. dfg [ * instruction] {
115
- if let Some ( inc_rc) = self . pop_rc_for ( * value, function ) {
141
+ for instruction in self . function . dfg [ self . last_block ] . instructions ( ) {
142
+ if let Instruction :: DecrementRc { value } = & self . function . dfg [ * instruction] {
143
+ if let Some ( inc_rc) = self . pop_rc_for ( * value) {
116
144
if !inc_rc. possibly_mutated {
117
145
to_remove. insert ( inc_rc. id ) ;
118
146
to_remove. insert ( * instruction) ;
@@ -139,8 +167,8 @@ impl Context {
139
167
}
140
168
141
169
/// Finds and pops the IncRc for the given array value if possible.
142
- fn pop_rc_for ( & mut self , value : ValueId , function : & Function ) -> Option < IncRc > {
143
- let typ = function. dfg . type_of_value ( value) ;
170
+ fn pop_rc_for ( & mut self , value : ValueId ) -> Option < IncRc > {
171
+ let typ = self . function . dfg . type_of_value ( value) ;
144
172
145
173
let rcs = self . inc_rcs . get_mut ( & typ) ?;
146
174
let position = rcs. iter ( ) . position ( |inc_rc| inc_rc. array == value) ?;
@@ -265,6 +293,7 @@ mod test {
265
293
builder. terminate_with_return ( vec ! [ ] ) ;
266
294
267
295
let ssa = builder. finish ( ) . remove_paired_rc ( ) ;
296
+ println ! ( "{}" , ssa) ;
268
297
let main = ssa. main ( ) ;
269
298
let entry = main. entry_block ( ) ;
270
299
@@ -325,4 +354,160 @@ mod test {
325
354
assert_eq ! ( count_inc_rcs( entry, & main. dfg) , 1 ) ;
326
355
assert_eq ! ( count_dec_rcs( entry, & main. dfg) , 1 ) ;
327
356
}
357
+
358
+ #[ test]
359
+ fn separate_entry_and_exit_block_fn_return_array ( ) {
360
+ // brillig fn foo f0 {
361
+ // b0(v0: [Field; 2]):
362
+ // jmp b1(v0)
363
+ // b1():
364
+ // inc_rc v0
365
+ // inc_rc v0
366
+ // dec_rc v0
367
+ // return [v0]
368
+ // }
369
+ let main_id = Id :: test_new ( 0 ) ;
370
+ let mut builder = FunctionBuilder :: new ( "foo" . into ( ) , main_id) ;
371
+ builder. set_runtime ( RuntimeType :: Brillig ) ;
372
+
373
+ let inner_array_type = Type :: Array ( Arc :: new ( vec ! [ Type :: field( ) ] ) , 2 ) ;
374
+ let v0 = builder. add_parameter ( inner_array_type. clone ( ) ) ;
375
+
376
+ let b1 = builder. insert_block ( ) ;
377
+ builder. terminate_with_jmp ( b1, vec ! [ v0] ) ;
378
+
379
+ builder. switch_to_block ( b1) ;
380
+ builder. insert_inc_rc ( v0) ;
381
+ builder. insert_inc_rc ( v0) ;
382
+ builder. insert_dec_rc ( v0) ;
383
+
384
+ let outer_array_type = Type :: Array ( Arc :: new ( vec ! [ inner_array_type] ) , 1 ) ;
385
+ let array = builder. array_constant ( vec ! [ v0] . into ( ) , outer_array_type) ;
386
+ builder. terminate_with_return ( vec ! [ array] ) ;
387
+
388
+ // Expected result:
389
+ //
390
+ // brillig fn foo f0 {
391
+ // b0(v0: [Field; 2]):
392
+ // jmp b1(v0)
393
+ // b1():
394
+ // inc_rc v0
395
+ // return [v0]
396
+ // }
397
+ let ssa = builder. finish ( ) . remove_paired_rc ( ) ;
398
+ let main = ssa. main ( ) ;
399
+
400
+ assert_eq ! ( count_inc_rcs( b1, & main. dfg) , 1 ) ;
401
+ assert_eq ! ( count_dec_rcs( b1, & main. dfg) , 0 ) ;
402
+ }
403
+
404
+ #[ test]
405
+ fn exit_block_single_mutation ( ) {
406
+ // fn mutator(mut array: [Field; 2]) {
407
+ // array[0] = 5;
408
+ // }
409
+ //
410
+ // acir(inline) fn mutator f0 {
411
+ // b0(v0: [Field; 2]):
412
+ // jmp b1(v0)
413
+ // b1(v1: [Field; 2]):
414
+ // v2 = allocate
415
+ // store v1 at v2
416
+ // inc_rc v1
417
+ // v3 = load v2
418
+ // v6 = array_set v3, index u64 0, value Field 5
419
+ // store v6 at v2
420
+ // dec_rc v1
421
+ // return
422
+ // }
423
+ let main_id = Id :: test_new ( 0 ) ;
424
+ let mut builder = FunctionBuilder :: new ( "mutator" . into ( ) , main_id) ;
425
+
426
+ let array_type = Type :: Array ( Arc :: new ( vec ! [ Type :: field( ) ] ) , 2 ) ;
427
+ let v0 = builder. add_parameter ( array_type. clone ( ) ) ;
428
+
429
+ let b1 = builder. insert_block ( ) ;
430
+ builder. terminate_with_jmp ( b1, vec ! [ v0] ) ;
431
+
432
+ builder. switch_to_block ( b1) ;
433
+ // We want to make sure we go through the block parameter
434
+ let v1 = builder. add_block_parameter ( b1, array_type. clone ( ) ) ;
435
+
436
+ let v2 = builder. insert_allocate ( array_type. clone ( ) ) ;
437
+ builder. insert_store ( v2, v1) ;
438
+ builder. insert_inc_rc ( v1) ;
439
+ let v3 = builder. insert_load ( v2, array_type) ;
440
+
441
+ let zero = builder. numeric_constant ( 0u128 , Type :: unsigned ( 64 ) ) ;
442
+ let five = builder. field_constant ( 5u128 ) ;
443
+ let v8 = builder. insert_array_set ( v3, zero, five) ;
444
+
445
+ builder. insert_store ( v2, v8) ;
446
+ builder. insert_dec_rc ( v1) ;
447
+ builder. terminate_with_return ( vec ! [ ] ) ;
448
+
449
+ let ssa = builder. finish ( ) . remove_paired_rc ( ) ;
450
+ let main = ssa. main ( ) ;
451
+
452
+ // No changes, the array is possibly mutated
453
+ assert_eq ! ( count_inc_rcs( b1, & main. dfg) , 1 ) ;
454
+ assert_eq ! ( count_dec_rcs( b1, & main. dfg) , 1 ) ;
455
+ }
456
+
457
+ #[ test]
458
+ fn exit_block_mutation_through_reference ( ) {
459
+ // fn mutator2(array: &mut [Field; 2]) {
460
+ // array[0] = 5;
461
+ // }
462
+ // acir(inline) fn mutator2 f0 {
463
+ // b0(v0: &mut [Field; 2]):
464
+ // jmp b1(v0)
465
+ // b1(v1: &mut [Field; 2]):
466
+ // v2 = load v1
467
+ // inc_rc v1
468
+ // store v2 at v1
469
+ // v3 = load v2
470
+ // v6 = array_set v3, index u64 0, value Field 5
471
+ // store v6 at v1
472
+ // v7 = load v1
473
+ // dec_rc v7
474
+ // store v7 at v1
475
+ // return
476
+ // }
477
+ let main_id = Id :: test_new ( 0 ) ;
478
+ let mut builder = FunctionBuilder :: new ( "mutator2" . into ( ) , main_id) ;
479
+
480
+ let array_type = Type :: Array ( Arc :: new ( vec ! [ Type :: field( ) ] ) , 2 ) ;
481
+ let reference_type = Type :: Reference ( Arc :: new ( array_type. clone ( ) ) ) ;
482
+
483
+ let v0 = builder. add_parameter ( reference_type. clone ( ) ) ;
484
+
485
+ let b1 = builder. insert_block ( ) ;
486
+ builder. terminate_with_jmp ( b1, vec ! [ v0] ) ;
487
+
488
+ builder. switch_to_block ( b1) ;
489
+ let v1 = builder. add_block_parameter ( b1, reference_type) ;
490
+
491
+ let v2 = builder. insert_load ( v1, array_type. clone ( ) ) ;
492
+ builder. insert_inc_rc ( v1) ;
493
+ builder. insert_store ( v1, v2) ;
494
+
495
+ let v3 = builder. insert_load ( v2, array_type. clone ( ) ) ;
496
+ let zero = builder. numeric_constant ( 0u128 , Type :: unsigned ( 64 ) ) ;
497
+ let five = builder. field_constant ( 5u128 ) ;
498
+ let v6 = builder. insert_array_set ( v3, zero, five) ;
499
+
500
+ builder. insert_store ( v1, v6) ;
501
+ let v7 = builder. insert_load ( v1, array_type) ;
502
+ builder. insert_dec_rc ( v7) ;
503
+ builder. insert_store ( v1, v7) ;
504
+ builder. terminate_with_return ( vec ! [ ] ) ;
505
+
506
+ let ssa = builder. finish ( ) . remove_paired_rc ( ) ;
507
+ let main = ssa. main ( ) ;
508
+
509
+ // No changes, the array is possibly mutated
510
+ assert_eq ! ( count_inc_rcs( b1, & main. dfg) , 1 ) ;
511
+ assert_eq ! ( count_dec_rcs( b1, & main. dfg) , 1 ) ;
512
+ }
328
513
}
0 commit comments