Skip to content

Commit ed6069d

Browse files
authored
fix: pippenger edge case (#22256)
Fixing a rare edge case caused by a bug in radix sort recursive calls
1 parent 0d1b322 commit ed6069d

File tree

4 files changed

+197
-4
lines changed

4 files changed

+197
-4
lines changed

barretenberg/cpp/src/barretenberg/benchmark/goblin_bench/eccvm.bench.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
#include <benchmark/benchmark.h>
22

3+
#include "barretenberg/commitment_schemes/ipa/ipa.hpp"
4+
#include "barretenberg/ecc/curves/bn254/fq.hpp"
35
#include "barretenberg/eccvm/eccvm_circuit_builder.hpp"
46
#include "barretenberg/eccvm/eccvm_prover.hpp"
57
#include "barretenberg/eccvm/eccvm_verifier.hpp"
8+
#include "barretenberg/srs/global_crs.hpp"
69

710
using namespace benchmark;
811
using namespace bb;
@@ -40,6 +43,9 @@ Builder generate_trace(size_t target_num_gates)
4043
op_queue->merge();
4144
}
4245

46+
using Fq = curve::BN254::BaseField;
47+
op_queue->append_hiding_op(Fq::random_element(), Fq::random_element());
48+
4349
Builder builder{ op_queue };
4450
return builder;
4551
}
@@ -63,12 +69,35 @@ void eccvm_prove(State& state) noexcept
6369
std::shared_ptr<Transcript> prover_transcript = std::make_shared<Transcript>();
6470
ECCVMProver prover(builder, prover_transcript);
6571
for (auto _ : state) {
66-
auto [proof, ipa_claim] = prover.construct_proof();
72+
auto [proof, opening_claim] = prover.construct_proof();
73+
auto ipa_transcript = std::make_shared<Transcript>();
74+
IPA<Flavor::Curve>::compute_opening_proof(prover.key->commitment_key, opening_claim, ipa_transcript);
75+
};
76+
}
77+
78+
void eccvm_ipa(State& state) noexcept
79+
{
80+
size_t target_num_gates = 1 << static_cast<size_t>(state.range(0));
81+
Builder builder = generate_trace(target_num_gates);
82+
std::shared_ptr<Transcript> prover_transcript = std::make_shared<Transcript>();
83+
ECCVMProver prover(builder, prover_transcript);
84+
auto [proof, opening_claim] = prover.construct_proof();
85+
for (auto _ : state) {
86+
auto ipa_transcript = std::make_shared<Transcript>();
87+
IPA<Flavor::Curve>::compute_opening_proof(prover.key->commitment_key, opening_claim, ipa_transcript);
6788
};
6889
}
6990

7091
BENCHMARK(eccvm_generate_prover)->Unit(kMillisecond)->DenseRange(12, CONST_ECCVM_LOG_N);
7192
BENCHMARK(eccvm_prove)->Unit(kMillisecond)->DenseRange(12, CONST_ECCVM_LOG_N);
93+
BENCHMARK(eccvm_ipa)->Unit(kMillisecond)->DenseRange(12, CONST_ECCVM_LOG_N);
7294
} // namespace
7395

74-
BENCHMARK_MAIN();
96+
int main(int argc, char** argv)
97+
{
98+
bb::srs::init_file_crs_factory(bb::srs::bb_crs_path());
99+
benchmark::Initialize(&argc, argv);
100+
benchmark::RunSpecifiedBenchmarks();
101+
benchmark::Shutdown();
102+
return 0;
103+
}

barretenberg/cpp/src/barretenberg/commitment_schemes/commitment_key.test.cpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11

22
#include "barretenberg/commitment_schemes/commitment_key.hpp"
3+
#include "barretenberg/common/thread.hpp"
34
#include "barretenberg/srs/global_crs.hpp"
45

56
#include <gtest/gtest.h>
@@ -125,6 +126,89 @@ template <typename Curve> class CommitmentKeyTest : public ::testing::Test {
125126
Commitment expected = commit_naive(ck, poly);
126127
EXPECT_EQ(expected, commitment);
127128
}
129+
130+
// Regression test for a zero-counting bug in Pippenger's MSD radix sort
131+
// (sort_point_schedule_and_count_zero_buckets in process_buckets.cpp).
132+
//
133+
// The bug: the recursive radix sort passed `keys` instead of `top_level_keys` when recursing,
134+
// causing the zero-entry counter to be overwritten by non-zero-bucket counts when the sort
135+
// uses 3+ recursion levels. The inflated count makes the MSM skip valid point contributions.
136+
//
137+
// When does 3-level recursion occur?
138+
// - Pippenger chooses bits_per_slice via a cost model (get_optimal_log_num_buckets).
139+
// - bits_per_slice > 16 pads to 24 bits -> initial_shift=16 -> 3 levels (shift 16->8->0).
140+
// - For BN254 (254-bit scalars), bits_per_slice=17 at ~4.6M+ points per work unit.
141+
// - Multi-threading splits MSM across cores, so each work unit is total_points/num_threads.
142+
// On a 32-core machine, a single work unit reaches 4.6M at ~150M total points.
143+
// - Single-threaded execution (WASM, resource-constrained environments) hits the threshold
144+
// at 4.6M points directly.
145+
//
146+
// Polynomial design (deterministic, all coefficients non-zero):
147+
// get_scalar_slice extracts bits MSB-first. With bits_per_slice=17 and 15 rounds for BN254,
148+
// round 13 extracts bits [16:33) of each scalar. We choose scalar values so that round 13
149+
// has the bucket distribution needed to trigger the overwrite:
150+
//
151+
// 100 coefficients = Fr(1) -> bits [16:33) = 0 -> bucket_index = 0
152+
// 10 coefficients = Fr(2^16) -> bits [16:33) = 1 -> bucket_index = 1 [DROPPED]
153+
// ~5M coefficients = Fr(2^32) -> bits [16:33) = 2^16 -> bucket_index = 65536 [OVERWRITES]
154+
//
155+
// Fr(1) entries must be non-zero (zero scalars are filtered before the MSM) but still
156+
// land in bucket 0 for round 13. They ensure point_schedule[0] has bucket_index=0 after
157+
// sorting, bypassing the post-sort safety check in sort_point_schedule_and_count_zero_buckets.
158+
//
159+
// The bug overwrites num_zero_entries from 100 (correct) to ~5M (count at bucket 65536).
160+
// The MSM span then starts ~5M entries into the sorted schedule, skipping all 10 target
161+
// entries with bucket_index=1 and silently dropping their contributions.
162+
//
163+
// This layout is chosen for efficiency (~1.5s) and full determinism (no random scalars).
164+
// The reference commitment is computed by chunking into 1M-point sub-MSMs, each using
165+
// bits_per_slice <= 15 (2-level sort, bug-free).
166+
void test_pippenger_zero_count_regression()
167+
{
168+
constexpr size_t n = 5000000;
169+
CK ck(n);
170+
171+
Polynomial poly(n);
172+
173+
constexpr size_t num_fake_zeros = 100;
174+
for (size_t i = 0; i < num_fake_zeros; ++i) {
175+
poly.at(i) = Fr(1);
176+
}
177+
178+
constexpr size_t num_targets = 10;
179+
for (size_t i = num_fake_zeros; i < num_fake_zeros + num_targets; ++i) {
180+
poly.at(i) = Fr(65536);
181+
}
182+
183+
for (size_t i = num_fake_zeros + num_targets; i < n; ++i) {
184+
poly.at(i) = Fr(uint256_t(1) << 32);
185+
}
186+
187+
// Commit single-threaded to keep the full point set in one work unit
188+
size_t original_concurrency = get_num_cpus();
189+
set_parallel_for_concurrency(1);
190+
Commitment actual_commitment = ck.commit(poly);
191+
set_parallel_for_concurrency(original_concurrency);
192+
193+
// Reference: sum of chunked sub-MSMs (each chunk uses bits_per_slice <= 15, bug-free)
194+
constexpr size_t chunk_size = 1UL << 20;
195+
auto srs_points = ck.get_monomial_points();
196+
GroupElement correct_sum;
197+
correct_sum.self_set_infinity();
198+
199+
for (size_t offset = 0; offset < n; offset += chunk_size) {
200+
size_t this_chunk = std::min(chunk_size, n - offset);
201+
std::span<const Fr> chunk_coeffs(&poly[offset], this_chunk);
202+
PolynomialSpan<const Fr> chunk_span(0, chunk_coeffs);
203+
std::span<const Commitment> chunk_points = srs_points.subspan(offset, this_chunk);
204+
205+
auto chunk_result = scalar_multiplication::pippenger_unsafe<Curve>(chunk_span, chunk_points);
206+
correct_sum += chunk_result;
207+
}
208+
Commitment correct_commitment(correct_sum);
209+
210+
EXPECT_EQ(actual_commitment, correct_commitment);
211+
}
128212
};
129213

130214
using Curves = ::testing::Types<curve::BN254, curve::Grumpkin>;
@@ -154,5 +238,16 @@ TYPED_TEST(CommitmentKeyTest, CommitWithStartIndex)
154238
{
155239
TestFixture::test_commit_with_start_index();
156240
}
241+
TYPED_TEST(CommitmentKeyTest, DISABLED_PippengerZeroCountRegression)
242+
{
243+
if constexpr (!std::is_same_v<TypeParam, curve::BN254>) {
244+
GTEST_SKIP() << "BN254 only: Grumpkin CRS has insufficient points for the 5M threshold";
245+
}
246+
#ifndef NDEBUG
247+
GTEST_SKIP() << "Too slow in debug builds";
248+
#else
249+
TestFixture::test_pippenger_zero_count_regression();
250+
#endif
251+
}
157252

158253
} // namespace bb

barretenberg/cpp/src/barretenberg/ecc/scalar_multiplication/process_buckets.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,12 @@ void radix_sort_count_zero_entries(uint64_t* keys,
6969
for (size_t i = 0; i < NUM_RADIX_BUCKETS; ++i) {
7070
const size_t bucket_size = offsets_copy[i + 1] - offsets_copy[i];
7171
if (bucket_size > 1) {
72-
radix_sort_count_zero_entries(
73-
&keys[offsets_copy[i]], bucket_size, shift - RADIX_BITS, num_zero_entries, bucket_index_bits, keys);
72+
radix_sort_count_zero_entries(&keys[offsets_copy[i]],
73+
bucket_size,
74+
shift - RADIX_BITS,
75+
num_zero_entries,
76+
bucket_index_bits,
77+
top_level_keys);
7478
}
7579
}
7680
}

barretenberg/cpp/src/barretenberg/ecc/scalar_multiplication/scalar_multiplication.test.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,67 @@ template <class Curve> class ScalarMultiplicationTest : public ::testing::Test {
224224
}
225225
}
226226

227+
// Regression test: radix sort zero-counting bug for bucket_index_bits > 16 (3+ recursion levels).
228+
// The recursive call passes `keys` instead of `top_level_keys`, causing num_zero_entries to be
229+
// overwritten by non-zero-bucket counts when the MSD radix sort recurses 3+ levels deep.
230+
void test_radix_sort_count_zero_entries_wide_buckets()
231+
{
232+
// Use bucket_index_bits = 17, which pads to 24 bits → 3 recursion levels (shift: 16→8→0).
233+
// At the 3rd level, the top_level_keys bug causes zero-counting to fire for every
234+
// level-0 bucket's sub-bucket-0, not just the bucket-0 chain.
235+
constexpr uint32_t bucket_index_bits = 17;
236+
constexpr size_t num_entries = 1000;
237+
238+
std::vector<uint64_t> schedule(num_entries);
239+
240+
// Place some entries with bucket_index = 0 (true zero-bucket entries)
241+
const size_t num_true_zeros = 10;
242+
for (size_t i = 0; i < num_true_zeros; ++i) {
243+
schedule[i] = static_cast<uint64_t>(i) << 32; // point_index=i, bucket_index=0
244+
}
245+
246+
// Place entries with bucket_index = 65536 (= 1 << 16). These have bits [0:16) all zero,
247+
// so the buggy code counts them as zero-bucket entries after the final recursion level
248+
// overwrites num_zero_entries from the level-0 bucket 1 path.
249+
const size_t num_false_zeros = 20;
250+
for (size_t i = 0; i < num_false_zeros; ++i) {
251+
size_t idx = num_true_zeros + i;
252+
schedule[idx] = (static_cast<uint64_t>(idx) << 32) | 65536ULL;
253+
}
254+
255+
// Fill remaining entries with random non-zero bucket indices that won't confuse the count
256+
for (size_t i = num_true_zeros + num_false_zeros; i < num_entries; ++i) {
257+
uint32_t bucket = (engine.get_random_uint32() % ((1U << bucket_index_bits) - 1)) + 1;
258+
// Avoid bucket_index values with all lower 16 bits zero (i.e., multiples of 65536)
259+
if ((bucket & 0xFFFF) == 0) {
260+
bucket |= 1;
261+
}
262+
schedule[i] = (static_cast<uint64_t>(i) << 32) | static_cast<uint64_t>(bucket);
263+
}
264+
265+
size_t result = scalar_multiplication::sort_point_schedule_and_count_zero_buckets(
266+
schedule.data(), num_entries, bucket_index_bits);
267+
268+
// Count actual zero-bucket entries after sort
269+
size_t expected = 0;
270+
for (size_t i = 0; i < num_entries; ++i) {
271+
if ((schedule[i] & scalar_multiplication::BUCKET_INDEX_MASK) == 0) {
272+
expected++;
273+
}
274+
}
275+
276+
EXPECT_EQ(result, expected) << "Zero-bucket count is wrong for bucket_index_bits=" << bucket_index_bits
277+
<< ". Got " << result << ", expected " << expected
278+
<< " (likely overwritten by count from a non-zero bucket)";
279+
280+
// Also verify the array is sorted
281+
for (size_t i = 1; i < num_entries; ++i) {
282+
uint32_t prev = static_cast<uint32_t>(schedule[i - 1]);
283+
uint32_t curr = static_cast<uint32_t>(schedule[i]);
284+
EXPECT_LE(prev, curr) << "Array not sorted at index " << i;
285+
}
286+
}
287+
227288
void test_pippenger_low_memory()
228289
{
229290
std::span<ScalarField> test_scalars(&scalars[0], num_points);
@@ -571,6 +632,10 @@ TYPED_TEST(ScalarMultiplicationTest, RadixSortCountZeroEntries)
571632
{
572633
this->test_radix_sort_count_zero_entries();
573634
}
635+
TYPED_TEST(ScalarMultiplicationTest, RadixSortCountZeroEntriesWideBuckets)
636+
{
637+
this->test_radix_sort_count_zero_entries_wide_buckets();
638+
}
574639
TYPED_TEST(ScalarMultiplicationTest, PippengerLowMemory)
575640
{
576641
this->test_pippenger_low_memory();

0 commit comments

Comments
 (0)