Skip to content

Commit ca21820

Browse files
authored
fix: shift right overflow in ACIR with unknown var now returns zero (#7509)
1 parent ebaff44 commit ca21820

File tree

4 files changed

+59
-0
lines changed

4 files changed

+59
-0
lines changed

compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs

+48
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ impl Context<'_> {
167167
let lhs_typ = self.function.dfg.type_of_value(lhs).unwrap_numeric();
168168
let base = self.field_constant(FieldElement::from(2_u128));
169169
let pow = self.pow(base, rhs);
170+
let pow = self.pow_or_max_for_bit_size(pow, rhs, bit_size, lhs_typ);
170171
let pow = self.insert_cast(pow, lhs_typ);
171172
if lhs_typ.is_unsigned() {
172173
// unsigned right bit shift is just a normal division
@@ -205,6 +206,53 @@ impl Context<'_> {
205206
}
206207
}
207208

209+
/// Returns `pow` or the maximum value allowed for `typ` if 2^rhs is guaranteed to exceed that maximum.
210+
fn pow_or_max_for_bit_size(
211+
&mut self,
212+
pow: ValueId,
213+
rhs: ValueId,
214+
bit_size: u32,
215+
typ: NumericType,
216+
) -> ValueId {
217+
let max = if typ.is_unsigned() {
218+
if bit_size == 128 { u128::MAX } else { (1_u128 << bit_size) - 1 }
219+
} else {
220+
1_u128 << (bit_size - 1)
221+
};
222+
let max = self.field_constant(FieldElement::from(max));
223+
224+
// Here we check whether rhs is less than the bit_size: if it's not then it will overflow.
225+
// Then we do:
226+
//
227+
// rhs_is_less_than_bit_size = lt rhs, bit_size
228+
// rhs_is_not_less_than_bit_size = not rhs_is_less_than_bit_size
229+
// pow_when_is_less_than_bit_size = rhs_is_less_than_bit_size * pow
230+
// pow_when_is_not_less_than_bit_size = rhs_is_not_less_than_bit_size * max
231+
// pow = add pow_when_is_less_than_bit_size, pow_when_is_not_less_than_bit_size
232+
//
233+
// All operations here are unchecked because they work on field types.
234+
let rhs_typ = self.function.dfg.type_of_value(rhs).unwrap_numeric();
235+
let bit_size = self.numeric_constant(bit_size as u128, rhs_typ);
236+
let rhs_is_less_than_bit_size = self.insert_binary(rhs, BinaryOp::Lt, bit_size);
237+
let rhs_is_not_less_than_bit_size = self.insert_not(rhs_is_less_than_bit_size);
238+
let rhs_is_less_than_bit_size =
239+
self.insert_cast(rhs_is_less_than_bit_size, NumericType::NativeField);
240+
let rhs_is_not_less_than_bit_size =
241+
self.insert_cast(rhs_is_not_less_than_bit_size, NumericType::NativeField);
242+
let pow_when_is_less_than_bit_size =
243+
self.insert_binary(rhs_is_less_than_bit_size, BinaryOp::Mul { unchecked: true }, pow);
244+
let pow_when_is_not_less_than_bit_size = self.insert_binary(
245+
rhs_is_not_less_than_bit_size,
246+
BinaryOp::Mul { unchecked: true },
247+
max,
248+
);
249+
self.insert_binary(
250+
pow_when_is_less_than_bit_size,
251+
BinaryOp::Add { unchecked: true },
252+
pow_when_is_not_less_than_bit_size,
253+
)
254+
}
255+
208256
/// Computes lhs^rhs via square&multiply, using the bits decomposition of rhs
209257
/// Pseudo-code of the computation:
210258
/// let mut r = 1;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[package]
2+
name = "shift_right_overflow"
3+
type = "bin"
4+
authors = [""]
5+
[dependencies]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
x = 9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
fn main(x: u8) {
2+
// This would previously overflow in ACIR. Now it returns zero.
3+
let value = 1 >> x;
4+
assert_eq(value, 0);
5+
}

0 commit comments

Comments
 (0)