@@ -626,10 +626,7 @@ impl<'f> LoopIteration<'f> {
626
626
627
627
#[ cfg( test) ]
628
628
mod tests {
629
- use crate :: ssa:: {
630
- function_builder:: FunctionBuilder ,
631
- ir:: { instruction:: BinaryOp , map:: Id , types:: Type } ,
632
- } ;
629
+ use crate :: ssa:: { opt:: assert_normalized_ssa_equals, Ssa } ;
633
630
634
631
#[ test]
635
632
fn unroll_nested_loops ( ) {
@@ -640,162 +637,89 @@ mod tests {
640
637
// }
641
638
// }
642
639
// }
643
- //
644
- // fn main f0 {
645
- // b0():
646
- // jmp b1(Field 0)
647
- // b1(v0: Field): // header of outer loop
648
- // v1 = lt v0, Field 3
649
- // jmpif v1, then: b2, else: b3
650
- // b2():
651
- // jmp b4(Field 0)
652
- // b4(v2: Field): // header of inner loop
653
- // v3 = lt v2, Field 4
654
- // jmpif v3, then: b5, else: b6
655
- // b5():
656
- // v4 = add v0, v2
657
- // v5 = lt Field 10, v4
658
- // constrain v5
659
- // v6 = add v2, Field 1
660
- // jmp b4(v6)
661
- // b6(): // end of inner loop
662
- // v7 = add v0, Field 1
663
- // jmp b1(v7)
664
- // b3(): // end of outer loop
665
- // return Field 0
666
- // }
667
- let main_id = Id :: test_new ( 0 ) ;
668
-
669
- // Compiling main
670
- let mut builder = FunctionBuilder :: new ( "main" . into ( ) , main_id) ;
671
-
672
- let b1 = builder. insert_block ( ) ;
673
- let b2 = builder. insert_block ( ) ;
674
- let b3 = builder. insert_block ( ) ;
675
- let b4 = builder. insert_block ( ) ;
676
- let b5 = builder. insert_block ( ) ;
677
- let b6 = builder. insert_block ( ) ;
678
-
679
- let v0 = builder. add_block_parameter ( b1, Type :: field ( ) ) ;
680
- let v2 = builder. add_block_parameter ( b4, Type :: field ( ) ) ;
681
-
682
- let zero = builder. field_constant ( 0u128 ) ;
683
- let one = builder. field_constant ( 1u128 ) ;
684
- let three = builder. field_constant ( 3u128 ) ;
685
- let four = builder. field_constant ( 4u128 ) ;
686
- let ten = builder. field_constant ( 10u128 ) ;
687
-
688
- builder. terminate_with_jmp ( b1, vec ! [ zero] ) ;
689
-
690
- // b1
691
- builder. switch_to_block ( b1) ;
692
- let v1 = builder. insert_binary ( v0, BinaryOp :: Lt , three) ;
693
- builder. terminate_with_jmpif ( v1, b2, b3) ;
694
-
695
- // b2
696
- builder. switch_to_block ( b2) ;
697
- builder. terminate_with_jmp ( b4, vec ! [ zero] ) ;
698
-
699
- // b3
700
- builder. switch_to_block ( b3) ;
701
- builder. terminate_with_return ( vec ! [ zero] ) ;
702
-
703
- // b4
704
- builder. switch_to_block ( b4) ;
705
- let v3 = builder. insert_binary ( v2, BinaryOp :: Lt , four) ;
706
- builder. terminate_with_jmpif ( v3, b5, b6) ;
707
-
708
- // b5
709
- builder. switch_to_block ( b5) ;
710
- let v4 = builder. insert_binary ( v0, BinaryOp :: Add , v2) ;
711
- let v5 = builder. insert_binary ( ten, BinaryOp :: Lt , v4) ;
712
- builder. insert_constrain ( v5, one, None ) ;
713
- let v6 = builder. insert_binary ( v2, BinaryOp :: Add , one) ;
714
- builder. terminate_with_jmp ( b4, vec ! [ v6] ) ;
715
-
716
- // b6
717
- builder. switch_to_block ( b6) ;
718
- let v7 = builder. insert_binary ( v0, BinaryOp :: Add , one) ;
719
- builder. terminate_with_jmp ( b1, vec ! [ v7] ) ;
720
-
721
- let ssa = builder. finish ( ) ;
722
- assert_eq ! ( ssa. main( ) . reachable_blocks( ) . len( ) , 7 ) ;
723
-
724
- // Expected output:
725
- //
726
- // fn main f0 {
727
- // b0():
728
- // constrain Field 0
729
- // constrain Field 0
730
- // constrain Field 0
731
- // constrain Field 0
732
- // jmp b23()
733
- // b23():
734
- // constrain Field 0
735
- // constrain Field 0
736
- // constrain Field 0
737
- // constrain Field 0
738
- // jmp b27()
739
- // b27():
740
- // constrain Field 0
741
- // constrain Field 0
742
- // constrain Field 0
743
- // constrain Field 0
744
- // jmp b31()
745
- // b31():
746
- // jmp b3()
747
- // b3():
748
- // return Field 0
749
- // }
640
+ let src = "
641
+ acir(inline) fn main f0 {
642
+ b0():
643
+ jmp b1(Field 0)
644
+ b1(v0: Field): // header of outer loop
645
+ v1 = lt v0, Field 3
646
+ jmpif v1 then: b2, else: b3
647
+ b2():
648
+ jmp b4(Field 0)
649
+ b4(v2: Field): // header of inner loop
650
+ v3 = lt v2, Field 4
651
+ jmpif v3 then: b5, else: b6
652
+ b5():
653
+ v4 = add v0, v2
654
+ v5 = lt Field 10, v4
655
+ constrain v5 == Field 1
656
+ v6 = add v2, Field 1
657
+ jmp b4(v6)
658
+ b6(): // end of inner loop
659
+ v7 = add v0, Field 1
660
+ jmp b1(v7)
661
+ b3(): // end of outer loop
662
+ return Field 0
663
+ }
664
+ " ;
665
+ let ssa = Ssa :: from_str ( src) . unwrap ( ) ;
666
+
667
+ let expected = "
668
+ acir(inline) fn main f0 {
669
+ b0():
670
+ constrain u1 0 == Field 1
671
+ constrain u1 0 == Field 1
672
+ constrain u1 0 == Field 1
673
+ constrain u1 0 == Field 1
674
+ jmp b1()
675
+ b1():
676
+ constrain u1 0 == Field 1
677
+ constrain u1 0 == Field 1
678
+ constrain u1 0 == Field 1
679
+ constrain u1 0 == Field 1
680
+ jmp b2()
681
+ b2():
682
+ constrain u1 0 == Field 1
683
+ constrain u1 0 == Field 1
684
+ constrain u1 0 == Field 1
685
+ constrain u1 0 == Field 1
686
+ jmp b3()
687
+ b3():
688
+ jmp b4()
689
+ b4():
690
+ return Field 0
691
+ }
692
+ " ;
693
+
750
694
// The final block count is not 1 because unrolling creates some unnecessary jmps.
751
695
// If a simplify cfg pass is ran afterward, the expected block count will be 1.
752
696
let ( ssa, errors) = ssa. try_to_unroll_loops ( ) ;
753
697
assert_eq ! ( errors. len( ) , 0 , "All loops should be unrolled" ) ;
754
698
assert_eq ! ( ssa. main( ) . reachable_blocks( ) . len( ) , 5 ) ;
699
+
700
+ assert_normalized_ssa_equals ( ssa, expected) ;
755
701
}
756
702
757
703
// Test that the pass can still be run on loops which fail to unroll properly
758
704
#[ test]
759
705
fn fail_to_unroll_loop ( ) {
760
- // fn main f0 {
761
- // b0(v0: Field):
762
- // jmp b1(v0)
763
- // b1(v1: Field):
764
- // v2 = lt v1, 5
765
- // jmpif v2, then: b2, else: b3
766
- // b2():
767
- // v3 = add v1, Field 1
768
- // jmp b1(v3)
769
- // b3():
770
- // return Field 0
771
- // }
772
- let main_id = Id :: test_new ( 0 ) ;
773
- let mut builder = FunctionBuilder :: new ( "main" . into ( ) , main_id) ;
774
-
775
- let b1 = builder. insert_block ( ) ;
776
- let b2 = builder. insert_block ( ) ;
777
- let b3 = builder. insert_block ( ) ;
778
-
779
- let v0 = builder. add_parameter ( Type :: field ( ) ) ;
780
- let v1 = builder. add_block_parameter ( b1, Type :: field ( ) ) ;
781
-
782
- builder. terminate_with_jmp ( b1, vec ! [ v0] ) ;
783
-
784
- builder. switch_to_block ( b1) ;
785
- let five = builder. field_constant ( 5u128 ) ;
786
- let v2 = builder. insert_binary ( v1, BinaryOp :: Lt , five) ;
787
- builder. terminate_with_jmpif ( v2, b2, b3) ;
788
-
789
- builder. switch_to_block ( b2) ;
790
- let one = builder. field_constant ( 1u128 ) ;
791
- let v3 = builder. insert_binary ( v1, BinaryOp :: Add , one) ;
792
- builder. terminate_with_jmp ( b1, vec ! [ v3] ) ;
793
-
794
- builder. switch_to_block ( b3) ;
795
- let zero = builder. field_constant ( 0u128 ) ;
796
- builder. terminate_with_return ( vec ! [ zero] ) ;
706
+ let src = "
707
+ acir(inline) fn main f0 {
708
+ b0(v0: Field):
709
+ jmp b1(v0)
710
+ b1(v1: Field):
711
+ v2 = lt v1, Field 5
712
+ jmpif v2 then: b2, else: b3
713
+ b2():
714
+ v3 = add v1, Field 1
715
+ jmp b1(v3)
716
+ b3():
717
+ return Field 0
718
+ }
719
+ " ;
720
+ let ssa = Ssa :: from_str ( src) . unwrap ( ) ;
797
721
798
- let ssa = builder . finish ( ) ;
722
+ // Sanity check
799
723
assert_eq ! ( ssa. main( ) . reachable_blocks( ) . len( ) , 4 ) ;
800
724
801
725
// Expected that we failed to unroll the loop
0 commit comments