Skip to content

Commit 01c70d7

Browse files
authored
Use O(log n/eps) budget for CPU distributed sketch (#12150)
1 parent 325eecc commit 01c70d7

3 files changed

Lines changed: 90 additions & 18 deletions

File tree

src/common/quantile.cc

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,14 @@ template <typename T>
8888
}
8989

9090
// Serialization payload for distributed numerical sketch merging over AllreduceV.
91-
// Encodes per-feature entry counts plus contiguous sketch entries.
91+
// Encodes per-feature represented element counts, per-feature entry counts, and contiguous
92+
// sketch entries.
9293
struct SketchReducePayload {
9394
[[nodiscard]] static std::vector<std::byte> SerializeFromSummaries(
9495
Span<bst_feature_t const> numeric_features,
95-
std::vector<WQuantileSketch::SummaryContainer> const &reduced) {
96+
std::vector<WQuantileSketch::SummaryContainer> const &reduced,
97+
std::vector<std::size_t> const &num_elements) {
98+
CHECK_EQ(reduced.size(), num_elements.size());
9699
std::size_t total_entries = 0;
97100
for (auto fidx : numeric_features) {
98101
total_entries += reduced.at(fidx).Size();
@@ -103,23 +106,33 @@ struct SketchReducePayload {
103106

104107
for (std::size_t i = 0; i < numeric_features.size(); ++i) {
105108
auto fidx = numeric_features[i];
109+
SetNumElements(&bytes, numeric_features.size(), i, num_elements.at(fidx));
106110
auto out_entries = reduced.at(fidx).Entries();
107-
AppendEntries(&bytes, i, out_entries);
111+
AppendEntries(&bytes, numeric_features.size(), i, out_entries);
108112
}
109113
auto header_bytes = HeaderBytes(numeric_features.size());
110114
CHECK_EQ((bytes.size() - header_bytes) / sizeof(WQuantileSketch::Entry), total_entries);
111115
return bytes;
112116
}
113117

114118
[[nodiscard]] static std::size_t HeaderBytes(std::size_t n_features) {
115-
return sizeof(std::uint64_t) + n_features * sizeof(std::uint64_t);
119+
return sizeof(std::uint64_t) + 2 * n_features * sizeof(std::uint64_t);
116120
}
117121

118-
static void AppendEntries(std::vector<std::byte> *bytes, std::size_t i,
119-
Span<WQuantileSketch::Entry const> entries) {
122+
static void SetNumElements(std::vector<std::byte> *bytes, std::size_t n_features, std::size_t i,
123+
std::size_t num_elements) {
120124
CHECK(bytes);
121125
auto count_offset = sizeof(std::uint64_t) + i * sizeof(std::uint64_t);
122-
CHECK_LE(count_offset + sizeof(std::uint64_t), bytes->size());
126+
CHECK_LE(count_offset + sizeof(std::uint64_t), HeaderBytes(n_features));
127+
WritePODAt<std::uint64_t>(bytes, count_offset, static_cast<std::uint64_t>(num_elements));
128+
}
129+
130+
static void AppendEntries(std::vector<std::byte> *bytes, std::size_t n_features, std::size_t i,
131+
Span<WQuantileSketch::Entry const> entries) {
132+
CHECK(bytes);
133+
auto count_offset =
134+
sizeof(std::uint64_t) + n_features * sizeof(std::uint64_t) + i * sizeof(std::uint64_t);
135+
CHECK_LE(count_offset + sizeof(std::uint64_t), HeaderBytes(n_features));
123136
WritePODAt<std::uint64_t>(bytes, count_offset, static_cast<std::uint64_t>(entries.size()));
124137
if (entries.empty()) {
125138
return;
@@ -143,6 +156,11 @@ struct SketchReducePayload {
143156
std::size_t cursor = 0;
144157
auto n_features = ReadPOD<std::uint64_t>(bytes, &cursor);
145158

159+
std::vector<std::size_t> num_elements(n_features, 0);
160+
for (std::size_t i = 0; i < n_features; ++i) {
161+
num_elements[i] = static_cast<std::size_t>(ReadPOD<std::uint64_t>(bytes, &cursor));
162+
}
163+
146164
std::vector<std::size_t> offsets(n_features + 1, 0);
147165
for (std::size_t i = 0; i < n_features; ++i) {
148166
auto n_i = static_cast<std::size_t>(ReadPOD<std::uint64_t>(bytes, &cursor));
@@ -161,11 +179,13 @@ struct SketchReducePayload {
161179
entries = reinterpret_cast<WQuantileSketch::Entry *>(ptr);
162180
}
163181

164-
return {std::move(offsets), Span<WQuantileSketch::Entry>{entries, n_entries}};
182+
return {std::move(offsets), std::move(num_elements),
183+
Span<WQuantileSketch::Entry>{entries, n_entries}};
165184
}
166185

167186
[[nodiscard]] std::size_t NumFeatures() const { return offsets_.size() - 1; }
168187
[[nodiscard]] std::size_t TotalEntries() const { return entries_.size(); }
188+
[[nodiscard]] std::size_t NumElements(std::size_t idx) const { return num_elements_.at(idx); }
169189

170190
[[nodiscard]] Span<WQuantileSketch::Entry> Entries(std::size_t idx) const {
171191
auto beg = offsets_.at(idx);
@@ -183,10 +203,12 @@ struct SketchReducePayload {
183203
}
184204

185205
private:
186-
SketchReducePayload(std::vector<std::size_t> offsets, Span<WQuantileSketch::Entry> entries)
187-
: offsets_{std::move(offsets)}, entries_{entries} {}
206+
SketchReducePayload(std::vector<std::size_t> offsets, std::vector<std::size_t> num_elements,
207+
Span<WQuantileSketch::Entry> entries)
208+
: offsets_{std::move(offsets)}, num_elements_{std::move(num_elements)}, entries_{entries} {}
188209

189210
std::vector<std::size_t> offsets_;
211+
std::vector<std::size_t> num_elements_;
190212
Span<WQuantileSketch::Entry> entries_;
191213
};
192214

@@ -428,12 +450,14 @@ auto HostSketchContainer::AllReduce(Context const *ctx, MetaInfo const &info,
428450
CHECK_EQ(n_columns, sketches_.size()) << "Number of columns differs across workers";
429451

430452
std::vector<WQSketch::SummaryContainer> reduced(sketches_.size());
453+
std::vector<std::size_t> num_elements(sketches_.size(), 0);
431454

432-
// Cap the per-feature summary size during local and distributed merge.
433-
auto const max_cut_target = static_cast<std::size_t>(max_bins_ * WQSketch::kFactor);
455+
// Size local summaries with the same O(log n / eps) budget as the single-machine sketch.
434456
ParallelFor(numeric_features.size(), n_threads_, [&](size_t idx) {
435457
auto fidx = numeric_features[idx];
436-
reduced[fidx] = sketches_[fidx].GetSummary(max_cut_target);
458+
num_elements[fidx] = sketches_[fidx].NumElements();
459+
auto cut_target = SketchSummaryBudget(max_bins_, num_elements[fidx]);
460+
reduced[fidx] = sketches_[fidx].GetSummary(cut_target);
437461
});
438462

439463
// Early exit: no allreduce needed when one worker, column-split, or no numeric features.
@@ -444,9 +468,8 @@ auto HostSketchContainer::AllReduce(Context const *ctx, MetaInfo const &info,
444468

445469
// Serialize local sketches to a byte array for allreduce
446470
auto merged = SketchReducePayload::SerializeFromSummaries(
447-
Span<bst_feature_t const>{numeric_features}, reduced);
471+
Span<bst_feature_t const>{numeric_features}, reduced, num_elements);
448472
WQSketch::SummaryContainer tmp;
449-
tmp.Reserve(max_cut_target * 2); // workspace for merging sketches during allreduce
450473
auto reduce_rc = collective::AllreduceV(
451474
ctx, &merged,
452475
[&](common::Span<std::byte const> a, common::Span<std::byte const> b,
@@ -459,19 +482,27 @@ auto HostSketchContainer::AllReduce(Context const *ctx, MetaInfo const &info,
459482
CHECK_EQ(b_payload.NumFeatures(), numeric_features.size());
460483

461484
auto max_entries = a_payload.TotalEntries() + b_payload.TotalEntries();
462-
auto max_pruned_entries = max_cut_target * numeric_features.size();
485+
std::size_t max_pruned_entries{0};
486+
for (std::size_t i = 0; i < numeric_features.size(); ++i) {
487+
auto num_elements = a_payload.NumElements(i) + b_payload.NumElements(i);
488+
max_pruned_entries += SketchSummaryBudget(max_bins_, num_elements);
489+
}
463490
max_entries = std::min(max_entries, max_pruned_entries);
464491
SketchReducePayload::InitHeader(out, numeric_features.size(), max_entries);
465492

466493
for (std::size_t i = 0; i < numeric_features.size(); ++i) {
467494
auto a_summary = a_payload.SummaryAt(i);
468495
auto b_summary = b_payload.SummaryAt(i);
496+
auto num_elements = a_payload.NumElements(i) + b_payload.NumElements(i);
497+
auto cut_target = SketchSummaryBudget(max_bins_, num_elements);
498+
tmp.Reserve(a_summary.Size() + b_summary.Size());
469499
tmp.CopyFrom(a_summary);
470500
tmp.SetCombine(b_summary);
471-
tmp.SetPrune(max_cut_target);
501+
tmp.SetPrune(cut_target);
472502

503+
SketchReducePayload::SetNumElements(out, numeric_features.size(), i, num_elements);
473504
auto pruned_entries = tmp.Entries();
474-
SketchReducePayload::AppendEntries(out, i, pruned_entries);
505+
SketchReducePayload::AppendEntries(out, numeric_features.size(), i, pruned_entries);
475506
}
476507
});
477508
collective::SafeColl(reduce_rc);

src/common/quantile.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,8 @@ class WQuantileSketch {
575575
level_.clear();
576576
}
577577

578+
[[nodiscard]] size_t NumElements() const { return num_elements_; }
579+
578580
static size_t LimitSizeLevel(size_t maxn, double eps) {
579581
if (maxn == 0) {
580582
// Empty columns can appear in distributed column-split settings.
@@ -604,6 +606,7 @@ class WQuantileSketch {
604606
*/
605607
void Push(bst_float x, bst_float w = 1) {
606608
if (w == static_cast<bst_float>(0)) return;
609+
++num_elements_;
607610
if (!inqueue_.Push(x, w)) {
608611
inqueue_.PopSummary(&temp_);
609612
this->PushSummary(&temp_);
@@ -621,6 +624,14 @@ class WQuantileSketch {
621624
void PushSorted(common::Span<::xgboost::Entry const> column, std::vector<float> const &weights,
622625
size_t num_retained_items) {
623626
CHECK_GE(num_retained_items, 1);
627+
if (weights.empty()) {
628+
num_elements_ += column.size();
629+
} else {
630+
num_elements_ +=
631+
std::count_if(column.cbegin(), column.cend(), [&](::xgboost::Entry const &entry) {
632+
return weights[entry.index] != static_cast<float>(0);
633+
});
634+
}
624635
auto const max_size = num_retained_items;
625636
this->temp_.Reserve(max_size + 1);
626637
this->temp_.SetPruneSorted(column, weights, max_size);
@@ -715,6 +726,8 @@ class WQuantileSketch {
715726
WQSummaryContainer temp_;
716727
// reusable workspace for combine-prune operations
717728
std::vector<Entry> combine_workspace_;
729+
// Number of source elements represented by this sketch.
730+
size_t num_elements_{0};
718731
};
719732

720733
[[nodiscard]] inline double SketchEpsilon(bst_bin_t max_bins, std::size_t num_elements) {

tests/cpp/common/test_quantile.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,34 @@ TEST(Quantile, InitWithEmptyColumn) {
3737
ASSERT_EQ(out.Size(), 0);
3838
}
3939

40+
TEST(Quantile, TrackSketchElements) {
41+
WQuantileSketch sketch{16, 0.1};
42+
ASSERT_EQ(sketch.NumElements(), 0);
43+
44+
sketch.Push(0.1f);
45+
sketch.Push(0.2f, 0.0f);
46+
sketch.Push(0.3f, 2.0f);
47+
sketch.Push(0.9f);
48+
49+
ASSERT_EQ(sketch.NumElements(), 3);
50+
auto out = sketch.GetSummary(4);
51+
ASSERT_GT(out.Size(), 0);
52+
ASSERT_EQ(sketch.NumElements(), 3);
53+
}
54+
55+
TEST(Quantile, TrackSketchElementsSorted) {
56+
WQuantileSketch sketch{16, 0.1};
57+
std::vector<::xgboost::Entry> column{{0, 0.1f}, {1, 0.2f}, {2, 0.8f}, {3, 0.9f}};
58+
std::vector<float> weights{1.0f, 0.0f, 1.0f, 1.0f};
59+
60+
sketch.PushSorted(Span<::xgboost::Entry const>{column.data(), column.size()}, weights, 2);
61+
62+
ASSERT_EQ(sketch.NumElements(), 3);
63+
auto out = sketch.GetSummary(4);
64+
ASSERT_GT(out.Size(), 0);
65+
ASSERT_EQ(sketch.NumElements(), 3);
66+
}
67+
4068
TEST(Quantile, SetPruneInplace) {
4169
using Summary = WQSummary<>;
4270
using Entry = Summary::Entry;

0 commit comments

Comments
 (0)