Skip to content

Commit b8bdb52

Browse files
authored
fix: AVM witgen track gas for nested calls and external halts (#10731)
Resolves #10033 Resolves #10374 This PR does the following: - Witgen handles out-of-gas errors for all opcodes - all halts (return/revert/exceptional) work as follows: - charge gas for the problematic instruction as always, adding a row to the gas trace - pop the parent/caller's latest gas from the stack - call a helper function on the gas trace to mutate that most recent gas row, returning to the parent's latest gas minus any consumed gas (all gas consumed on exceptional halt) - `GasTraceEntry` includes a field `is_halt_or_first_row_in_nested_call` which lets us break gas rules on a halt or when starting a nested call because in both cases gas will jump. - `constrain_gas` returns a bool `out_of_gas` so that opcode implementations can handle out of gas - `write_to_memory` now has an option to skip the "jump back to correct pc" which was problematic when halting because the `jump` wouldn't result in a next row with the right pc Explanation on how gas works for calls: - Parent snapshots its gas right before a nested call in `ctx.*_gas_left` - Nested call is given a `ctx.start_*_gas_left` and the gas trace is forced to that same value - throughout the nested call, the gas trace operates normally, charging per instruction - when any halt is encountered, the instruction that halted must have its gas charged normally, but then we call a helper function on the gas trace to mutate the most recent row, flagging it to eventually become a sort of "fake" row that skips some constraints - the mutation of the halting row resets the gas to the parents last gas before the call (minus however much gas was consumed by the nested call... if exceptional halt, that is _all_ allocated gas) Follow-up work - properly constrain gas for nested calls, returns, reverts and exceptional halts - if `jump` exceptionally halts (i.e. out of gas), it should be okay that the next row doesn't have the target pc - Handle the edge case when an error is encountered on return/revert/call, but after the stack has already been modified
1 parent 962a7a2 commit b8bdb52

File tree

16 files changed

+738
-231
lines changed

16 files changed

+738
-231
lines changed

barretenberg/cpp/pil/avm/gas.pil

+4-4
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ namespace main(256);
6464
is_fake_row * (1 - is_fake_row) = 0; // Temporary
6565

6666
//TODO(8945): clean up fake row related code
67-
#[L2_GAS_NO_DECREMENT_FAKE_ROW]
68-
is_fake_row * (l2_gas_remaining - l2_gas_remaining') = 0;
69-
#[DA_GAS_NO_DECREMENT_FAKE_ROW]
70-
is_fake_row * (da_gas_remaining - da_gas_remaining') = 0;
67+
//#[L2_GAS_NO_DECREMENT_FAKE_ROW]
68+
//is_fake_row * (l2_gas_remaining - l2_gas_remaining') = 0;
69+
//#[DA_GAS_NO_DECREMENT_FAKE_ROW]
70+
//is_fake_row * (da_gas_remaining - da_gas_remaining') = 0;
7171

7272
// Constrain that the gas decrements correctly per instruction
7373
#[L2_GAS_REMAINING_DECREMENT_NOT_CALL]

barretenberg/cpp/src/barretenberg/vm/avm/generated/relations/gas.hpp

+15-35
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ template <typename FF_> class gasImpl {
1010
public:
1111
using FF = FF_;
1212

13-
static constexpr std::array<size_t, 12> SUBRELATION_PARTIAL_LENGTHS = { 3, 3, 3, 3, 3, 3, 5, 5, 4, 4, 2, 2 };
13+
static constexpr std::array<size_t, 10> SUBRELATION_PARTIAL_LENGTHS = { 3, 3, 3, 3, 5, 5, 4, 4, 2, 2 };
1414

1515
template <typename ContainerOverSubrelations, typename AllEntities>
1616
void static accumulate(ContainerOverSubrelations& evals,
@@ -46,67 +46,53 @@ template <typename FF_> class gasImpl {
4646
}
4747
{
4848
using Accumulator = typename std::tuple_element_t<4, ContainerOverSubrelations>;
49-
auto tmp =
50-
(new_term.main_is_fake_row * (new_term.main_l2_gas_remaining - new_term.main_l2_gas_remaining_shift));
51-
tmp *= scaling_factor;
52-
std::get<4>(evals) += typename Accumulator::View(tmp);
53-
}
54-
{
55-
using Accumulator = typename std::tuple_element_t<5, ContainerOverSubrelations>;
56-
auto tmp =
57-
(new_term.main_is_fake_row * (new_term.main_da_gas_remaining - new_term.main_da_gas_remaining_shift));
58-
tmp *= scaling_factor;
59-
std::get<5>(evals) += typename Accumulator::View(tmp);
60-
}
61-
{
62-
using Accumulator = typename std::tuple_element_t<6, ContainerOverSubrelations>;
6349
auto tmp = ((new_term.main_is_gas_accounted *
6450
((FF(1) - new_term.main_sel_op_external_call) - new_term.main_sel_op_static_call)) *
6551
(((new_term.main_l2_gas_remaining_shift - new_term.main_l2_gas_remaining) +
6652
new_term.main_base_l2_gas_op_cost) +
6753
(new_term.main_dyn_l2_gas_op_cost * new_term.main_dyn_gas_multiplier)));
6854
tmp *= scaling_factor;
69-
std::get<6>(evals) += typename Accumulator::View(tmp);
55+
std::get<4>(evals) += typename Accumulator::View(tmp);
7056
}
7157
{
72-
using Accumulator = typename std::tuple_element_t<7, ContainerOverSubrelations>;
58+
using Accumulator = typename std::tuple_element_t<5, ContainerOverSubrelations>;
7359
auto tmp = ((new_term.main_is_gas_accounted *
7460
((FF(1) - new_term.main_sel_op_external_call) - new_term.main_sel_op_static_call)) *
7561
(((new_term.main_da_gas_remaining_shift - new_term.main_da_gas_remaining) +
7662
new_term.main_base_da_gas_op_cost) +
7763
(new_term.main_dyn_da_gas_op_cost * new_term.main_dyn_gas_multiplier)));
7864
tmp *= scaling_factor;
79-
std::get<7>(evals) += typename Accumulator::View(tmp);
65+
std::get<5>(evals) += typename Accumulator::View(tmp);
8066
}
8167
{
82-
using Accumulator = typename std::tuple_element_t<8, ContainerOverSubrelations>;
68+
using Accumulator = typename std::tuple_element_t<6, ContainerOverSubrelations>;
8369
auto tmp = (new_term.main_is_gas_accounted *
8470
(((FF(1) - (FF(2) * new_term.main_l2_out_of_gas)) * new_term.main_l2_gas_remaining_shift) -
8571
new_term.main_abs_l2_rem_gas));
8672
tmp *= scaling_factor;
87-
std::get<8>(evals) += typename Accumulator::View(tmp);
73+
std::get<6>(evals) += typename Accumulator::View(tmp);
8874
}
8975
{
90-
using Accumulator = typename std::tuple_element_t<9, ContainerOverSubrelations>;
76+
using Accumulator = typename std::tuple_element_t<7, ContainerOverSubrelations>;
9177
auto tmp = (new_term.main_is_gas_accounted *
9278
(((FF(1) - (FF(2) * new_term.main_da_out_of_gas)) * new_term.main_da_gas_remaining_shift) -
9379
new_term.main_abs_da_rem_gas));
9480
tmp *= scaling_factor;
95-
std::get<9>(evals) += typename Accumulator::View(tmp);
81+
std::get<7>(evals) += typename Accumulator::View(tmp);
9682
}
9783
{
98-
using Accumulator = typename std::tuple_element_t<10, ContainerOverSubrelations>;
84+
using Accumulator = typename std::tuple_element_t<8, ContainerOverSubrelations>;
9985
auto tmp = (new_term.main_abs_l2_rem_gas -
10086
(new_term.main_l2_gas_u16_r0 + (new_term.main_l2_gas_u16_r1 * FF(65536))));
10187
tmp *= scaling_factor;
102-
std::get<10>(evals) += typename Accumulator::View(tmp);
88+
std::get<8>(evals) += typename Accumulator::View(tmp);
10389
}
10490
{
105-
using Accumulator = typename std::tuple_element_t<11, ContainerOverSubrelations>;
91+
using Accumulator = typename std::tuple_element_t<9, ContainerOverSubrelations>;
10692
auto tmp = (new_term.main_abs_da_rem_gas -
10793
(new_term.main_da_gas_u16_r0 + (new_term.main_da_gas_u16_r1 * FF(65536))));
10894
tmp *= scaling_factor;
109-
std::get<11>(evals) += typename Accumulator::View(tmp);
95+
std::get<9>(evals) += typename Accumulator::View(tmp);
11096
}
11197
}
11298
};
@@ -121,23 +107,17 @@ template <typename FF> class gas : public Relation<gasImpl<FF>> {
121107
case 0:
122108
return "IS_GAS_ACCOUNTED";
123109
case 4:
124-
return "L2_GAS_NO_DECREMENT_FAKE_ROW";
125-
case 5:
126-
return "DA_GAS_NO_DECREMENT_FAKE_ROW";
127-
case 6:
128110
return "L2_GAS_REMAINING_DECREMENT_NOT_CALL";
129-
case 7:
111+
case 5:
130112
return "DA_GAS_REMAINING_DECREMENT_NOT_CALL";
131113
}
132114
return std::to_string(index);
133115
}
134116

135117
// Subrelation indices constants, to be used in tests.
136118
static constexpr size_t SR_IS_GAS_ACCOUNTED = 0;
137-
static constexpr size_t SR_L2_GAS_NO_DECREMENT_FAKE_ROW = 4;
138-
static constexpr size_t SR_DA_GAS_NO_DECREMENT_FAKE_ROW = 5;
139-
static constexpr size_t SR_L2_GAS_REMAINING_DECREMENT_NOT_CALL = 6;
140-
static constexpr size_t SR_DA_GAS_REMAINING_DECREMENT_NOT_CALL = 7;
119+
static constexpr size_t SR_L2_GAS_REMAINING_DECREMENT_NOT_CALL = 4;
120+
static constexpr size_t SR_DA_GAS_REMAINING_DECREMENT_NOT_CALL = 5;
141121
};
142122

143123
} // namespace bb::avm

barretenberg/cpp/src/barretenberg/vm/avm/tests/arithmetic.test.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,10 @@ class AvmArithmeticTests : public ::testing::Test {
227227
.nested_returndata = {},
228228
.last_pc = 0,
229229
.success_offset = 0,
230-
.l2_gas = 0,
231-
.da_gas = 0,
230+
.start_l2_gas_left = 0,
231+
.start_da_gas_left = 0,
232+
.l2_gas_left = 0,
233+
.da_gas_left = 0,
232234
.internal_return_ptr_stack = {} });
233235
trace_builder.current_ext_call_ctx = ext_call_ctx;
234236
}

barretenberg/cpp/src/barretenberg/vm/avm/tests/cast.test.cpp

+8-4
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,10 @@ TEST_F(AvmCastTests, truncationFFToU16ModMinus1)
193193
.nested_returndata = {},
194194
.last_pc = 0,
195195
.success_offset = 0,
196-
.l2_gas = 0,
197-
.da_gas = 0,
196+
.start_l2_gas_left = 0,
197+
.start_da_gas_left = 0,
198+
.l2_gas_left = 0,
199+
.da_gas_left = 0,
198200
.internal_return_ptr_stack = {} });
199201
trace_builder.current_ext_call_ctx = ext_call_ctx;
200202
trace_builder.op_set(0, 0, 0, AvmMemoryTag::U32);
@@ -222,8 +224,10 @@ TEST_F(AvmCastTests, truncationFFToU16ModMinus2)
222224
.nested_returndata = {},
223225
.last_pc = 0,
224226
.success_offset = 0,
225-
.l2_gas = 0,
226-
.da_gas = 0,
227+
.start_l2_gas_left = 0,
228+
.start_da_gas_left = 0,
229+
.l2_gas_left = 0,
230+
.da_gas_left = 0,
227231
.internal_return_ptr_stack = {} });
228232
trace_builder.current_ext_call_ctx = ext_call_ctx;
229233

barretenberg/cpp/src/barretenberg/vm/avm/tests/slice.test.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@ class AvmSliceTests : public ::testing::Test {
3737
.nested_returndata = {},
3838
.last_pc = 0,
3939
.success_offset = 0,
40-
.l2_gas = 0,
41-
.da_gas = 0,
40+
.start_l2_gas_left = 0,
41+
.start_da_gas_left = 0,
42+
.l2_gas_left = 0,
43+
.da_gas_left = 0,
4244
.internal_return_ptr_stack = {} });
4345
trace_builder.current_ext_call_ctx = ext_call_ctx;
4446
this->calldata = calldata;

barretenberg/cpp/src/barretenberg/vm/avm/trace/errors.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ enum class AvmError : uint32_t {
2020
RADIX_OUT_OF_BOUNDS,
2121
DUPLICATE_NULLIFIER,
2222
SIDE_EFFECT_LIMIT_REACHED,
23+
OUT_OF_GAS,
2324
};
2425

2526
} // namespace bb::avm_trace

barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.cpp

+64-33
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,8 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
428428
// These hints help us to set up first call ctx
429429
uint32_t clk = trace_builder.get_clk();
430430
auto context_id = static_cast<uint8_t>(clk);
431+
uint32_t l2_gas_allocated_to_enqueued_call = trace_builder.get_l2_gas_left();
432+
uint32_t da_gas_allocated_to_enqueued_call = trace_builder.get_da_gas_left();
431433
trace_builder.current_ext_call_ctx = AvmTraceBuilder::ExtCallCtx{
432434
.context_id = context_id,
433435
.parent_id = 0,
@@ -436,10 +438,13 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
436438
.nested_returndata = {},
437439
.last_pc = 0,
438440
.success_offset = 0,
439-
.l2_gas = 0,
440-
.da_gas = 0,
441+
.start_l2_gas_left = l2_gas_allocated_to_enqueued_call,
442+
.start_da_gas_left = da_gas_allocated_to_enqueued_call,
443+
.l2_gas_left = l2_gas_allocated_to_enqueued_call,
444+
.da_gas_left = da_gas_allocated_to_enqueued_call,
441445
.internal_return_ptr_stack = {},
442446
};
447+
trace_builder.allocate_gas_for_call(l2_gas_allocated_to_enqueued_call, da_gas_allocated_to_enqueued_call);
443448
// Find the bytecode based on contract address of the public call request
444449
std::vector<uint8_t> bytecode =
445450
trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address, check_bytecode_membership);
@@ -451,11 +456,13 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
451456
std::stack<uint32_t> debug_counter_stack;
452457
uint32_t counter = 0;
453458
trace_builder.set_call_ptr(context_id);
454-
while (is_ok(error) && (pc = trace_builder.get_pc()) < bytecode.size()) {
459+
while ((pc = trace_builder.get_pc()) < bytecode.size()) {
455460
auto [inst, parse_error] = Deserialization::parse(bytecode, pc);
456-
error = parse_error;
457461

462+
// FIXME: properly handle case when an instruction fails parsing
463+
// especially first instruction in bytecode
458464
if (!is_ok(error)) {
465+
error = parse_error;
459466
break;
460467
}
461468

@@ -848,9 +855,10 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
848855
std::get<uint16_t>(inst.operands.at(3)),
849856
std::get<uint16_t>(inst.operands.at(4)),
850857
std::get<uint16_t>(inst.operands.at(5)));
858+
// TODO: what if an error is encountered on return or call which have already modified stack?
851859
// We hack it in here the logic to change contract address that we are processing
852860
bytecode = trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address,
853-
check_bytecode_membership);
861+
/*check_membership=*/false);
854862
debug_counter_stack.push(counter);
855863
counter = 0;
856864
break;
@@ -864,7 +872,7 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
864872
std::get<uint16_t>(inst.operands.at(5)));
865873
// We hack it in here the logic to change contract address that we are processing
866874
bytecode = trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address,
867-
check_bytecode_membership);
875+
/*check_membership=*/false);
868876
debug_counter_stack.push(counter);
869877
counter = 0;
870878
break;
@@ -873,53 +881,58 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
873881
auto ret = trace_builder.op_return(std::get<uint8_t>(inst.operands.at(0)),
874882
std::get<uint16_t>(inst.operands.at(1)),
875883
std::get<uint16_t>(inst.operands.at(2)));
876-
// We hack it in here the logic to change contract address that we are processing
884+
// did the return opcode hit an exceptional halt?
885+
error = ret.error;
877886
if (ret.is_top_level) {
878-
error = ret.error;
879887
returndata.insert(returndata.end(), ret.return_data.begin(), ret.return_data.end());
880-
881-
} else {
888+
} else if (is_ok(error)) {
889+
// switch back to caller's bytecode
882890
bytecode = trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address,
883-
check_bytecode_membership);
891+
/*check_membership=*/false);
884892
counter = debug_counter_stack.top();
885893
debug_counter_stack.pop();
886894
}
895+
// on error/exceptional-halt, jumping back to parent code is handled at bottom of execution loop
887896
break;
888897
}
889898
case OpCode::REVERT_8: {
890899
info("HIT REVERT_8 ", "[PC=" + std::to_string(pc) + "] " + inst.to_string());
891900
auto ret = trace_builder.op_revert(std::get<uint8_t>(inst.operands.at(0)),
892901
std::get<uint8_t>(inst.operands.at(1)),
893902
std::get<uint8_t>(inst.operands.at(2)));
903+
// error is only set here if the revert opcode hit an exceptional halt
904+
// revert itself does not trigger "error"
905+
error = ret.error;
894906
if (ret.is_top_level) {
895-
error = ret.error;
896907
returndata.insert(returndata.end(), ret.return_data.begin(), ret.return_data.end());
897-
} else {
898-
// change to the current ext call ctx
908+
} else if (is_ok(error)) {
909+
// switch back to caller's bytecode
899910
bytecode = trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address,
900-
check_bytecode_membership);
911+
/*check_membership=*/false);
901912
counter = debug_counter_stack.top();
902913
debug_counter_stack.pop();
903914
}
904-
915+
// on error/exceptional-halt, jumping back to parent code is handled at bottom of execution loop
905916
break;
906917
}
907918
case OpCode::REVERT_16: {
908919
info("HIT REVERT_16 ", "[PC=" + std::to_string(pc) + "] " + inst.to_string());
909920
auto ret = trace_builder.op_revert(std::get<uint8_t>(inst.operands.at(0)),
910921
std::get<uint16_t>(inst.operands.at(1)),
911922
std::get<uint16_t>(inst.operands.at(2)));
923+
// error is only set here if the revert opcode hit an exceptional halt
924+
// revert itself does not trigger "error"
925+
error = ret.error;
912926
if (ret.is_top_level) {
913-
error = ret.error;
914927
returndata.insert(returndata.end(), ret.return_data.begin(), ret.return_data.end());
915-
} else {
916-
// change to the current ext call ctx
928+
} else if (is_ok(error)) {
929+
// switch back to caller's bytecode
917930
bytecode = trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address,
918-
check_bytecode_membership);
931+
/*check_membership=*/false);
919932
counter = debug_counter_stack.top();
920933
debug_counter_stack.pop();
921934
}
922-
935+
// on error/exceptional-halt, jumping back to parent code is handled at bottom of execution loop
923936
break;
924937
}
925938

@@ -987,18 +1000,36 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
9871000
".");
9881001
break;
9891002
}
990-
}
991-
if (!is_ok(error)) {
992-
auto const error_ic = counter - 1; // Need adjustement as counter increment occurs in loop body
993-
std::string reason_prefix = exceptionally_halted(error) ? "exceptional halt" : "REVERT opcode";
994-
info("AVM enqueued call halted due to ",
995-
reason_prefix,
996-
". Error: ",
997-
to_name(error),
998-
" at PC: ",
999-
pc,
1000-
" IC: ",
1001-
error_ic);
1003+
1004+
if (!is_ok(error)) {
1005+
const bool is_top_level = trace_builder.current_ext_call_ctx.context_id == 0;
1006+
1007+
auto const error_ic = counter - 1; // Need adjustement as counter increment occurs in loop body
1008+
std::string call_type = is_top_level ? "enqueued" : "nested";
1009+
info("AVM ",
1010+
call_type,
1011+
" call exceptionally halted. Error: ",
1012+
to_name(error),
1013+
" at PC: ",
1014+
pc,
1015+
" IC: ",
1016+
error_ic);
1017+
1018+
trace_builder.handle_exceptional_halt();
1019+
1020+
if (is_top_level) {
1021+
break;
1022+
}
1023+
// otherwise, handle exceptional halt and proceed with execution in caller/parent
1024+
// We hack it in here the logic to change contract address that we are processing
1025+
bytecode = trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address,
1026+
/*check_membership=*/false);
1027+
counter = debug_counter_stack.top();
1028+
debug_counter_stack.pop();
1029+
1030+
// reset error as we've now returned to caller
1031+
error = AvmError::NO_ERROR;
1032+
}
10021033
}
10031034
return error;
10041035
}

0 commit comments

Comments
 (0)