Skip to content

Commit f5b4076

Browse files
committed
add unit tests
1 parent 88faf2c commit f5b4076

File tree

1 file changed

+142
-2
lines changed

1 file changed

+142
-2
lines changed

compiler/noirc_evaluator/src/ssa/opt/basic_conditional.rs

+142-2
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,8 @@ fn block_cost(block: BasicBlockId, dfg: &DataFlowGraph) -> u32 {
209209
1
210210
},
211211
// if less than 10 elements, it is translated into a store for each element
212-
// if more than 10, it is a loop, so 10 should be a good estimate
213-
Instruction::MakeArray { .. } => 10,
212+
// if more than 10, it is a loop, so 20 should be a good estimate, worst case being 10 stores and ~10 index increments
213+
Instruction::MakeArray { .. } => 20,
214214

215215
Instruction::Allocate
216216
| Instruction::EnableSideEffectsIf { .. }
@@ -380,3 +380,143 @@ impl<'f> Context<'f> {
380380
}
381381
}
382382
}
383+
384+
#[cfg(test)]
385+
mod test {
386+
use crate::ssa::{opt::assert_normalized_ssa_equals, Ssa};
387+
388+
#[test]
389+
fn basic_jmpif() {
390+
let src = "
391+
brillig(inline) fn foo f0 {
392+
b0(v0: u32):
393+
v3 = eq v0, u32 0
394+
jmpif v3 then: b2, else: b1
395+
b1():
396+
jmp b3(u32 5)
397+
b2():
398+
jmp b3(u32 3)
399+
b3(v1: u32):
400+
return v1
401+
}
402+
";
403+
let ssa = Ssa::from_str(src).unwrap();
404+
assert_eq!(ssa.main().reachable_blocks().len(), 4);
405+
406+
let expected = "
407+
brillig(inline) fn foo f0 {
408+
b0(v0: u32):
409+
v2 = eq v0, u32 0
410+
v3 = not v2
411+
v4 = cast v2 as u32
412+
v5 = cast v3 as u32
413+
v7 = unchecked_mul v4, u32 3
414+
v9 = unchecked_mul v5, u32 5
415+
v10 = unchecked_add v7, v9
416+
return v10
417+
}
418+
";
419+
420+
let ssa = ssa.flatten_basic_conditionals();
421+
assert_normalized_ssa_equals(ssa, expected);
422+
}
423+
424+
#[test]
425+
fn array_jmpif() {
426+
let src = r#"
427+
brillig(inline) fn foo f0 {
428+
b0(v0: u32):
429+
v3 = eq v0, u32 5
430+
jmpif v3 then: b2, else: b1
431+
b1():
432+
v6 = make_array b"foo"
433+
jmp b3(v6)
434+
b2():
435+
v10 = make_array b"bar"
436+
jmp b3(v10)
437+
b3(v1: [u8; 3]):
438+
return v1
439+
}
440+
"#;
441+
let ssa = Ssa::from_str(src).unwrap();
442+
assert_eq!(ssa.main().reachable_blocks().len(), 4);
443+
let ssa = ssa.flatten_basic_conditionals();
444+
// make_array is not simplified
445+
assert_normalized_ssa_equals(ssa, src);
446+
}
447+
448+
#[test]
449+
fn nested_jmpifs() {
450+
let src = "
451+
brillig(inline) fn foo f0 {
452+
b0(v0: u32):
453+
v5 = eq v0, u32 5
454+
v6 = not v5
455+
jmpif v5 then: b5, else: b1
456+
b1():
457+
v8 = lt v0, u32 3
458+
jmpif v8 then: b3, else: b2
459+
b2():
460+
v9 = truncate v0 to 2 bits, max_bit_size: 32
461+
jmp b4(v9)
462+
b3():
463+
v10 = truncate v0 to 1 bits, max_bit_size: 32
464+
jmp b4(v10)
465+
b4(v1: u32):
466+
jmp b9(v1)
467+
b5():
468+
v12 = lt u32 2, v0
469+
jmpif v12 then: b7, else: b6
470+
b6():
471+
v13 = truncate v0 to 3 bits, max_bit_size: 32
472+
jmp b8(v13)
473+
b7():
474+
v14 = and v0, u32 2
475+
jmp b8(v14)
476+
b8(v2: u32):
477+
jmp b9(v2)
478+
b9(v3: u32):
479+
return v3
480+
}
481+
";
482+
let ssa = Ssa::from_str(src).unwrap();
483+
assert_eq!(ssa.main().reachable_blocks().len(), 10);
484+
485+
let expected = "
486+
brillig(inline) fn foo f0 {
487+
b0(v0: u32):
488+
v3 = eq v0, u32 5
489+
v4 = not v3
490+
jmpif v3 then: b2, else: b1
491+
b1():
492+
v6 = lt v0, u32 3
493+
v7 = truncate v0 to 1 bits, max_bit_size: 32
494+
v8 = not v6
495+
v9 = truncate v0 to 2 bits, max_bit_size: 32
496+
v10 = cast v6 as u32
497+
v11 = cast v8 as u32
498+
v12 = unchecked_mul v10, v7
499+
v13 = unchecked_mul v11, v9
500+
v14 = unchecked_add v12, v13
501+
jmp b3(v14)
502+
b2():
503+
v16 = lt u32 2, v0
504+
v17 = and v0, u32 2
505+
v18 = not v16
506+
v19 = truncate v0 to 3 bits, max_bit_size: 32
507+
v20 = cast v16 as u32
508+
v21 = cast v18 as u32
509+
v22 = unchecked_mul v20, v17
510+
v23 = unchecked_mul v21, v19
511+
v24 = unchecked_add v22, v23
512+
jmp b3(v24)
513+
b3(v1: u32):
514+
return v1
515+
}
516+
";
517+
518+
let ssa = ssa.flatten_basic_conditionals();
519+
assert_eq!(ssa.main().reachable_blocks().len(), 4);
520+
assert_normalized_ssa_equals(ssa, expected);
521+
}
522+
}

0 commit comments

Comments
 (0)