@@ -209,8 +209,8 @@ fn block_cost(block: BasicBlockId, dfg: &DataFlowGraph) -> u32 {
209
209
1
210
210
} ,
211
211
// if less than 10 elements, it is translated into a store for each element
212
- // if more than 10, it is a loop, so 10 should be a good estimate
213
- Instruction :: MakeArray { .. } => 10 ,
212
+ // if more than 10, it is a loop, so 20 should be a good estimate, worst case being 10 stores and ~10 index increments
213
+ Instruction :: MakeArray { .. } => 20 ,
214
214
215
215
Instruction :: Allocate
216
216
| Instruction :: EnableSideEffectsIf { .. }
@@ -380,3 +380,143 @@ impl<'f> Context<'f> {
380
380
}
381
381
}
382
382
}
383
+
384
+ #[ cfg( test) ]
385
+ mod test {
386
+ use crate :: ssa:: { opt:: assert_normalized_ssa_equals, Ssa } ;
387
+
388
+ #[ test]
389
+ fn basic_jmpif ( ) {
390
+ let src = "
391
+ brillig(inline) fn foo f0 {
392
+ b0(v0: u32):
393
+ v3 = eq v0, u32 0
394
+ jmpif v3 then: b2, else: b1
395
+ b1():
396
+ jmp b3(u32 5)
397
+ b2():
398
+ jmp b3(u32 3)
399
+ b3(v1: u32):
400
+ return v1
401
+ }
402
+ " ;
403
+ let ssa = Ssa :: from_str ( src) . unwrap ( ) ;
404
+ assert_eq ! ( ssa. main( ) . reachable_blocks( ) . len( ) , 4 ) ;
405
+
406
+ let expected = "
407
+ brillig(inline) fn foo f0 {
408
+ b0(v0: u32):
409
+ v2 = eq v0, u32 0
410
+ v3 = not v2
411
+ v4 = cast v2 as u32
412
+ v5 = cast v3 as u32
413
+ v7 = unchecked_mul v4, u32 3
414
+ v9 = unchecked_mul v5, u32 5
415
+ v10 = unchecked_add v7, v9
416
+ return v10
417
+ }
418
+ " ;
419
+
420
+ let ssa = ssa. flatten_basic_conditionals ( ) ;
421
+ assert_normalized_ssa_equals ( ssa, expected) ;
422
+ }
423
+
424
+ #[ test]
425
+ fn array_jmpif ( ) {
426
+ let src = r#"
427
+ brillig(inline) fn foo f0 {
428
+ b0(v0: u32):
429
+ v3 = eq v0, u32 5
430
+ jmpif v3 then: b2, else: b1
431
+ b1():
432
+ v6 = make_array b"foo"
433
+ jmp b3(v6)
434
+ b2():
435
+ v10 = make_array b"bar"
436
+ jmp b3(v10)
437
+ b3(v1: [u8; 3]):
438
+ return v1
439
+ }
440
+ "# ;
441
+ let ssa = Ssa :: from_str ( src) . unwrap ( ) ;
442
+ assert_eq ! ( ssa. main( ) . reachable_blocks( ) . len( ) , 4 ) ;
443
+ let ssa = ssa. flatten_basic_conditionals ( ) ;
444
+ // make_array is not simplified
445
+ assert_normalized_ssa_equals ( ssa, src) ;
446
+ }
447
+
448
+ #[ test]
449
+ fn nested_jmpifs ( ) {
450
+ let src = "
451
+ brillig(inline) fn foo f0 {
452
+ b0(v0: u32):
453
+ v5 = eq v0, u32 5
454
+ v6 = not v5
455
+ jmpif v5 then: b5, else: b1
456
+ b1():
457
+ v8 = lt v0, u32 3
458
+ jmpif v8 then: b3, else: b2
459
+ b2():
460
+ v9 = truncate v0 to 2 bits, max_bit_size: 32
461
+ jmp b4(v9)
462
+ b3():
463
+ v10 = truncate v0 to 1 bits, max_bit_size: 32
464
+ jmp b4(v10)
465
+ b4(v1: u32):
466
+ jmp b9(v1)
467
+ b5():
468
+ v12 = lt u32 2, v0
469
+ jmpif v12 then: b7, else: b6
470
+ b6():
471
+ v13 = truncate v0 to 3 bits, max_bit_size: 32
472
+ jmp b8(v13)
473
+ b7():
474
+ v14 = and v0, u32 2
475
+ jmp b8(v14)
476
+ b8(v2: u32):
477
+ jmp b9(v2)
478
+ b9(v3: u32):
479
+ return v3
480
+ }
481
+ " ;
482
+ let ssa = Ssa :: from_str ( src) . unwrap ( ) ;
483
+ assert_eq ! ( ssa. main( ) . reachable_blocks( ) . len( ) , 10 ) ;
484
+
485
+ let expected = "
486
+ brillig(inline) fn foo f0 {
487
+ b0(v0: u32):
488
+ v3 = eq v0, u32 5
489
+ v4 = not v3
490
+ jmpif v3 then: b2, else: b1
491
+ b1():
492
+ v6 = lt v0, u32 3
493
+ v7 = truncate v0 to 1 bits, max_bit_size: 32
494
+ v8 = not v6
495
+ v9 = truncate v0 to 2 bits, max_bit_size: 32
496
+ v10 = cast v6 as u32
497
+ v11 = cast v8 as u32
498
+ v12 = unchecked_mul v10, v7
499
+ v13 = unchecked_mul v11, v9
500
+ v14 = unchecked_add v12, v13
501
+ jmp b3(v14)
502
+ b2():
503
+ v16 = lt u32 2, v0
504
+ v17 = and v0, u32 2
505
+ v18 = not v16
506
+ v19 = truncate v0 to 3 bits, max_bit_size: 32
507
+ v20 = cast v16 as u32
508
+ v21 = cast v18 as u32
509
+ v22 = unchecked_mul v20, v17
510
+ v23 = unchecked_mul v21, v19
511
+ v24 = unchecked_add v22, v23
512
+ jmp b3(v24)
513
+ b3(v1: u32):
514
+ return v1
515
+ }
516
+ " ;
517
+
518
+ let ssa = ssa. flatten_basic_conditionals ( ) ;
519
+ assert_eq ! ( ssa. main( ) . reachable_blocks( ) . len( ) , 4 ) ;
520
+ assert_normalized_ssa_equals ( ssa, expected) ;
521
+ }
522
+ }
0 commit comments