Skip to content

Commit d8db656

Browse files
authored
chore(avm): bugfixing witness generation for add, sub, mul for FF (#9938)
1 parent b36c137 commit d8db656

File tree

3 files changed

+68
-38
lines changed

3 files changed

+68
-38
lines changed

barretenberg/cpp/pil/avm/alu.pil

+1
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ namespace alu(256);
147147

148148
// This holds the product over the integers
149149
// (u1 multiplication only cares about a_lo and b_lo)
150+
// TODO(9937): The following is not well constrained as this expression overflows the field.
150151
pol PRODUCT = a_lo * b_lo + (1 - u1_tag) * (LIMB_BITS_POW * partial_prod_lo + MAX_BITS_POW * (partial_prod_hi + a_hi * b_hi));
151152

152153
// =============== ADDITION/SUBTRACTION Operation Constraints =================================================

barretenberg/cpp/src/barretenberg/vm/avm/tests/arithmetic.test.cpp

+24-15
Original file line numberDiff line numberDiff line change
@@ -403,74 +403,83 @@ std::vector<std::array<FF, 3>> positive_op_div_test_values = { {
403403
// Test on basic addition over finite field type.
404404
TEST_F(AvmArithmeticTestsFF, addition)
405405
{
406-
std::vector<FF> const calldata = { 37, 4, 11 };
406+
const FF a = FF::modulus - 19;
407+
const FF b = FF::modulus - 5;
408+
const FF c = FF::modulus - 24; // c = a + b
409+
std::vector<FF> const calldata = { a, b, 4 };
407410
gen_trace_builder(calldata);
408411
trace_builder.op_set(0, 0, 0, AvmMemoryTag::U32);
409412
trace_builder.op_set(0, 3, 1, AvmMemoryTag::U32);
410413
trace_builder.op_calldata_copy(0, 0, 1, 0);
411414

412-
// Memory layout: [37,4,11,0,0,0,....]
413-
trace_builder.op_add(0, 0, 1, 4); // [37,4,11,0,41,0,....]
415+
// Memory layout: [a,b,4,0,0,....]
416+
trace_builder.op_add(0, 0, 1, 4); // [a,b,4,0,c,0,....]
414417
trace_builder.op_set(0, 5, 100, AvmMemoryTag::U32);
415418
trace_builder.op_return(0, 0, 100);
416419
auto trace = trace_builder.finalize();
417420

418-
auto alu_row = common_validate_add(trace, FF(37), FF(4), FF(41), FF(0), FF(1), FF(4), AvmMemoryTag::FF);
421+
auto alu_row = common_validate_add(trace, a, b, c, FF(0), FF(1), FF(4), AvmMemoryTag::FF);
419422

420423
EXPECT_EQ(alu_row.alu_ff_tag, FF(1));
421424
EXPECT_EQ(alu_row.alu_cf, FF(0));
422425

423-
std::vector<FF> const returndata = { 37, 4, 11, 0, 41 };
426+
std::vector<FF> const returndata = { a, b, 4, 0, c };
424427

425428
validate_trace(std::move(trace), public_inputs, calldata, returndata);
426429
}
427430

428431
// Test on basic subtraction over finite field type.
429432
TEST_F(AvmArithmeticTestsFF, subtraction)
430433
{
431-
std::vector<FF> const calldata = { 8, 4, 17 };
434+
const FF a = 8;
435+
const FF b = FF::modulus - 5;
436+
const FF c = 13; // c = a - b
437+
std::vector<FF> const calldata = { b, 4, a };
432438
gen_trace_builder(calldata);
433439
trace_builder.op_set(0, 0, 0, AvmMemoryTag::U32);
434440
trace_builder.op_set(0, 3, 1, AvmMemoryTag::U32);
435441
trace_builder.op_calldata_copy(0, 0, 1, 0);
436442

437-
// Memory layout: [8,4,17,0,0,0,....]
438-
trace_builder.op_sub(0, 2, 0, 1); // [8,9,17,0,0,0....]
443+
// Memory layout: [b,4,a,0,0,0,....]
444+
trace_builder.op_sub(0, 2, 0, 1); // [b,c,a,0,0,0....]
439445
trace_builder.op_set(0, 3, 100, AvmMemoryTag::U32);
440446
trace_builder.op_return(0, 0, 100);
441447
auto trace = trace_builder.finalize();
442448

443-
auto alu_row = common_validate_sub(trace, FF(17), FF(8), FF(9), FF(2), FF(0), FF(1), AvmMemoryTag::FF);
449+
auto alu_row = common_validate_sub(trace, a, b, c, FF(2), FF(0), FF(1), AvmMemoryTag::FF);
444450

445451
EXPECT_EQ(alu_row.alu_ff_tag, FF(1));
446452
EXPECT_EQ(alu_row.alu_cf, FF(0));
447453

448-
std::vector<FF> const returndata = { 8, 9, 17 };
454+
std::vector<FF> const returndata = { b, c, a };
449455
validate_trace(std::move(trace), public_inputs, calldata, returndata);
450456
}
451457

452458
// Test on basic multiplication over finite field type.
453459
TEST_F(AvmArithmeticTestsFF, multiplication)
454460
{
455-
std::vector<FF> const calldata = { 5, 0, 20 };
461+
const FF a = FF::modulus - 1;
462+
const FF b = 278;
463+
const FF c = FF::modulus - 278;
464+
std::vector<FF> const calldata = { b, 0, a };
456465
gen_trace_builder(calldata);
457466
trace_builder.op_set(0, 0, 0, AvmMemoryTag::U32);
458467
trace_builder.op_set(0, 3, 1, AvmMemoryTag::U32);
459468
trace_builder.op_calldata_copy(0, 0, 1, 0);
460469

461-
// Memory layout: [5,0,20,0,0,0,....]
462-
trace_builder.op_mul(0, 2, 0, 1); // [5,100,20,0,0,0....]
470+
// Memory layout: [b,0,a,0,0,0,....]
471+
trace_builder.op_mul(0, 2, 0, 1); // [b,c,a,0,0,0....]
463472
trace_builder.op_set(0, 3, 100, AvmMemoryTag::U32);
464473
trace_builder.op_return(0, 0, 100);
465474
auto trace = trace_builder.finalize();
466475

467-
auto alu_row_index = common_validate_mul(trace, FF(20), FF(5), FF(100), FF(2), FF(0), FF(1), AvmMemoryTag::FF);
476+
auto alu_row_index = common_validate_mul(trace, a, b, c, FF(2), FF(0), FF(1), AvmMemoryTag::FF);
468477
auto alu_row = trace.at(alu_row_index);
469478

470479
EXPECT_EQ(alu_row.alu_ff_tag, FF(1));
471480
EXPECT_EQ(alu_row.alu_cf, FF(0));
472481

473-
std::vector<FF> const returndata = { 5, 100, 20 };
482+
std::vector<FF> const returndata = { b, c, a };
474483
validate_trace(std::move(trace), public_inputs, calldata, returndata);
475484
}
476485

barretenberg/cpp/src/barretenberg/vm/avm/trace/alu_trace.cpp

+43-23
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,14 @@ void AvmAluTraceBuilder::reset()
106106
FF AvmAluTraceBuilder::op_add(FF const& a, FF const& b, AvmMemoryTag in_tag, uint32_t const clk)
107107
{
108108
bool carry = false;
109-
uint256_t c_u256 = uint256_t(a) + uint256_t(b);
110-
FF c = cast_to_mem_tag(c_u256, in_tag);
109+
FF c;
110+
111+
if (in_tag == AvmMemoryTag::FF) {
112+
c = a + b;
113+
} else {
114+
uint256_t c_u256 = uint256_t(a) + uint256_t(b);
115+
c = cast_to_mem_tag(c_u256, in_tag);
111116

112-
if (in_tag != AvmMemoryTag::FF) {
113117
// a_u128 + b_u128 >= 2^128 <==> c_u128 < a_u128
114118
if (uint128_t(c) < uint128_t(a)) {
115119
carry = true;
@@ -150,10 +154,14 @@ FF AvmAluTraceBuilder::op_add(FF const& a, FF const& b, AvmMemoryTag in_tag, uin
150154
FF AvmAluTraceBuilder::op_sub(FF const& a, FF const& b, AvmMemoryTag in_tag, uint32_t const clk)
151155
{
152156
bool carry = false;
153-
uint256_t c_u256 = uint256_t(a) - uint256_t(b);
154-
FF c = cast_to_mem_tag(c_u256, in_tag);
157+
FF c;
158+
159+
if (in_tag == AvmMemoryTag::FF) {
160+
c = a - b;
161+
} else {
162+
uint256_t c_u256 = uint256_t(a) - uint256_t(b);
163+
c = cast_to_mem_tag(c_u256, in_tag);
155164

156-
if (in_tag != AvmMemoryTag::FF) {
157165
// Underflow when a_u128 < b_u128
158166
if (uint128_t(a) < uint128_t(b)) {
159167
carry = true;
@@ -189,29 +197,41 @@ FF AvmAluTraceBuilder::op_sub(FF const& a, FF const& b, AvmMemoryTag in_tag, uin
189197
*/
190198
FF AvmAluTraceBuilder::op_mul(FF const& a, FF const& b, AvmMemoryTag in_tag, uint32_t const clk)
191199
{
192-
uint256_t a_u256{ a };
193-
uint256_t b_u256{ b };
194-
uint256_t c_u256 = a_u256 * b_u256; // Multiplication over the integers (not mod. 2^128)
200+
FF c = 0;
201+
uint256_t alu_a_lo = 0;
202+
uint256_t alu_a_hi = 0;
203+
uint256_t alu_b_lo = 0;
204+
uint256_t alu_b_hi = 0;
205+
uint256_t c_hi = 0;
206+
uint256_t partial_prod_lo = 0;
207+
uint256_t partial_prod_hi = 0;
195208

196-
FF c = cast_to_mem_tag(c_u256, in_tag);
209+
if (in_tag == AvmMemoryTag::FF) {
210+
c = a * b;
211+
} else {
197212

198-
uint8_t bits = mem_tag_bits(in_tag);
199-
// limbs are size 1 for u1
200-
uint8_t limb_bits = bits == 1 ? 1 : bits / 2;
201-
uint8_t num_bits = bits;
213+
uint256_t a_u256{ a };
214+
uint256_t b_u256{ b };
215+
uint256_t c_u256 = a_u256 * b_u256; // Multiplication over the integers (not mod. 2^128)
202216

203-
// Decompose a
204-
auto [alu_a_lo, alu_a_hi] = decompose(a_u256, limb_bits);
205-
// Decompose b
206-
auto [alu_b_lo, alu_b_hi] = decompose(b_u256, limb_bits);
217+
c = cast_to_mem_tag(c_u256, in_tag);
207218

208-
uint256_t partial_prod = alu_a_lo * alu_b_hi + alu_a_hi * alu_b_lo;
209-
// Decompose the partial product
210-
auto [partial_prod_lo, partial_prod_hi] = decompose(partial_prod, limb_bits);
219+
uint8_t bits = mem_tag_bits(in_tag);
220+
// limbs are size 1 for u1
221+
uint8_t limb_bits = bits == 1 ? 1 : bits / 2;
222+
uint8_t num_bits = bits;
211223

212-
auto c_hi = c_u256 >> num_bits;
224+
// Decompose a
225+
std::tie(alu_a_lo, alu_a_hi) = decompose(a_u256, limb_bits);
226+
// Decompose b
227+
std::tie(alu_b_lo, alu_b_hi) = decompose(b_u256, limb_bits);
228+
229+
uint256_t partial_prod = alu_a_lo * alu_b_hi + alu_a_hi * alu_b_lo;
230+
// Decompose the partial product
231+
std::tie(partial_prod_lo, partial_prod_hi) = decompose(partial_prod, limb_bits);
232+
233+
c_hi = c_u256 >> num_bits;
213234

214-
if (in_tag != AvmMemoryTag::FF) {
215235
cmp_builder.range_check_builder.assert_range(uint128_t(c), mem_tag_bits(in_tag), EventEmitter::ALU, clk);
216236
}
217237

0 commit comments

Comments
 (0)