@@ -88,11 +88,14 @@ template <typename T>
8888}
8989
9090// Serialization payload for distributed numerical sketch merging over AllreduceV.
91- // Encodes per-feature entry counts plus contiguous sketch entries.
91+ // Encodes per-feature represented element counts, per-feature entry counts, and contiguous
92+ // sketch entries.
9293struct SketchReducePayload {
9394 [[nodiscard]] static std::vector<std::byte> SerializeFromSummaries (
9495 Span<bst_feature_t const > numeric_features,
95- std::vector<WQuantileSketch::SummaryContainer> const &reduced) {
96+ std::vector<WQuantileSketch::SummaryContainer> const &reduced,
97+ std::vector<std::size_t > const &num_elements) {
98+ CHECK_EQ (reduced.size (), num_elements.size ());
9699 std::size_t total_entries = 0 ;
97100 for (auto fidx : numeric_features) {
98101 total_entries += reduced.at (fidx).Size ();
@@ -103,23 +106,33 @@ struct SketchReducePayload {
103106
104107 for (std::size_t i = 0 ; i < numeric_features.size (); ++i) {
105108 auto fidx = numeric_features[i];
109+ SetNumElements (&bytes, numeric_features.size (), i, num_elements.at (fidx));
106110 auto out_entries = reduced.at (fidx).Entries ();
107- AppendEntries (&bytes, i, out_entries);
111+ AppendEntries (&bytes, numeric_features. size (), i, out_entries);
108112 }
109113 auto header_bytes = HeaderBytes (numeric_features.size ());
110114 CHECK_EQ ((bytes.size () - header_bytes) / sizeof (WQuantileSketch::Entry), total_entries);
111115 return bytes;
112116 }
113117
114118 [[nodiscard]] static std::size_t HeaderBytes (std::size_t n_features) {
115- return sizeof (std::uint64_t ) + n_features * sizeof (std::uint64_t );
119+ return sizeof (std::uint64_t ) + 2 * n_features * sizeof (std::uint64_t );
116120 }
117121
118- static void AppendEntries (std::vector<std::byte> *bytes, std::size_t i,
119- Span<WQuantileSketch::Entry const > entries ) {
122+ static void SetNumElements (std::vector<std::byte> *bytes, std:: size_t n_features , std::size_t i,
123+ std:: size_t num_elements ) {
120124 CHECK (bytes);
121125 auto count_offset = sizeof (std::uint64_t ) + i * sizeof (std::uint64_t );
122- CHECK_LE (count_offset + sizeof (std::uint64_t ), bytes->size ());
126+ CHECK_LE (count_offset + sizeof (std::uint64_t ), HeaderBytes (n_features));
127+ WritePODAt<std::uint64_t >(bytes, count_offset, static_cast <std::uint64_t >(num_elements));
128+ }
129+
130+ static void AppendEntries (std::vector<std::byte> *bytes, std::size_t n_features, std::size_t i,
131+ Span<WQuantileSketch::Entry const > entries) {
132+ CHECK (bytes);
133+ auto count_offset =
134+ sizeof (std::uint64_t ) + n_features * sizeof (std::uint64_t ) + i * sizeof (std::uint64_t );
135+ CHECK_LE (count_offset + sizeof (std::uint64_t ), HeaderBytes (n_features));
123136 WritePODAt<std::uint64_t >(bytes, count_offset, static_cast <std::uint64_t >(entries.size ()));
124137 if (entries.empty ()) {
125138 return ;
@@ -143,6 +156,11 @@ struct SketchReducePayload {
143156 std::size_t cursor = 0 ;
144157 auto n_features = ReadPOD<std::uint64_t >(bytes, &cursor);
145158
159+ std::vector<std::size_t > num_elements (n_features, 0 );
160+ for (std::size_t i = 0 ; i < n_features; ++i) {
161+ num_elements[i] = static_cast <std::size_t >(ReadPOD<std::uint64_t >(bytes, &cursor));
162+ }
163+
146164 std::vector<std::size_t > offsets (n_features + 1 , 0 );
147165 for (std::size_t i = 0 ; i < n_features; ++i) {
148166 auto n_i = static_cast <std::size_t >(ReadPOD<std::uint64_t >(bytes, &cursor));
@@ -161,11 +179,13 @@ struct SketchReducePayload {
161179 entries = reinterpret_cast <WQuantileSketch::Entry *>(ptr);
162180 }
163181
164- return {std::move (offsets), Span<WQuantileSketch::Entry>{entries, n_entries}};
182+ return {std::move (offsets), std::move (num_elements),
183+ Span<WQuantileSketch::Entry>{entries, n_entries}};
165184 }
166185
167186 [[nodiscard]] std::size_t NumFeatures () const { return offsets_.size () - 1 ; }
168187 [[nodiscard]] std::size_t TotalEntries () const { return entries_.size (); }
188+ [[nodiscard]] std::size_t NumElements (std::size_t idx) const { return num_elements_.at (idx); }
169189
170190 [[nodiscard]] Span<WQuantileSketch::Entry> Entries (std::size_t idx) const {
171191 auto beg = offsets_.at (idx);
@@ -183,10 +203,12 @@ struct SketchReducePayload {
183203 }
184204
185205 private:
186- SketchReducePayload (std::vector<std::size_t > offsets, Span<WQuantileSketch::Entry> entries)
187- : offsets_{std::move (offsets)}, entries_{entries} {}
206+ SketchReducePayload (std::vector<std::size_t > offsets, std::vector<std::size_t > num_elements,
207+ Span<WQuantileSketch::Entry> entries)
208+ : offsets_{std::move (offsets)}, num_elements_{std::move (num_elements)}, entries_{entries} {}
188209
189210 std::vector<std::size_t > offsets_;
211+ std::vector<std::size_t > num_elements_;
190212 Span<WQuantileSketch::Entry> entries_;
191213};
192214
@@ -428,12 +450,14 @@ auto HostSketchContainer::AllReduce(Context const *ctx, MetaInfo const &info,
428450 CHECK_EQ (n_columns, sketches_.size ()) << " Number of columns differs across workers" ;
429451
430452 std::vector<WQSketch::SummaryContainer> reduced (sketches_.size ());
453+ std::vector<std::size_t > num_elements (sketches_.size (), 0 );
431454
432- // Cap the per-feature summary size during local and distributed merge.
433- auto const max_cut_target = static_cast <std::size_t >(max_bins_ * WQSketch::kFactor );
455+ // Size local summaries with the same O(log n / eps) budget as the single-machine sketch.
434456 ParallelFor (numeric_features.size (), n_threads_, [&](size_t idx) {
435457 auto fidx = numeric_features[idx];
436- reduced[fidx] = sketches_[fidx].GetSummary (max_cut_target);
458+ num_elements[fidx] = sketches_[fidx].NumElements ();
459+ auto cut_target = SketchSummaryBudget (max_bins_, num_elements[fidx]);
460+ reduced[fidx] = sketches_[fidx].GetSummary (cut_target);
437461 });
438462
439463 // Early exit: no allreduce needed when one worker, column-split, or no numeric features.
@@ -444,9 +468,8 @@ auto HostSketchContainer::AllReduce(Context const *ctx, MetaInfo const &info,
444468
445469 // Serialize local sketches to a byte array for allreduce
446470 auto merged = SketchReducePayload::SerializeFromSummaries (
447- Span<bst_feature_t const >{numeric_features}, reduced);
471+ Span<bst_feature_t const >{numeric_features}, reduced, num_elements );
448472 WQSketch::SummaryContainer tmp;
449- tmp.Reserve (max_cut_target * 2 ); // workspace for merging sketches during allreduce
450473 auto reduce_rc = collective::AllreduceV (
451474 ctx, &merged,
452475 [&](common::Span<std::byte const > a, common::Span<std::byte const > b,
@@ -459,19 +482,27 @@ auto HostSketchContainer::AllReduce(Context const *ctx, MetaInfo const &info,
459482 CHECK_EQ (b_payload.NumFeatures (), numeric_features.size ());
460483
461484 auto max_entries = a_payload.TotalEntries () + b_payload.TotalEntries ();
462- auto max_pruned_entries = max_cut_target * numeric_features.size ();
485+ std::size_t max_pruned_entries{0 };
486+ for (std::size_t i = 0 ; i < numeric_features.size (); ++i) {
487+ auto num_elements = a_payload.NumElements (i) + b_payload.NumElements (i);
488+ max_pruned_entries += SketchSummaryBudget (max_bins_, num_elements);
489+ }
463490 max_entries = std::min (max_entries, max_pruned_entries);
464491 SketchReducePayload::InitHeader (out, numeric_features.size (), max_entries);
465492
466493 for (std::size_t i = 0 ; i < numeric_features.size (); ++i) {
467494 auto a_summary = a_payload.SummaryAt (i);
468495 auto b_summary = b_payload.SummaryAt (i);
496+ auto num_elements = a_payload.NumElements (i) + b_payload.NumElements (i);
497+ auto cut_target = SketchSummaryBudget (max_bins_, num_elements);
498+ tmp.Reserve (a_summary.Size () + b_summary.Size ());
469499 tmp.CopyFrom (a_summary);
470500 tmp.SetCombine (b_summary);
471- tmp.SetPrune (max_cut_target );
501+ tmp.SetPrune (cut_target );
472502
503+ SketchReducePayload::SetNumElements (out, numeric_features.size (), i, num_elements);
473504 auto pruned_entries = tmp.Entries ();
474- SketchReducePayload::AppendEntries (out, i, pruned_entries);
505+ SketchReducePayload::AppendEntries (out, numeric_features. size (), i, pruned_entries);
475506 }
476507 });
477508 collective::SafeColl (reduce_rc);
0 commit comments