@@ -181,7 +181,12 @@ impl<'f> PerFunctionContext<'f> {
181
181
self . last_loads . get ( store_address) . is_none ( )
182
182
} ;
183
183
184
- if remove_load && !is_reference_param {
184
+ let is_reference_alias = block
185
+ . expressions
186
+ . get ( store_address)
187
+ . map_or ( false , |expression| matches ! ( expression, Expression :: Dereference ( _) ) ) ;
188
+
189
+ if remove_load && !is_reference_param && !is_reference_alias {
185
190
self . instructions_to_remove . insert ( * store_instruction) ;
186
191
}
187
192
}
@@ -286,19 +291,19 @@ impl<'f> PerFunctionContext<'f> {
286
291
} else {
287
292
references. mark_value_used ( address, self . inserter . function ) ;
288
293
289
- let expression = if let Some ( expression) = references. expressions . get ( & result) {
290
- expression. clone ( )
291
- } else {
292
- references. expressions . insert ( result, Expression :: Other ( result) ) ;
293
- Expression :: Other ( result)
294
- } ;
295
- if let Some ( aliases) = references. aliases . get_mut ( & expression) {
294
+ let expression =
295
+ references. expressions . entry ( result) . or_insert ( Expression :: Other ( result) ) ;
296
+ // Make sure this load result is marked an alias to itself
297
+ if let Some ( aliases) = references. aliases . get_mut ( expression) {
298
+ // If we have an alias set, add to the set
296
299
aliases. insert ( result) ;
297
300
} else {
301
+ // Otherwise, create a new alias set containing just the load result
298
302
references
299
303
. aliases
300
304
. insert ( Expression :: Other ( result) , AliasSet :: known ( result) ) ;
301
305
}
306
+ // Mark that we know a load result is equivalent to the address of a load.
302
307
references. set_known_value ( result, address) ;
303
308
304
309
self . last_loads . insert ( address, ( instruction, block_id) ) ;
@@ -789,4 +794,98 @@ mod tests {
789
794
// We expect the last eq to be optimized out
790
795
assert_eq ! ( b1_instructions. len( ) , 0 ) ;
791
796
}
797
+
798
+ #[ test]
799
+ fn keep_store_to_alias_in_loop_block ( ) {
800
+ // This test makes sure the instruction `store Field 2 at v5` in b2 remains after mem2reg.
801
+ // Although the only instruction on v5 is a lone store without any loads,
802
+ // v5 is an alias of the reference v0 which is stored in v2.
803
+ // This test makes sure that we are not inadvertently removing stores to aliases across blocks.
804
+ //
805
+ // acir(inline) fn main f0 {
806
+ // b0():
807
+ // v0 = allocate
808
+ // store Field 0 at v0
809
+ // v2 = allocate
810
+ // store v0 at v2
811
+ // jmp b1(Field 0)
812
+ // b1(v3: Field):
813
+ // v4 = eq v3, Field 0
814
+ // jmpif v4 then: b2, else: b3
815
+ // b2():
816
+ // v5 = load v2
817
+ // store Field 2 at v5
818
+ // v8 = add v3, Field 1
819
+ // jmp b1(v8)
820
+ // b3():
821
+ // v9 = load v0
822
+ // v10 = eq v9, Field 2
823
+ // constrain v9 == Field 2
824
+ // v11 = load v2
825
+ // v12 = load v10
826
+ // v13 = eq v12, Field 2
827
+ // constrain v11 == Field 2
828
+ // return
829
+ // }
830
+ let main_id = Id :: test_new ( 0 ) ;
831
+ let mut builder = FunctionBuilder :: new ( "main" . into ( ) , main_id) ;
832
+
833
+ let v0 = builder. insert_allocate ( Type :: field ( ) ) ;
834
+ let zero = builder. numeric_constant ( 0u128 , Type :: field ( ) ) ;
835
+ builder. insert_store ( v0, zero) ;
836
+
837
+ let v2 = builder. insert_allocate ( Type :: field ( ) ) ;
838
+ // Construct alias
839
+ builder. insert_store ( v2, v0) ;
840
+ let v2_type = builder. current_function . dfg . type_of_value ( v2) ;
841
+ assert ! ( builder. current_function. dfg. value_is_reference( v2) ) ;
842
+
843
+ let b1 = builder. insert_block ( ) ;
844
+ builder. terminate_with_jmp ( b1, vec ! [ zero] ) ;
845
+
846
+ // Loop header
847
+ builder. switch_to_block ( b1) ;
848
+ let v3 = builder. add_block_parameter ( b1, Type :: field ( ) ) ;
849
+ let is_zero = builder. insert_binary ( v3, BinaryOp :: Eq , zero) ;
850
+
851
+ let b2 = builder. insert_block ( ) ;
852
+ let b3 = builder. insert_block ( ) ;
853
+ builder. terminate_with_jmpif ( is_zero, b2, b3) ;
854
+
855
+ // Loop body
856
+ builder. switch_to_block ( b2) ;
857
+ let v5 = builder. insert_load ( v2, v2_type. clone ( ) ) ;
858
+ let two = builder. numeric_constant ( 2u128 , Type :: field ( ) ) ;
859
+ builder. insert_store ( v5, two) ;
860
+ let one = builder. numeric_constant ( 1u128 , Type :: field ( ) ) ;
861
+ let v3_plus_one = builder. insert_binary ( v3, BinaryOp :: Add , one) ;
862
+ builder. terminate_with_jmp ( b1, vec ! [ v3_plus_one] ) ;
863
+
864
+ builder. switch_to_block ( b3) ;
865
+ let v9 = builder. insert_load ( v0, Type :: field ( ) ) ;
866
+ let _ = builder. insert_binary ( v9, BinaryOp :: Eq , two) ;
867
+
868
+ builder. insert_constrain ( v9, two, None ) ;
869
+ let v11 = builder. insert_load ( v2, v2_type) ;
870
+ let v12 = builder. insert_load ( v11, Type :: field ( ) ) ;
871
+ let _ = builder. insert_binary ( v12, BinaryOp :: Eq , two) ;
872
+
873
+ builder. insert_constrain ( v11, two, None ) ;
874
+ builder. terminate_with_return ( vec ! [ ] ) ;
875
+
876
+ let ssa = builder. finish ( ) ;
877
+
878
+ // We expect the same result as above.
879
+ let ssa = ssa. mem2reg ( ) ;
880
+
881
+ let main = ssa. main ( ) ;
882
+ assert_eq ! ( main. reachable_blocks( ) . len( ) , 4 ) ;
883
+
884
+ // The store from the original SSA should remain
885
+ assert_eq ! ( count_stores( main. entry_block( ) , & main. dfg) , 2 ) ;
886
+ assert_eq ! ( count_stores( b2, & main. dfg) , 1 ) ;
887
+
888
+ assert_eq ! ( count_loads( b2, & main. dfg) , 1 ) ;
889
+ assert_eq ! ( count_loads( b3, & main. dfg) , 3 ) ;
890
+ }
792
891
}
0 commit comments