@@ -85,29 +85,116 @@ void MSM<Curve>::transform_scalar_and_get_nonzero_scalar_indices(std::span<typen
8585 });
8686}
8787
88+ template <typename Curve>
89+ void MSM<Curve>::compute_scalar_slice_weights(std::span<const typename Curve::ScalarField> scalars,
90+ std::span<const uint32_t > nonzero_indices,
91+ uint32_t bits_per_slice,
92+ std::vector<uint16_t >& weights) noexcept
93+ {
94+ // weight = ceil(bit_length / bps) + FIXED_PER_SCALAR_WEIGHT. The fixed term approximates the
95+ // O(num_rounds) per-scalar overhead in build_schedule, sort_schedule, and reduce_buckets that
96+ // doesn't scale with bit_length. Without it, threads assigned many lightweight scalars end up
97+ // with disproportionate build/sort/reduce work (empirically observed via per-phase profiling).
98+ // Max is ceil(NUM_BITS_IN_FIELD / 1) + FIXED.
99+ static constexpr uint16_t FIXED_PER_SCALAR_WEIGHT = 4 ;
100+ static_assert (NUM_BITS_IN_FIELD + FIXED_PER_SCALAR_WEIGHT <= std::numeric_limits<uint16_t >::max (),
101+ " slice-count weight overflows uint16_t" );
102+ BB_ASSERT_GT (bits_per_slice, 0U );
103+
104+ const size_t n = nonzero_indices.size ();
105+ weights.resize (n);
106+
107+ parallel_for ([&](const ThreadChunk& chunk) {
108+ for (size_t k : chunk.range (n)) {
109+ const auto & scalar = scalars[nonzero_indices[k]];
110+ // Scalars were filtered for nonzero and are in non-Montgomery form, so get_msb()
111+ // returns a valid bit index in [0, NUM_BITS_IN_FIELD).
112+ const uint64_t msb = uint256_t { scalar.data [0 ], scalar.data [1 ], scalar.data [2 ], scalar.data [3 ] }.get_msb ();
113+ const size_t bit_length = static_cast <size_t >(msb) + 1 ;
114+ weights[k] =
115+ static_cast <uint16_t >((bit_length + bits_per_slice - 1 ) / bits_per_slice) + FIXED_PER_SCALAR_WEIGHT;
116+ }
117+ });
118+ }
119+
120+ template <typename Curve>
121+ std::vector<typename MSM<Curve>::ThreadWorkUnits> MSM<Curve>::partition_by_weight(
122+ std::span<const std::vector<uint16_t >> msm_scalar_weights, size_t num_threads) noexcept
123+ {
124+ BB_ASSERT_GT (num_threads, 0U );
125+ std::vector<ThreadWorkUnits> work_units (num_threads);
126+
127+ size_t grand_total_weight = 0 ;
128+ for (const auto & weights : msm_scalar_weights) {
129+ for (uint16_t w : weights) {
130+ grand_total_weight += w;
131+ }
132+ }
133+ if (grand_total_weight == 0 ) {
134+ return work_units;
135+ }
136+
137+ const size_t weight_per_thread = numeric::ceil_div (grand_total_weight, num_threads);
138+
139+ size_t thread_accumulated_weight = 0 ;
140+ size_t current_thread_idx = 0 ;
141+ for (size_t i = 0 ; i < msm_scalar_weights.size (); ++i) {
142+ const auto & weights = msm_scalar_weights[i];
143+ const size_t n = weights.size ();
144+
145+ size_t start = 0 ;
146+ for (size_t k = 0 ; k < n; ++k) {
147+ thread_accumulated_weight += weights[k];
148+
149+ if (current_thread_idx < num_threads - 1 && thread_accumulated_weight >= weight_per_thread) {
150+ work_units[current_thread_idx].push_back (MSMWorkUnit{
151+ .batch_msm_index = i,
152+ .start_index = start,
153+ .size = k + 1 - start,
154+ });
155+ start = k + 1 ;
156+ current_thread_idx++;
157+ thread_accumulated_weight = 0 ;
158+ }
159+ }
160+ if (start < n) {
161+ work_units[current_thread_idx].push_back (MSMWorkUnit{
162+ .batch_msm_index = i,
163+ .start_index = start,
164+ .size = n - start,
165+ });
166+ }
167+ }
168+ return work_units;
169+ }
170+
88171template <typename Curve>
89172std::vector<typename MSM<Curve>::ThreadWorkUnits> MSM<Curve>::get_work_units(
90173 std::span<std::span<ScalarField>> scalars, std::vector<std::vector<uint32_t >>& msm_scalar_indices) noexcept
91174{
92175 const size_t num_msms = scalars.size ();
93176 msm_scalar_indices.resize (num_msms);
94- for (size_t i = 0 ; i < num_msms; ++i) {
95- transform_scalar_and_get_nonzero_scalar_indices (scalars[i], msm_scalar_indices[i]);
96- }
97177
178+ // Weight scalars by their Pippenger cost (slice count + fixed overhead, see
179+ // compute_scalar_slice_weights) to improve thread balancing.
180+ std::vector<std::vector<uint16_t >> msm_scalar_weights (num_msms);
98181 size_t total_work = 0 ;
99- for (const auto & indices : msm_scalar_indices) {
100- total_work += indices.size ();
182+ for (size_t i = 0 ; i < num_msms; ++i) {
183+ transform_scalar_and_get_nonzero_scalar_indices (scalars[i], msm_scalar_indices[i]);
184+ const size_t n = msm_scalar_indices[i].size ();
185+ total_work += n;
186+ if (n == 0 ) {
187+ continue ;
188+ }
189+ const uint32_t bps = get_optimal_log_num_buckets (n);
190+ compute_scalar_slice_weights (scalars[i], msm_scalar_indices[i], bps, msm_scalar_weights[i]);
101191 }
102192
103193 const size_t num_threads = get_num_cpus ();
104- std::vector<ThreadWorkUnits> work_units (num_threads);
105-
106- const size_t work_per_thread = numeric::ceil_div (total_work, num_threads);
107- const size_t work_of_last_thread = total_work - (work_per_thread * (num_threads - 1 ));
108194
109195 // Only use a single work unit if we don't have enough work for every thread
110196 if (num_threads > total_work) {
197+ std::vector<ThreadWorkUnits> work_units (num_threads);
111198 for (size_t i = 0 ; i < num_msms; ++i) {
112199 work_units[0 ].push_back (MSMWorkUnit{
113200 .batch_msm_index = i,
@@ -118,37 +205,7 @@ std::vector<typename MSM<Curve>::ThreadWorkUnits> MSM<Curve>::get_work_units(
118205 return work_units;
119206 }
120207
121- size_t thread_accumulated_work = 0 ;
122- size_t current_thread_idx = 0 ;
123- for (size_t i = 0 ; i < num_msms; ++i) {
124- size_t msm_work_remaining = msm_scalar_indices[i].size ();
125- const size_t initial_msm_work = msm_work_remaining;
126-
127- while (msm_work_remaining > 0 ) {
128- BB_ASSERT_LT (current_thread_idx, work_units.size ());
129-
130- const size_t total_thread_work =
131- (current_thread_idx == num_threads - 1 ) ? work_of_last_thread : work_per_thread;
132- const size_t available_thread_work = total_thread_work - thread_accumulated_work;
133- const size_t work_to_assign = std::min (available_thread_work, msm_work_remaining);
134-
135- work_units[current_thread_idx].push_back (MSMWorkUnit{
136- .batch_msm_index = i,
137- .start_index = initial_msm_work - msm_work_remaining,
138- .size = work_to_assign,
139- });
140-
141- thread_accumulated_work += work_to_assign;
142- msm_work_remaining -= work_to_assign;
143-
144- // Move to next thread if current thread is full
145- if (thread_accumulated_work >= total_thread_work) {
146- current_thread_idx++;
147- thread_accumulated_work = 0 ;
148- }
149- }
150- }
151- return work_units;
208+ return partition_by_weight (msm_scalar_weights, num_threads);
152209}
153210
154211/* *
0 commit comments