@@ -167,6 +167,7 @@ impl Context<'_> {
167
167
let lhs_typ = self . function . dfg . type_of_value ( lhs) . unwrap_numeric ( ) ;
168
168
let base = self . field_constant ( FieldElement :: from ( 2_u128 ) ) ;
169
169
let pow = self . pow ( base, rhs) ;
170
+ let pow = self . pow_or_max_for_bit_size ( pow, rhs, bit_size, lhs_typ) ;
170
171
let pow = self . insert_cast ( pow, lhs_typ) ;
171
172
if lhs_typ. is_unsigned ( ) {
172
173
// unsigned right bit shift is just a normal division
@@ -205,6 +206,53 @@ impl Context<'_> {
205
206
}
206
207
}
207
208
209
+ /// Returns `pow` or the maximum value allowed for `typ` if 2^rhs is guaranteed to exceed that maximum.
210
+ fn pow_or_max_for_bit_size (
211
+ & mut self ,
212
+ pow : ValueId ,
213
+ rhs : ValueId ,
214
+ bit_size : u32 ,
215
+ typ : NumericType ,
216
+ ) -> ValueId {
217
+ let max = if typ. is_unsigned ( ) {
218
+ if bit_size == 128 { u128:: MAX } else { ( 1_u128 << bit_size) - 1 }
219
+ } else {
220
+ 1_u128 << ( bit_size - 1 )
221
+ } ;
222
+ let max = self . field_constant ( FieldElement :: from ( max) ) ;
223
+
224
+ // Here we check whether rhs is less than the bit_size: if it's not then it will overflow.
225
+ // Then we do:
226
+ //
227
+ // rhs_is_less_than_bit_size = lt rhs, bit_size
228
+ // rhs_is_not_less_than_bit_size = not rhs_is_less_than_bit_size
229
+ // pow_when_is_less_than_bit_size = rhs_is_less_than_bit_size * pow
230
+ // pow_when_is_not_less_than_bit_size = rhs_is_not_less_than_bit_size * max
231
+ // pow = add pow_when_is_less_than_bit_size, pow_when_is_not_less_than_bit_size
232
+ //
233
+ // All operations here are unchecked because they work on field types.
234
+ let rhs_typ = self . function . dfg . type_of_value ( rhs) . unwrap_numeric ( ) ;
235
+ let bit_size = self . numeric_constant ( bit_size as u128 , rhs_typ) ;
236
+ let rhs_is_less_than_bit_size = self . insert_binary ( rhs, BinaryOp :: Lt , bit_size) ;
237
+ let rhs_is_not_less_than_bit_size = self . insert_not ( rhs_is_less_than_bit_size) ;
238
+ let rhs_is_less_than_bit_size =
239
+ self . insert_cast ( rhs_is_less_than_bit_size, NumericType :: NativeField ) ;
240
+ let rhs_is_not_less_than_bit_size =
241
+ self . insert_cast ( rhs_is_not_less_than_bit_size, NumericType :: NativeField ) ;
242
+ let pow_when_is_less_than_bit_size =
243
+ self . insert_binary ( rhs_is_less_than_bit_size, BinaryOp :: Mul { unchecked : true } , pow) ;
244
+ let pow_when_is_not_less_than_bit_size = self . insert_binary (
245
+ rhs_is_not_less_than_bit_size,
246
+ BinaryOp :: Mul { unchecked : true } ,
247
+ max,
248
+ ) ;
249
+ self . insert_binary (
250
+ pow_when_is_less_than_bit_size,
251
+ BinaryOp :: Add { unchecked : true } ,
252
+ pow_when_is_not_less_than_bit_size,
253
+ )
254
+ }
255
+
208
256
/// Computes lhs^rhs via square&multiply, using the bits decomposition of rhs
209
257
/// Pseudo-code of the computation:
210
258
/// let mut r = 1;
0 commit comments