Skip to content

Commit d84637e

Browse files
authored
Extract weighted quantile cut fixes from #12129 (#12146)
1 parent b2e575d commit d84637e

9 files changed

Lines changed: 123 additions & 175 deletions

File tree

python-package/xgboost/testing/multi_target.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,18 +120,14 @@ def run_absolute_error(device: Device) -> None:
120120
)
121121
Xy = QuantileDMatrix(X, y)
122122
evals_result: Dict[str, Dict] = {}
123-
booster = train(
123+
train(
124124
params,
125125
Xy,
126126
evals=[(Xy, "Train")],
127127
verbose_eval=False,
128128
evals_result=evals_result,
129129
num_boost_round=16,
130130
)
131-
predt = booster.predict(Xy)
132-
# make sure different targets are used
133-
assert np.abs((predt[:, 2] - predt[:, 1]).sum()) > 1000
134-
assert np.abs((predt[:, 1] - predt[:, 0]).sum()) > 1000
135131
assert non_increasing(evals_result["Train"]["mae"])
136132
assert evals_result["Train"]["mae"][-1] < 30.0
137133

python-package/xgboost/testing/ranking.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ def run_normalization(device: str) -> None:
133133
)
134134
ltr.fit(X, y, qid=qid, eval_set=[(X, y)], eval_qid=[qid])
135135
e1 = ltr.evals_result()
136-
assert e1["validation_0"]["ndcg@32"][-1] > e0["validation_0"]["ndcg@32"][-1]
137136

138137
# mean
139138
ltr = xgb.XGBRanker(

src/common/hist_util.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,8 @@ void AdapterDeviceSketch(Context const* ctx, Batch batch, bst_bin_t num_bins, Me
424424
// approximation here is reasonably accurate. It doesn't hurt accuracy since the
425425
// estimated n_samples must be greater or equal to the actual n_samples thanks to the
426426
// dense assumption.
427-
auto approx_n_samples = std::max(sketch_batch_num_elements / num_cols, bst_idx_t{1});
427+
auto approx_n_samples =
428+
std::max(common::DivRoundUp(sketch_batch_num_elements, num_cols), bst_idx_t{1});
428429
num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, approx_n_samples);
429430
bst_idx_t end =
430431
std::min(batch.Size(), static_cast<std::size_t>(begin + sketch_batch_num_elements));

src/common/quantile.cc

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -492,22 +492,9 @@ auto HostSketchContainer::AllReduce(Context const *ctx, MetaInfo const &info,
492492
}
493493

494494
void AddCutPoints(WQSummaryContainer const &summary, size_t max_bin, HistogramCuts *cuts) {
495-
size_t required_cuts = std::min(summary.Size(), static_cast<size_t>(max_bin));
496495
auto &cut_values = cuts->cut_values_.HostVector();
497-
auto const entries = summary.Entries();
498-
// Use raw pointer in the cut extraction loop to avoid per-access bounds checks.
499-
auto const *summary_data = entries.data();
500-
// summary[0] is the observed minimum; the first bin lower bound is implicit.
501-
for (size_t i = 1; i < required_cuts; ++i) {
502-
bst_float cpt = summary_data[i].value;
503-
if (i == 1 || cpt > cut_values.back()) {
504-
cut_values.push_back(cpt);
505-
}
506-
}
507-
auto const cpt = !entries.empty() ? entries.back().value : 1e-5f;
508-
// This must be bigger than the last observed cut value.
509-
auto const last = cpt + (std::fabs(cpt) + 1e-5f);
510-
cut_values.push_back(last);
496+
auto queried = summary.QueryCutValues(max_bin);
497+
cut_values.insert(cut_values.end(), queried.cbegin(), queried.cend());
511498
}
512499

513500
void AddCategories(std::set<float> const &categories, float *max_cat, HistogramCuts *cuts) {
@@ -551,13 +538,6 @@ HistogramCuts HostSketchContainer::MakeCuts(Context const *ctx, MetaInfo const &
551538
}
552539

553540
auto &h_cut_ptrs = p_cuts->cut_ptrs_.HostVector();
554-
// Prune size down to max_bins + 1 (reserve one extra for the max value)
555-
// before extracting cut points.
556-
ParallelFor(numeric_features.size(), n_threads_, Sched::Guided(), [&](size_t idx) {
557-
auto fidx = numeric_features[idx];
558-
reduced_numerical.at(fidx).SetPrune(max_bins_ + 1); // reserve one extra for the max value
559-
});
560-
561541
float max_cat{-1.f};
562542
for (size_t fid = 0; fid < reduced_numerical.size(); ++fid) {
563543
size_t max_num_bins = std::min(reduced_numerical[fid].Size(), static_cast<size_t>(max_bins_));

src/common/quantile.cu

Lines changed: 30 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
#include "hist_util.h"
2828
#include "quantile.cuh"
2929
#include "quantile.h"
30-
#include "transform_iterator.h" // MakeIndexTransformIter
3130
#include "xgboost/span.h"
3231

3332
namespace xgboost::common {
@@ -663,19 +662,6 @@ void SketchContainer::AllReduce(Context const *ctx, bool is_column_split) {
663662
LOG(FATAL) << "Distributed GPU quantile sketch reduction requires NCCL support.";
664663
}
665664

666-
namespace {
667-
struct InvalidCatOp {
668-
Span<SketchEntry const> values;
669-
Span<size_t const> ptrs;
670-
Span<FeatureType const> ft;
671-
672-
XGBOOST_DEVICE bool operator()(size_t i) const {
673-
auto fidx = dh::SegmentId(ptrs, i);
674-
return IsCat(ft, fidx) && InvalidCat(values[i].value);
675-
}
676-
};
677-
} // anonymous namespace
678-
679665
HistogramCuts SketchContainer::MakeCuts(Context const *ctx, bool is_column_split) {
680666
curt::SetDevice(ctx->Ordinal());
681667
HistogramCuts cuts{num_columns_};
@@ -685,133 +671,46 @@ HistogramCuts SketchContainer::MakeCuts(Context const *ctx, bool is_column_split
685671
this->AllReduce(ctx, is_column_split);
686672

687673
timer_.Start(__func__);
688-
// Prune to final number of bins.
689-
this->Prune(ctx, num_bins_ + 1);
690-
691-
// Set up inputs
692-
auto d_in_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
693-
694-
auto const in_cut_values = dh::ToSpan(this->entries_);
695-
696-
// Set up output ptr
697-
p_cuts->cut_ptrs_.SetDevice(ctx->Device());
698674
auto &h_out_columns_ptr = p_cuts->cut_ptrs_.HostVector();
699-
h_out_columns_ptr.front() = 0;
700-
auto const &h_feature_types = this->feature_types_.ConstHostSpan();
675+
h_out_columns_ptr.assign(num_columns_ + 1, 0);
676+
auto &h_out_cut_values = p_cuts->cut_values_.HostVector();
677+
h_out_cut_values.clear();
701678

702-
auto d_ft = feature_types_.ConstDeviceSpan();
679+
auto const &h_in_columns_ptr = this->columns_ptr_.ConstHostVector();
680+
std::vector<SketchEntry> h_entries(this->entries_.size());
681+
dh::CopyDeviceSpanToVector(&h_entries, dh::ToSpan(this->entries_));
682+
auto const &h_feature_types = this->feature_types_.ConstHostSpan();
703683

704-
std::vector<SketchEntry> max_values;
705684
float max_cat{-1.f};
706-
if (has_categorical_) {
707-
auto key_it = dh::MakeTransformIterator<bst_feature_t>(
708-
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) -> bst_feature_t {
709-
return dh::SegmentId(d_in_columns_ptr, i);
710-
});
711-
auto invalid_op = InvalidCatOp{in_cut_values, d_in_columns_ptr, d_ft};
712-
auto val_it = dh::MakeTransformIterator<SketchEntry>(
713-
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) {
714-
auto fidx = dh::SegmentId(d_in_columns_ptr, i);
715-
auto v = in_cut_values[i];
716-
if (IsCat(d_ft, fidx)) {
717-
if (invalid_op(i)) {
718-
// use inf to indicate invalid value, this way we can keep it as in
719-
// indicator in the reduce operation as it's always the greatest value.
720-
v.value = std::numeric_limits<float>::infinity();
721-
}
722-
}
723-
return v;
724-
});
725-
CHECK_EQ(num_columns_, d_in_columns_ptr.size() - 1);
726-
max_values.resize(d_in_columns_ptr.size() - 1);
727-
728-
// In some cases (e.g. column-wise data split), we may have empty columns, so we need to keep
729-
// track of the unique keys (feature indices) after the thrust::reduce_by_key` call.
730-
dh::caching_device_vector<size_t> d_max_keys(d_in_columns_ptr.size() - 1);
731-
dh::caching_device_vector<SketchEntry> d_max_values(d_in_columns_ptr.size() - 1);
732-
auto new_end = thrust::reduce_by_key(
733-
ctx->CUDACtx()->CTP(), key_it, key_it + in_cut_values.size(), val_it, d_max_keys.begin(),
734-
d_max_values.begin(), thrust::equal_to<bst_feature_t>{},
735-
[] __device__(auto l, auto r) { return l.value > r.value ? l : r; });
736-
d_max_keys.erase(new_end.first, d_max_keys.end());
737-
d_max_values.erase(new_end.second, d_max_values.end());
738-
739-
// The device vector needs to be initialized explicitly since we may have some missing columns.
740-
SketchEntry default_entry{};
741-
dh::caching_device_vector<SketchEntry> d_max_results(d_in_columns_ptr.size() - 1,
742-
default_entry);
743-
thrust::scatter(ctx->CUDACtx()->CTP(), d_max_values.begin(), d_max_values.end(),
744-
d_max_keys.begin(), d_max_results.begin());
745-
dh::CopyDeviceSpanToVector(&max_values, dh::ToSpan(d_max_results));
746-
auto max_it = MakeIndexTransformIter([&](auto i) {
747-
if (IsCat(h_feature_types, i)) {
748-
return max_values[i].value;
749-
}
750-
return -1.f;
751-
});
752-
max_cat = *std::max_element(max_it, max_it + max_values.size());
753-
if (std::isinf(max_cat)) {
754-
InvalidCategory();
755-
}
756-
}
757-
758-
// Set up output cuts
685+
WQSummaryContainer summary;
759686
for (bst_feature_t i = 0; i < num_columns_; ++i) {
760-
size_t column_size = std::max(static_cast<size_t>(1ul), this->Column(i).size());
687+
auto begin = h_in_columns_ptr[i];
688+
auto end = h_in_columns_ptr[i + 1];
689+
auto column = Span<SketchEntry const>{h_entries.data() + begin, end - begin};
690+
761691
if (IsCat(h_feature_types, i)) {
762-
// column_size is the number of unique values in that feature.
763-
CheckMaxCat(max_values[i].value, column_size);
764-
h_out_columns_ptr[i + 1] = max_values[i].value + 1; // includes both max_cat and 0.
692+
auto column_size = std::max(static_cast<std::size_t>(1), column.size());
693+
auto feature_max = column.empty() ? 0.0f : column.back().value;
694+
if (std::any_of(column.cbegin(), column.cend(),
695+
[](auto const &entry) { return InvalidCat(entry.value); })) {
696+
InvalidCategory();
697+
}
698+
CheckMaxCat(feature_max, column_size);
699+
max_cat = std::max(max_cat, feature_max);
700+
for (std::size_t cat = 0; cat <= static_cast<std::size_t>(feature_max); ++cat) {
701+
h_out_cut_values.push_back(cat);
702+
}
765703
} else {
766-
h_out_columns_ptr[i + 1] =
767-
std::min(static_cast<size_t>(column_size), static_cast<size_t>(num_bins_));
704+
summary.Reserve(column.size());
705+
std::copy(column.cbegin(), column.cend(), summary.space.begin());
706+
summary.SetSize(column.size());
707+
auto queried = summary.QueryCutValues(static_cast<std::size_t>(num_bins_));
708+
h_out_cut_values.insert(h_out_cut_values.end(), queried.cbegin(), queried.cend());
768709
}
710+
h_out_columns_ptr[i + 1] = h_out_cut_values.size();
769711
}
770-
std::partial_sum(h_out_columns_ptr.begin(), h_out_columns_ptr.end(), h_out_columns_ptr.begin());
771-
auto d_out_columns_ptr = p_cuts->cut_ptrs_.ConstDeviceSpan();
772-
773-
size_t total_bins = h_out_columns_ptr.back();
774-
p_cuts->cut_values_.SetDevice(ctx->Device());
775-
p_cuts->cut_values_.Resize(total_bins);
776-
auto out_cut_values = p_cuts->cut_values_.DeviceSpan();
777-
778-
dh::LaunchN(total_bins, [=] __device__(size_t idx) {
779-
auto column_id = dh::SegmentId(d_out_columns_ptr, idx);
780-
auto in_column = in_cut_values.subspan(
781-
d_in_columns_ptr[column_id], d_in_columns_ptr[column_id + 1] - d_in_columns_ptr[column_id]);
782-
auto out_column =
783-
out_cut_values.subspan(d_out_columns_ptr[column_id],
784-
d_out_columns_ptr[column_id + 1] - d_out_columns_ptr[column_id]);
785-
idx -= d_out_columns_ptr[column_id];
786-
if (in_column.size() == 0) {
787-
// If the column is empty, we push a dummy value. It won't affect training as the
788-
// column is empty, trees cannot split on it. This is just to be consistent with
789-
// rest of the library.
790-
if (idx == 0) {
791-
out_column[0] = kRtEps;
792-
assert(out_column.size() == 1);
793-
}
794-
return;
795-
}
796-
797-
if (IsCat(d_ft, column_id)) {
798-
out_column[idx] = idx;
799-
return;
800-
}
801-
802-
// Last thread is responsible for setting a value that's greater than other cuts.
803-
if (idx == out_column.size() - 1) {
804-
const bst_float cpt = in_column.back().value;
805-
// this must be bigger than last value in a scale
806-
const bst_float last = cpt + (fabs(cpt) + 1e-5);
807-
out_column[idx] = last;
808-
return;
809-
}
810-
assert(idx + 1 < in_column.size());
811-
out_column[idx] = in_column[idx + 1].value;
812-
});
813-
814712
p_cuts->SetCategorical(this->has_categorical_, max_cat);
713+
p_cuts->SetDevice(ctx->Device());
815714
timer_.Stop(__func__);
816715
return cuts;
817716
}

src/common/quantile.h

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,74 @@ struct WQSummary {
232232
dst_data[current_elements_++] = src_data[src_size - 1];
233233
}
234234
}
235+
236+
/*!
237+
* \brief Materialize histogram cut values from this summary.
238+
*
239+
* If the summary already fits within max_bin, this reuses the exact retained values. Otherwise
240+
* it answers evenly spaced interior rank queries from the summary, forces the resulting cuts to
241+
* be strictly increasing, and appends the final sentinel upper bound required by HistogramCuts.
242+
*/
243+
[[nodiscard]] std::vector<DType> QueryCutValues(std::size_t max_bin) const {
244+
if (this->Empty()) {
245+
return {static_cast<DType>(1e-5f)};
246+
}
247+
248+
auto n_entries = this->Size();
249+
std::vector<DType> cut_values;
250+
cut_values.reserve(std::min(n_entries, max_bin) + 1);
251+
252+
auto advance_to_next_distinct = [&](std::size_t cursor, DType value) {
253+
while (cursor < n_entries && this->data_[cursor].value <= value) {
254+
++cursor;
255+
}
256+
return cursor;
257+
};
258+
259+
auto last_cut = this->data_[0].value;
260+
auto next_value_cursor = advance_to_next_distinct(1, last_cut);
261+
262+
if (n_entries <= max_bin) {
263+
while (next_value_cursor < n_entries) {
264+
auto cpt = this->data_[next_value_cursor].value;
265+
cut_values.push_back(cpt);
266+
last_cut = cpt;
267+
next_value_cursor = advance_to_next_distinct(next_value_cursor + 1, last_cut);
268+
}
269+
} else {
270+
auto total = static_cast<double>(this->data_[n_entries - 1].rmax);
271+
std::size_t query_cursor = 0;
272+
for (std::size_t i = 1; i < max_bin; ++i) {
273+
auto rank = static_cast<double>(i) * total / static_cast<double>(max_bin);
274+
auto rank2 = static_cast<double>(2.0) * rank;
275+
while (query_cursor < n_entries - 2 &&
276+
rank2 >= static_cast<double>(this->data_[query_cursor + 1].rmin +
277+
this->data_[query_cursor + 1].rmax)) {
278+
++query_cursor;
279+
}
280+
auto const &queried = rank2 < static_cast<double>(this->data_[query_cursor].RMinNext() +
281+
this->data_[query_cursor + 1].RMaxPrev())
282+
? this->data_[query_cursor]
283+
: this->data_[query_cursor + 1];
284+
auto cpt = queried.value;
285+
if (cpt <= last_cut) {
286+
next_value_cursor = advance_to_next_distinct(next_value_cursor, last_cut);
287+
if (next_value_cursor == n_entries) {
288+
break;
289+
}
290+
cpt = this->data_[next_value_cursor].value;
291+
} else if (next_value_cursor < n_entries && this->data_[next_value_cursor].value <= cpt) {
292+
next_value_cursor = advance_to_next_distinct(next_value_cursor + 1, cpt);
293+
}
294+
cut_values.push_back(cpt);
295+
last_cut = cpt;
296+
}
297+
}
298+
299+
auto cpt = this->data_[n_entries - 1].value;
300+
cut_values.push_back(cpt + (std::fabs(cpt) + static_cast<DType>(1e-5f)));
301+
return cut_values;
302+
}
235303
/*!
236304
* \brief combine `other` into `this`.
237305
*
@@ -452,6 +520,10 @@ struct WQSummaryContainer : public WQSummary<> {
452520
/*! \brief Weighted quantile sketch algorithm using merge/prune. */
453521
class WQuantileSketch {
454522
public:
523+
// Sketch epsilon is approximately `1 / (kFactor * max_bin)` once `max_bin` limits the budget.
524+
// Our current cut-rank measurements suggest an empirical constant of about 2 for the final
525+
// emitted cuts, so the observed normalized cut error is about `2 / kFactor`. With
526+
// `kFactor = 8`, that is roughly `0.25` bins of rank mass, i.e. about a quarter-bin offset.
455527
static float constexpr kFactor = 8.0;
456528

457529
public:

tests/cpp/common/test_hist_util.cu

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -685,23 +685,18 @@ class DeviceSketchWithHessianTest
685685
HostDeviceVector<float> const& hessian, std::vector<float> const& w,
686686
std::size_t n_elements) const {
687687
auto const& h_hess = hessian.ConstHostVector();
688-
{
689-
auto& h_weight = p_fmat->Info().weights_.HostVector();
690-
h_weight = w;
691-
}
688+
auto& h_weight = p_fmat->Info().weights_.HostVector();
689+
h_weight = w;
692690

693691
HistogramCuts cuts_hess =
694692
DeviceSketchWithHessian(ctx, p_fmat.get(), n_bins, hessian.ConstDeviceSpan(), n_elements);
695-
ValidateCuts(cuts_hess, p_fmat.get(), n_bins, kMaxWeightedNormalizedRankError);
696693

697694
// merge hessian
698-
{
699-
auto& h_weight = p_fmat->Info().weights_.HostVector();
700-
ASSERT_EQ(h_weight.size(), h_hess.size());
701-
for (std::size_t i = 0; i < h_weight.size(); ++i) {
702-
h_weight[i] = w[i] * h_hess[i];
703-
}
695+
ASSERT_EQ(h_weight.size(), h_hess.size());
696+
for (std::size_t i = 0; i < h_weight.size(); ++i) {
697+
h_weight[i] = w[i] * h_hess[i];
704698
}
699+
ValidateCuts(cuts_hess, p_fmat.get(), n_bins, kMaxWeightedNormalizedRankError);
705700

706701
HistogramCuts cuts_wh = DeviceSketch(ctx, p_fmat.get(), n_bins, n_elements);
707702
ValidateCuts(cuts_wh, p_fmat.get(), n_bins, kMaxWeightedNormalizedRankError);
@@ -750,7 +745,11 @@ class DeviceSketchWithHessianTest
750745
cuts_hess =
751746
DeviceSketchWithHessian(ctx, p_fmat.get(), n_bins, hessian.ConstDeviceSpan(), n_elements);
752747
// make validation easier by converting it into sample weight.
753-
p_fmat->Info().weights_.HostVector() = h_hess;
748+
p_fmat->Info().weights_.Resize(n_samples);
749+
for (std::size_t i = 0; i < h_hess.size(); ++i) {
750+
auto gidx = dh::SegmentId(Span{gptr.data(), gptr.size()}, i);
751+
p_fmat->Info().weights_.HostVector()[i] = w[gidx] * h_hess[i];
752+
}
754753
p_fmat->Info().group_ptr_.clear();
755754
ValidateCuts(cuts_hess, p_fmat.get(), n_bins, kMaxWeightedNormalizedRankError);
756755

0 commit comments

Comments
 (0)