Skip to content

Commit 2323cd5

Browse files
authored
refactor(bb): use std::span in pippenger for scalars (#8269)
Refactoring stepping stone. Behaves identically Next step would be to use this to allow accessing power of 2 quantities above the std::span size() (with a different wrapper class) so that non-powers-of-2 can be passed directly to pippenger We recently anted to save memory on polynomials. The idea is that instead of rounding up to a power of 2 to make pippenger fast (at cost of memory), we will make a wrapper class that happily pretends it has T{} (i.e. zeroes) anywhere form 0 to nearest rounded up power of 2. For starters this just introduces a std::span, which should behave identically
1 parent 2b8af9e commit 2323cd5

File tree

10 files changed

+97
-75
lines changed

10 files changed

+97
-75
lines changed

barretenberg/cpp/src/barretenberg/benchmark/pippenger_bench/main.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ int pippenger()
7272
scalar_multiplication::pippenger_runtime_state<curve::BN254> state(NUM_POINTS);
7373
std::chrono::steady_clock::time_point time_start = std::chrono::steady_clock::now();
7474
g1::element result = scalar_multiplication::pippenger_unsafe<curve::BN254>(
75-
&scalars[0], reference_string->get_monomial_points(), NUM_POINTS, state);
75+
{ &scalars[0], /*size*/ NUM_POINTS }, reference_string->get_monomial_points(), NUM_POINTS, state);
7676
std::chrono::steady_clock::time_point time_end = std::chrono::steady_clock::now();
7777
std::chrono::microseconds diff = std::chrono::duration_cast<std::chrono::microseconds>(time_end - time_start);
7878
std::cout << "run time: " << diff.count() << "us" << std::endl;

barretenberg/cpp/src/barretenberg/commitment_schemes/commitment_key.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ template <class Curve> class CommitmentKey {
7979
ASSERT(false);
8080
}
8181
return scalar_multiplication::pippenger_unsafe<Curve>(
82-
const_cast<Fr*>(polynomial.data()), srs->get_monomial_points(), degree, pippenger_runtime_state);
82+
polynomial, srs->get_monomial_points(), degree, pippenger_runtime_state);
8383
};
8484

8585
/**
@@ -146,7 +146,7 @@ template <class Curve> class CommitmentKey {
146146

147147
// Call the version of pippenger which assumes all points are distinct
148148
return scalar_multiplication::pippenger_unsafe<Curve>(
149-
scalars.data(), points.data(), scalars.size(), pippenger_runtime_state);
149+
scalars, points.data(), scalars.size(), pippenger_runtime_state);
150150
}
151151
};
152152

barretenberg/cpp/src/barretenberg/commitment_schemes/ipa/ipa.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,13 @@ template <typename Curve_> class IPA {
215215
// Step 6.a (using letters, because doxygen automaticall converts the sublist counters to letters :( )
216216
// L_i = < a_vec_lo, G_vec_hi > + inner_prod_L * aux_generator
217217
L_i = bb::scalar_multiplication::pippenger_without_endomorphism_basis_points<Curve>(
218-
&a_vec[0], &G_vec_local[round_size], round_size, ck->pippenger_runtime_state);
218+
{&a_vec[0], /*size*/ round_size}, &G_vec_local[round_size], round_size, ck->pippenger_runtime_state);
219219
L_i += aux_generator * inner_prod_L;
220220

221221
// Step 6.b
222222
// R_i = < a_vec_hi, G_vec_lo > + inner_prod_R * aux_generator
223223
R_i = bb::scalar_multiplication::pippenger_without_endomorphism_basis_points<Curve>(
224-
&a_vec[round_size], &G_vec_local[0], round_size, ck->pippenger_runtime_state);
224+
{&a_vec[round_size], /*size*/ round_size}, &G_vec_local[0], round_size, ck->pippenger_runtime_state);
225225
R_i += aux_generator * inner_prod_R;
226226

227227
// Step 6.c
@@ -345,7 +345,7 @@ template <typename Curve_> class IPA {
345345
// Step 5.
346346
// Compute C₀ = C' + ∑_{j ∈ [k]} u_j^{-1}L_j + ∑_{j ∈ [k]} u_jR_j
347347
GroupElement LR_sums = bb::scalar_multiplication::pippenger_without_endomorphism_basis_points<Curve>(
348-
&msm_scalars[0], &msm_elements[0], pippenger_size, vk->pippenger_runtime_state);
348+
{&msm_scalars[0], /*size*/ pippenger_size}, &msm_elements[0], pippenger_size, vk->pippenger_runtime_state);
349349
GroupElement C_zero = C_prime + LR_sums;
350350

351351
// Step 6.
@@ -394,7 +394,7 @@ template <typename Curve_> class IPA {
394394
// Step 8.
395395
// Compute G₀
396396
Commitment G_zero = bb::scalar_multiplication::pippenger_without_endomorphism_basis_points<Curve>(
397-
&s_vec[0], &G_vec_local[0], poly_length, vk->pippenger_runtime_state);
397+
{&s_vec[0], /*size*/ poly_length}, &G_vec_local[0], poly_length, vk->pippenger_runtime_state);
398398

399399
// Step 9.
400400
// Receive a₀ from the prover

barretenberg/cpp/src/barretenberg/ecc/scalar_multiplication/scalar_multiplication.cpp

+27-25
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ template <typename Curve>
199199
void compute_wnaf_states(uint64_t* point_schedule,
200200
bool* input_skew_table,
201201
uint64_t* round_counts,
202-
const typename Curve::ScalarField* scalars,
202+
const std::span<const typename Curve::ScalarField> scalars,
203203
const size_t num_initial_points)
204204
{
205205
using Fr = typename Curve::ScalarField;
@@ -857,7 +857,7 @@ typename Curve::Element evaluate_pippenger_rounds(pippenger_runtime_state<Curve>
857857

858858
template <typename Curve>
859859
typename Curve::Element pippenger_internal(typename Curve::AffineElement* points,
860-
typename Curve::ScalarField* scalars,
860+
std::span<const typename Curve::ScalarField> scalars,
861861
const size_t num_initial_points,
862862
pippenger_runtime_state<Curve>& state,
863863
bool handle_edge_cases)
@@ -871,7 +871,7 @@ typename Curve::Element pippenger_internal(typename Curve::AffineElement* points
871871
}
872872

873873
template <typename Curve>
874-
typename Curve::Element pippenger(typename Curve::ScalarField* scalars,
874+
typename Curve::Element pippenger(std::span<const typename Curve::ScalarField> scalars,
875875
typename Curve::AffineElement* points,
876876
const size_t num_initial_points,
877877
pippenger_runtime_state<Curve>& state,
@@ -910,10 +910,9 @@ typename Curve::Element pippenger(typename Curve::ScalarField* scalars,
910910
const auto num_slice_points = static_cast<size_t>(1ULL << slice_bits);
911911

912912
Element result = pippenger_internal(points, scalars, num_slice_points, state, handle_edge_cases);
913-
914913
if (num_slice_points != num_initial_points) {
915914
const uint64_t leftover_points = num_initial_points - num_slice_points;
916-
return result + pippenger(scalars + num_slice_points,
915+
return result + pippenger(scalars.subspan(num_slice_points),
917916
points + static_cast<size_t>(num_slice_points * 2),
918917
static_cast<size_t>(leftover_points),
919918
state,
@@ -938,7 +937,7 @@ typename Curve::Element pippenger(typename Curve::ScalarField* scalars,
938937
*
939938
**/
940939
template <typename Curve>
941-
typename Curve::Element pippenger_unsafe(typename Curve::ScalarField* scalars,
940+
typename Curve::Element pippenger_unsafe(std::span<const typename Curve::ScalarField> scalars,
942941
typename Curve::AffineElement* points,
943942
const size_t num_initial_points,
944943
pippenger_runtime_state<Curve>& state)
@@ -947,10 +946,11 @@ typename Curve::Element pippenger_unsafe(typename Curve::ScalarField* scalars,
947946
}
948947

949948
template <typename Curve>
950-
typename Curve::Element pippenger_without_endomorphism_basis_points(typename Curve::ScalarField* scalars,
951-
typename Curve::AffineElement* points,
952-
const size_t num_initial_points,
953-
pippenger_runtime_state<Curve>& state)
949+
typename Curve::Element pippenger_without_endomorphism_basis_points(
950+
std::span<const typename Curve::ScalarField> scalars,
951+
typename Curve::AffineElement* points,
952+
const size_t num_initial_points,
953+
pippenger_runtime_state<Curve>& state)
954954
{
955955
std::vector<typename Curve::AffineElement> G_mod(num_initial_points * 2);
956956
bb::scalar_multiplication::generate_pippenger_point_table<Curve>(points, &G_mod[0], num_initial_points);
@@ -978,7 +978,7 @@ template void evaluate_addition_chains<curve::BN254>(affine_product_runtime_stat
978978
const size_t max_bucket_bits,
979979
bool handle_edge_cases);
980980
template curve::BN254::Element pippenger_internal<curve::BN254>(curve::BN254::AffineElement* points,
981-
curve::BN254::ScalarField* scalars,
981+
std::span<const curve::BN254::ScalarField> scalars,
982982
const size_t num_initial_points,
983983
pippenger_runtime_state<curve::BN254>& state,
984984
bool handle_edge_cases);
@@ -992,19 +992,19 @@ template curve::BN254::AffineElement* reduce_buckets<curve::BN254>(affine_produc
992992
bool first_round = true,
993993
bool handle_edge_cases = false);
994994

995-
template curve::BN254::Element pippenger<curve::BN254>(curve::BN254::ScalarField* scalars,
995+
template curve::BN254::Element pippenger<curve::BN254>(std::span<const curve::BN254::ScalarField> scalars,
996996
curve::BN254::AffineElement* points,
997997
const size_t num_points,
998998
pippenger_runtime_state<curve::BN254>& state,
999999
bool handle_edge_cases = true);
10001000

1001-
template curve::BN254::Element pippenger_unsafe<curve::BN254>(curve::BN254::ScalarField* scalars,
1001+
template curve::BN254::Element pippenger_unsafe<curve::BN254>(std::span<const curve::BN254::ScalarField> scalars,
10021002
curve::BN254::AffineElement* points,
10031003
const size_t num_initial_points,
10041004
pippenger_runtime_state<curve::BN254>& state);
10051005

10061006
template curve::BN254::Element pippenger_without_endomorphism_basis_points<curve::BN254>(
1007-
curve::BN254::ScalarField* scalars,
1007+
std::span<const curve::BN254::ScalarField> scalars,
10081008
curve::BN254::AffineElement* points,
10091009
const size_t num_initial_points,
10101010
pippenger_runtime_state<curve::BN254>& state);
@@ -1028,11 +1028,12 @@ template void add_affine_points_with_edge_cases<curve::Grumpkin>(curve::Grumpkin
10281028
template void evaluate_addition_chains<curve::Grumpkin>(affine_product_runtime_state<curve::Grumpkin>& state,
10291029
const size_t max_bucket_bits,
10301030
bool handle_edge_cases);
1031-
template curve::Grumpkin::Element pippenger_internal<curve::Grumpkin>(curve::Grumpkin::AffineElement* points,
1032-
curve::Grumpkin::ScalarField* scalars,
1033-
const size_t num_initial_points,
1034-
pippenger_runtime_state<curve::Grumpkin>& state,
1035-
bool handle_edge_cases);
1031+
template curve::Grumpkin::Element pippenger_internal<curve::Grumpkin>(
1032+
curve::Grumpkin::AffineElement* points,
1033+
std::span<const curve::Grumpkin::ScalarField> scalars,
1034+
const size_t num_initial_points,
1035+
pippenger_runtime_state<curve::Grumpkin>& state,
1036+
bool handle_edge_cases);
10361037

10371038
template curve::Grumpkin::Element evaluate_pippenger_rounds<curve::Grumpkin>(
10381039
pippenger_runtime_state<curve::Grumpkin>& state,
@@ -1043,19 +1044,20 @@ template curve::Grumpkin::Element evaluate_pippenger_rounds<curve::Grumpkin>(
10431044
template curve::Grumpkin::AffineElement* reduce_buckets<curve::Grumpkin>(
10441045
affine_product_runtime_state<curve::Grumpkin>& state, bool first_round = true, bool handle_edge_cases = false);
10451046

1046-
template curve::Grumpkin::Element pippenger<curve::Grumpkin>(curve::Grumpkin::ScalarField* scalars,
1047+
template curve::Grumpkin::Element pippenger<curve::Grumpkin>(std::span<const curve::Grumpkin::ScalarField> scalars,
10471048
curve::Grumpkin::AffineElement* points,
10481049
const size_t num_points,
10491050
pippenger_runtime_state<curve::Grumpkin>& state,
10501051
bool handle_edge_cases = true);
10511052

1052-
template curve::Grumpkin::Element pippenger_unsafe<curve::Grumpkin>(curve::Grumpkin::ScalarField* scalars,
1053-
curve::Grumpkin::AffineElement* points,
1054-
const size_t num_initial_points,
1055-
pippenger_runtime_state<curve::Grumpkin>& state);
1053+
template curve::Grumpkin::Element pippenger_unsafe<curve::Grumpkin>(
1054+
std::span<const curve::Grumpkin::ScalarField> scalars,
1055+
curve::Grumpkin::AffineElement* points,
1056+
const size_t num_initial_points,
1057+
pippenger_runtime_state<curve::Grumpkin>& state);
10561058

10571059
template curve::Grumpkin::Element pippenger_without_endomorphism_basis_points<curve::Grumpkin>(
1058-
curve::Grumpkin::ScalarField* scalars,
1060+
std::span<const curve::Grumpkin::ScalarField> scalars,
10591061
curve::Grumpkin::AffineElement* points,
10601062
const size_t num_initial_points,
10611063
pippenger_runtime_state<curve::Grumpkin>& state);

barretenberg/cpp/src/barretenberg/ecc/scalar_multiplication/scalar_multiplication.hpp

+9-8
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ template <typename Curve>
8989
void compute_wnaf_states(uint64_t* point_schedule,
9090
bool* input_skew_table,
9191
uint64_t* round_counts,
92-
const typename Curve::ScalarField* scalars,
92+
std::span<const typename Curve::ScalarField> scalars,
9393
size_t num_initial_points);
9494

9595
template <typename Curve>
@@ -135,7 +135,7 @@ void evaluate_addition_chains(affine_product_runtime_state<Curve>& state,
135135
bool handle_edge_cases);
136136
template <typename Curve>
137137
typename Curve::Element pippenger_internal(typename Curve::AffineElement* points,
138-
typename Curve::ScalarField* scalars,
138+
std::span<const typename Curve::ScalarField> scalars,
139139
size_t num_initial_points,
140140
pippenger_runtime_state<Curve>& state,
141141
bool handle_edge_cases);
@@ -152,23 +152,24 @@ typename Curve::AffineElement* reduce_buckets(affine_product_runtime_state<Curve
152152
bool handle_edge_cases = false);
153153

154154
template <typename Curve>
155-
typename Curve::Element pippenger(typename Curve::ScalarField* scalars,
155+
typename Curve::Element pippenger(std::span<const typename Curve::ScalarField> scalars,
156156
typename Curve::AffineElement* points,
157157
size_t num_initial_points,
158158
pippenger_runtime_state<Curve>& state,
159159
bool handle_edge_cases = true);
160160

161161
template <typename Curve>
162-
typename Curve::Element pippenger_unsafe(typename Curve::ScalarField* scalars,
162+
typename Curve::Element pippenger_unsafe(std::span<const typename Curve::ScalarField> scalars,
163163
typename Curve::AffineElement* points,
164164
size_t num_initial_points,
165165
pippenger_runtime_state<Curve>& state);
166166

167167
template <typename Curve>
168-
typename Curve::Element pippenger_without_endomorphism_basis_points(typename Curve::ScalarField* scalars,
169-
typename Curve::AffineElement* points,
170-
size_t num_initial_points,
171-
pippenger_runtime_state<Curve>& state);
168+
typename Curve::Element pippenger_without_endomorphism_basis_points(
169+
std::span<const typename Curve::ScalarField> scalars,
170+
typename Curve::AffineElement* points,
171+
size_t num_initial_points,
172+
pippenger_runtime_state<Curve>& state);
172173

173174
// Explicit instantiation
174175
// BN254

barretenberg/cpp/src/barretenberg/plonk/proof_system/verifier/verifier.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ template <typename program_settings> bool VerifierBase<program_settings>::verify
182182

183183
g1::element P[2];
184184

185-
P[0] = bb::scalar_multiplication::pippenger<curve::BN254>(&scalars[0], &elements[0], num_elements, state);
185+
P[0] = bb::scalar_multiplication::pippenger<curve::BN254>(
186+
{ &scalars[0], num_elements }, &elements[0], num_elements, state);
186187
P[1] = -(g1::element(PI_Z_OMEGA) * separator_challenge + PI_Z);
187188

188189
if (key->contains_recursive_proof) {

barretenberg/cpp/src/barretenberg/plonk/proof_system/verifier/verifier.test.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ plonk::Verifier generate_verifier(std::shared_ptr<proving_key> circuit_proving_k
3333
commitments.resize(8);
3434

3535
for (size_t i = 0; i < 8; ++i) {
36-
commitments[i] = g1::affine_element(
37-
scalar_multiplication::pippenger<curve::BN254>(poly_coefficients[i].get(),
38-
circuit_proving_key->reference_string->get_monomial_points(),
39-
circuit_proving_key->circuit_size,
40-
state));
36+
commitments[i] = g1::affine_element(scalar_multiplication::pippenger<curve::BN254>(
37+
{ poly_coefficients[i].get(), circuit_proving_key->circuit_size },
38+
circuit_proving_key->reference_string->get_monomial_points(),
39+
circuit_proving_key->circuit_size,
40+
state));
4141
}
4242

4343
auto crs = std::make_shared<bb::srs::factories::FileVerifierCrs<curve::BN254>>("../srs_db/ignition");

barretenberg/cpp/src/barretenberg/plonk/work_queue/work_queue.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ void work_queue::process_queue()
214214
// Run pippenger multi-scalar multiplication.
215215
auto runtime_state = bb::scalar_multiplication::pippenger_runtime_state<curve::BN254>(msm_size);
216216
bb::g1::affine_element result(bb::scalar_multiplication::pippenger_unsafe<curve::BN254>(
217-
item.mul_scalars.get(), srs_points, msm_size, runtime_state));
217+
{ item.mul_scalars.get(), msm_size }, srs_points, msm_size, runtime_state));
218218

219219
transcript->add_element(item.tag, result.to_buffer());
220220

0 commit comments

Comments
 (0)