Skip to content

Commit 0a2a38c

Browse files
authored
feat: merge-train/barretenberg (#22888)
BEGIN_COMMIT_OVERRIDE chore: chunk scalars in pip to distribute work evenly (#22627) END_COMMIT_OVERRIDE
2 parents e3ce6bc + 0c188ca commit 0a2a38c

4 files changed

Lines changed: 403 additions & 41 deletions

File tree

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/**
2+
* @brief Pippenger thread-scaling benchmark for heterogeneous scalar distributions.
3+
*
4+
* MSM::batch_multi_scalar_mul partitions work across threads by cumulative per-scalar
5+
* weight (see get_work_units in scalar_multiplication.cpp), where each scalar's weight
6+
* is ceil(bit_length / bits_per_slice) -- i.e. the number of nonzero c-bit slices it
7+
* contributes to bucket accumulation. Small scalars weigh less because their high-order
8+
* slices are zero and get filtered by the zero-bucket pre-sort. This benchmark exercises
9+
* pathological and typical bit-size distributions to verify thread scaling stays uniform.
10+
*
11+
* Distributions contrasted here:
12+
* - Clustered: first half small (32-bit), second half full random -- stresses the
13+
* weighted split; count-based partitioning would give half the threads
14+
* ~all of the heavy work.
15+
* - UniformMixed: small/full randomly interleaved -- isolates heterogeneity alone.
16+
* - AllFull: all full random (z_perm-like baseline).
17+
*
18+
* Expected: all three scale comparably under the weighted partition.
19+
*/
20+
#include "barretenberg/common/thread.hpp"
21+
#include "barretenberg/ecc/curves/bn254/bn254.hpp"
22+
#include "barretenberg/ecc/scalar_multiplication/scalar_multiplication.hpp"
23+
#include "barretenberg/numeric/random/engine.hpp"
24+
#include "barretenberg/srs/global_crs.hpp"
25+
26+
#include <benchmark/benchmark.h>
27+
28+
#include "barretenberg/common/google_bb_bench.hpp"
29+
30+
using namespace benchmark;
31+
32+
using Curve = bb::curve::BN254;
33+
using Fr = Curve::ScalarField;
34+
using G1 = Curve::AffineElement;
35+
36+
namespace {
37+
38+
constexpr size_t MSM_SIZE = 1 << 20;
39+
40+
enum class Distribution { Clustered, UniformMixed, AllFull };
41+
42+
class ThreadScalingBench : public benchmark::Fixture {
43+
public:
44+
std::shared_ptr<bb::srs::factories::Crs<Curve>> srs;
45+
bb::numeric::RNG& engine = bb::numeric::get_debug_randomness();
46+
47+
void SetUp([[maybe_unused]] const ::benchmark::State& state) override
48+
{
49+
if (srs) {
50+
return;
51+
}
52+
bb::srs::init_file_crs_factory(bb::srs::bb_crs_path());
53+
srs = bb::srs::get_crs_factory<Curve>()->get_crs(MSM_SIZE);
54+
}
55+
56+
// 32-bit "small" value -- mimics witness indices, booleans, limbs.
57+
// On BN254 (254-bit field) with ~14 bits per Pippenger slice, only the lowest
58+
// ~2-3 rounds produce nonzero slices for these scalars; the rest get filtered.
59+
Fr small_scalar() { return Fr(static_cast<uint64_t>(engine.get_random_uint32())); }
60+
Fr full_scalar() { return Fr::random_element(&engine); }
61+
62+
std::vector<Fr> build_scalars(Distribution dist)
63+
{
64+
std::vector<Fr> scalars(MSM_SIZE);
65+
switch (dist) {
66+
case Distribution::Clustered:
67+
for (size_t i = 0; i < MSM_SIZE / 2; ++i) {
68+
scalars[i] = small_scalar();
69+
}
70+
for (size_t i = MSM_SIZE / 2; i < MSM_SIZE; ++i) {
71+
scalars[i] = full_scalar();
72+
}
73+
break;
74+
case Distribution::UniformMixed:
75+
for (size_t i = 0; i < MSM_SIZE; ++i) {
76+
scalars[i] = (engine.get_random_uint32() & 1U) ? small_scalar() : full_scalar();
77+
}
78+
break;
79+
case Distribution::AllFull:
80+
for (size_t i = 0; i < MSM_SIZE; ++i) {
81+
scalars[i] = full_scalar();
82+
}
83+
break;
84+
}
85+
return scalars;
86+
}
87+
};
88+
89+
static void run_msm(ThreadScalingBench& fx, benchmark::State& state, Distribution dist)
90+
{
91+
const size_t num_threads = static_cast<size_t>(state.range(0));
92+
93+
// Rebuild per-invocation of the bench is fine: scalars get mutated (Montgomery
94+
// round-trip) inside batch_multi_scalar_mul, and we want consistent input across iterations.
95+
std::vector<Fr> scalars = fx.build_scalars(dist);
96+
97+
std::vector<std::span<Fr>> scalar_spans;
98+
std::vector<std::span<const G1>> point_spans;
99+
scalar_spans.emplace_back(scalars);
100+
point_spans.emplace_back(fx.srs->get_monomial_points().subspan(0, MSM_SIZE));
101+
102+
const size_t original_concurrency = bb::get_num_cpus();
103+
bb::set_parallel_for_concurrency(num_threads);
104+
105+
for (auto _ : state) {
106+
GOOGLE_BB_BENCH_REPORTER(state);
107+
bb::scalar_multiplication::MSM<Curve>::batch_multi_scalar_mul(point_spans, scalar_spans, false);
108+
}
109+
110+
bb::set_parallel_for_concurrency(original_concurrency);
111+
}
112+
113+
BENCHMARK_DEFINE_F(ThreadScalingBench, Clustered)(benchmark::State& state)
114+
{
115+
run_msm(*this, state, Distribution::Clustered);
116+
}
117+
BENCHMARK_DEFINE_F(ThreadScalingBench, UniformMixed)(benchmark::State& state)
118+
{
119+
run_msm(*this, state, Distribution::UniformMixed);
120+
}
121+
BENCHMARK_DEFINE_F(ThreadScalingBench, AllFull)(benchmark::State& state)
122+
{
123+
run_msm(*this, state, Distribution::AllFull);
124+
}
125+
126+
static void ThreadSweep(benchmark::internal::Benchmark* b)
127+
{
128+
for (int64_t t : { 1, 2, 4, 8 }) {
129+
b->Arg(t);
130+
}
131+
}
132+
133+
BENCHMARK_REGISTER_F(ThreadScalingBench, Clustered)->Unit(benchmark::kMillisecond)->Apply(ThreadSweep);
134+
BENCHMARK_REGISTER_F(ThreadScalingBench, UniformMixed)->Unit(benchmark::kMillisecond)->Apply(ThreadSweep);
135+
BENCHMARK_REGISTER_F(ThreadScalingBench, AllFull)->Unit(benchmark::kMillisecond)->Apply(ThreadSweep);
136+
137+
} // namespace
138+
139+
BENCHMARK_MAIN();

barretenberg/cpp/src/barretenberg/ecc/scalar_multiplication/scalar_multiplication.cpp

Lines changed: 97 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -85,29 +85,116 @@ void MSM<Curve>::transform_scalar_and_get_nonzero_scalar_indices(std::span<typen
8585
});
8686
}
8787

88+
template <typename Curve>
89+
void MSM<Curve>::compute_scalar_slice_weights(std::span<const typename Curve::ScalarField> scalars,
90+
std::span<const uint32_t> nonzero_indices,
91+
uint32_t bits_per_slice,
92+
std::vector<uint16_t>& weights) noexcept
93+
{
94+
// weight = ceil(bit_length / bps) + FIXED_PER_SCALAR_WEIGHT. The fixed term approximates the
95+
// O(num_rounds) per-scalar overhead in build_schedule, sort_schedule, and reduce_buckets that
96+
// doesn't scale with bit_length. Without it, threads assigned many lightweight scalars end up
97+
// with disproportionate build/sort/reduce work (empirically observed via per-phase profiling).
98+
// Max is ceil(NUM_BITS_IN_FIELD / 1) + FIXED.
99+
static constexpr uint16_t FIXED_PER_SCALAR_WEIGHT = 4;
100+
static_assert(NUM_BITS_IN_FIELD + FIXED_PER_SCALAR_WEIGHT <= std::numeric_limits<uint16_t>::max(),
101+
"slice-count weight overflows uint16_t");
102+
BB_ASSERT_GT(bits_per_slice, 0U);
103+
104+
const size_t n = nonzero_indices.size();
105+
weights.resize(n);
106+
107+
parallel_for([&](const ThreadChunk& chunk) {
108+
for (size_t k : chunk.range(n)) {
109+
const auto& scalar = scalars[nonzero_indices[k]];
110+
// Scalars were filtered for nonzero and are in non-Montgomery form, so get_msb()
111+
// returns a valid bit index in [0, NUM_BITS_IN_FIELD).
112+
const uint64_t msb = uint256_t{ scalar.data[0], scalar.data[1], scalar.data[2], scalar.data[3] }.get_msb();
113+
const size_t bit_length = static_cast<size_t>(msb) + 1;
114+
weights[k] =
115+
static_cast<uint16_t>((bit_length + bits_per_slice - 1) / bits_per_slice) + FIXED_PER_SCALAR_WEIGHT;
116+
}
117+
});
118+
}
119+
120+
template <typename Curve>
121+
std::vector<typename MSM<Curve>::ThreadWorkUnits> MSM<Curve>::partition_by_weight(
122+
std::span<const std::vector<uint16_t>> msm_scalar_weights, size_t num_threads) noexcept
123+
{
124+
BB_ASSERT_GT(num_threads, 0U);
125+
std::vector<ThreadWorkUnits> work_units(num_threads);
126+
127+
size_t grand_total_weight = 0;
128+
for (const auto& weights : msm_scalar_weights) {
129+
for (uint16_t w : weights) {
130+
grand_total_weight += w;
131+
}
132+
}
133+
if (grand_total_weight == 0) {
134+
return work_units;
135+
}
136+
137+
const size_t weight_per_thread = numeric::ceil_div(grand_total_weight, num_threads);
138+
139+
size_t thread_accumulated_weight = 0;
140+
size_t current_thread_idx = 0;
141+
for (size_t i = 0; i < msm_scalar_weights.size(); ++i) {
142+
const auto& weights = msm_scalar_weights[i];
143+
const size_t n = weights.size();
144+
145+
size_t start = 0;
146+
for (size_t k = 0; k < n; ++k) {
147+
thread_accumulated_weight += weights[k];
148+
149+
if (current_thread_idx < num_threads - 1 && thread_accumulated_weight >= weight_per_thread) {
150+
work_units[current_thread_idx].push_back(MSMWorkUnit{
151+
.batch_msm_index = i,
152+
.start_index = start,
153+
.size = k + 1 - start,
154+
});
155+
start = k + 1;
156+
current_thread_idx++;
157+
thread_accumulated_weight = 0;
158+
}
159+
}
160+
if (start < n) {
161+
work_units[current_thread_idx].push_back(MSMWorkUnit{
162+
.batch_msm_index = i,
163+
.start_index = start,
164+
.size = n - start,
165+
});
166+
}
167+
}
168+
return work_units;
169+
}
170+
88171
template <typename Curve>
89172
std::vector<typename MSM<Curve>::ThreadWorkUnits> MSM<Curve>::get_work_units(
90173
std::span<std::span<ScalarField>> scalars, std::vector<std::vector<uint32_t>>& msm_scalar_indices) noexcept
91174
{
92175
const size_t num_msms = scalars.size();
93176
msm_scalar_indices.resize(num_msms);
94-
for (size_t i = 0; i < num_msms; ++i) {
95-
transform_scalar_and_get_nonzero_scalar_indices(scalars[i], msm_scalar_indices[i]);
96-
}
97177

178+
// Weight scalars by their Pippenger cost (slice count + fixed overhead, see
179+
// compute_scalar_slice_weights) to improve thread balancing.
180+
std::vector<std::vector<uint16_t>> msm_scalar_weights(num_msms);
98181
size_t total_work = 0;
99-
for (const auto& indices : msm_scalar_indices) {
100-
total_work += indices.size();
182+
for (size_t i = 0; i < num_msms; ++i) {
183+
transform_scalar_and_get_nonzero_scalar_indices(scalars[i], msm_scalar_indices[i]);
184+
const size_t n = msm_scalar_indices[i].size();
185+
total_work += n;
186+
if (n == 0) {
187+
continue;
188+
}
189+
const uint32_t bps = get_optimal_log_num_buckets(n);
190+
compute_scalar_slice_weights(scalars[i], msm_scalar_indices[i], bps, msm_scalar_weights[i]);
101191
}
102192

103193
const size_t num_threads = get_num_cpus();
104-
std::vector<ThreadWorkUnits> work_units(num_threads);
105-
106-
const size_t work_per_thread = numeric::ceil_div(total_work, num_threads);
107-
const size_t work_of_last_thread = total_work - (work_per_thread * (num_threads - 1));
108194

109195
// Only use a single work unit if we don't have enough work for every thread
110196
if (num_threads > total_work) {
197+
std::vector<ThreadWorkUnits> work_units(num_threads);
111198
for (size_t i = 0; i < num_msms; ++i) {
112199
work_units[0].push_back(MSMWorkUnit{
113200
.batch_msm_index = i,
@@ -118,37 +205,7 @@ std::vector<typename MSM<Curve>::ThreadWorkUnits> MSM<Curve>::get_work_units(
118205
return work_units;
119206
}
120207

121-
size_t thread_accumulated_work = 0;
122-
size_t current_thread_idx = 0;
123-
for (size_t i = 0; i < num_msms; ++i) {
124-
size_t msm_work_remaining = msm_scalar_indices[i].size();
125-
const size_t initial_msm_work = msm_work_remaining;
126-
127-
while (msm_work_remaining > 0) {
128-
BB_ASSERT_LT(current_thread_idx, work_units.size());
129-
130-
const size_t total_thread_work =
131-
(current_thread_idx == num_threads - 1) ? work_of_last_thread : work_per_thread;
132-
const size_t available_thread_work = total_thread_work - thread_accumulated_work;
133-
const size_t work_to_assign = std::min(available_thread_work, msm_work_remaining);
134-
135-
work_units[current_thread_idx].push_back(MSMWorkUnit{
136-
.batch_msm_index = i,
137-
.start_index = initial_msm_work - msm_work_remaining,
138-
.size = work_to_assign,
139-
});
140-
141-
thread_accumulated_work += work_to_assign;
142-
msm_work_remaining -= work_to_assign;
143-
144-
// Move to next thread if current thread is full
145-
if (thread_accumulated_work >= total_thread_work) {
146-
current_thread_idx++;
147-
thread_accumulated_work = 0;
148-
}
149-
}
150-
}
151-
return work_units;
208+
return partition_by_weight(msm_scalar_weights, num_threads);
152209
}
153210

154211
/**

barretenberg/cpp/src/barretenberg/ecc/scalar_multiplication/scalar_multiplication.hpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,14 @@ template <typename Curve> class MSM {
240240
/** @brief Compute optimal bits per slice by minimizing cost over c in [1, MAX_SLICE_BITS) */
241241
static uint32_t get_optimal_log_num_buckets(size_t num_points) noexcept;
242242

243+
/** @brief Partition per-MSM scalar weights into num_threads work units of approximately
244+
* equal cumulative weight.
245+
* @details Curve-independent and side-effect-free. The walk closes a work unit every time
246+
* the running weight crosses the per-thread target, except on the last thread
247+
* which absorbs any remainder so rounding drift doesn't leave work stranded. */
248+
static std::vector<ThreadWorkUnits> partition_by_weight(std::span<const std::vector<uint16_t>> msm_scalar_weights,
249+
size_t num_threads) noexcept;
250+
243251
/** @brief Process sorted point schedule into bucket accumulators using batched affine additions */
244252
static void batch_accumulate_points_into_buckets(std::span<const uint64_t> point_schedule,
245253
std::span<const AffineElement> points,
@@ -288,7 +296,20 @@ template <typename Curve> class MSM {
288296
static void transform_scalar_and_get_nonzero_scalar_indices(std::span<ScalarField> scalars,
289297
std::vector<uint32_t>& nonzero_scalar_indices) noexcept;
290298

291-
/** @brief Distribute multiple MSMs across threads with balanced point counts */
299+
/** @brief Compute per-scalar slice-count weights ceil(bit_length / bits_per_slice).
300+
* @details Parallel over nonzero_indices. Scalars must be in non-Montgomery form (as left
301+
* by transform_scalar_and_get_nonzero_scalar_indices). Weights drive thread
302+
* partitioning in get_work_units. */
303+
static void compute_scalar_slice_weights(std::span<const ScalarField> scalars,
304+
std::span<const uint32_t> nonzero_indices,
305+
uint32_t bits_per_slice,
306+
std::vector<uint16_t>& weights) noexcept;
307+
308+
/** @brief Distribute multiple MSMs across threads with balanced bucket-accumulation work.
309+
* @details Per-thread assignment is a contiguous range of each MSM's nonzero-scalar
310+
* indices, sized by cumulative slice-count weight ceil(bit_length / c). This is
311+
* the actual number of nonzero c-bit slices a scalar contributes — the quantity
312+
* that drives bucket-accumulation cost. */
292313
static std::vector<ThreadWorkUnits> get_work_units(std::span<std::span<ScalarField>> scalars,
293314
std::vector<std::vector<uint32_t>>& msm_scalar_indices) noexcept;
294315

0 commit comments

Comments
 (0)