Skip to content

Commit c9c8b75

Browse files
authored
Preserve exact weighted values in sorted sketch (#12148)
1 parent ab6d3c6 commit c9c8b75

2 files changed

Lines changed: 58 additions & 0 deletions

File tree

src/common/quantile.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,17 +118,54 @@ struct WQSummary {
118118
auto const *col_data = column.data();
119119
auto const col_size = column.size();
120120
double sum_total{0.0};
121+
std::size_t unique_values{0};
121122
double rmin{0.0};
122123
double wmin{0.0};
123124
bst_float last_fvalue{0.0f};
124125
double next_goal{-1.0f};
125126

126127
// first pass
127128
for (size_t i = 0; i < col_size; ++i) {
129+
if (i == 0 || col_data[i - 1].fvalue != col_data[i].fvalue) {
130+
++unique_values;
131+
}
128132
auto const &c = col_data[i];
129133
sum_total += weights[c.index];
130134
}
131135

136+
if (unique_values <= max_size) {
137+
// When we have enough budget to keep every unique feature value, emit the exact
138+
// weighted summary instead of running the weighted goal-selection logic below.
139+
for (size_t i = 0; i < col_size; ++i) {
140+
auto const &c = col_data[i];
141+
if (i == 0) {
142+
last_fvalue = c.fvalue;
143+
wmin = weights[c.index];
144+
continue;
145+
}
146+
if (last_fvalue == c.fvalue) {
147+
wmin += weights[c.index];
148+
continue;
149+
}
150+
151+
auto rmax = rmin + wmin;
152+
data_[this->Size()] = Entry(static_cast<bst_float>(rmin), static_cast<bst_float>(rmax),
153+
static_cast<bst_float>(wmin), last_fvalue);
154+
this->SetSize(this->Size() + 1);
155+
rmin = rmax;
156+
last_fvalue = c.fvalue;
157+
wmin = weights[c.index];
158+
}
159+
160+
if (col_size != 0) {
161+
auto rmax = rmin + wmin;
162+
data_[this->Size()] = Entry(static_cast<bst_float>(rmin), static_cast<bst_float>(rmax),
163+
static_cast<bst_float>(wmin), last_fvalue);
164+
this->SetSize(this->Size() + 1);
165+
}
166+
return;
167+
}
168+
132169
// second pass
133170
for (size_t i = 0; i < col_size; ++i) {
134171
auto const &c = col_data[i];

tests/cpp/common/test_hist_util.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,27 @@ TEST(HistUtil, DenseCutsAccuracyTestWeights) {
179179
}
180180
}
181181

182+
TEST(HistUtil, SortedWeightedExactCuts) {
183+
Context ctx;
184+
std::vector<float> x{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
185+
std::vector<float> weights{3.0f, 0.25f, 7.0f, 2.0f, 0.5f, 4.0f};
186+
auto dmat = GetDMatrixFromData(x, x.size(), 1);
187+
dmat->Info().weights_.HostVector() = weights;
188+
189+
auto sorted_cuts = SketchOnDMatrix(&ctx, dmat.get(), x.size(), true);
190+
auto row_cuts = SketchOnDMatrix(&ctx, dmat.get(), x.size(), false);
191+
192+
ASSERT_EQ(sorted_cuts.Ptrs(), row_cuts.Ptrs());
193+
ASSERT_EQ(sorted_cuts.Values().size(), row_cuts.Values().size());
194+
ASSERT_EQ(sorted_cuts.Values().size(), x.size());
195+
for (std::size_t i = 1; i < x.size(); ++i) {
196+
EXPECT_FLOAT_EQ(sorted_cuts.Values()[i - 1], x[i]);
197+
EXPECT_FLOAT_EQ(sorted_cuts.Values()[i - 1], row_cuts.Values()[i - 1]);
198+
}
199+
EXPECT_GT(sorted_cuts.Values().back(), x.back());
200+
EXPECT_FLOAT_EQ(sorted_cuts.Values().back(), row_cuts.Values().back());
201+
}
202+
182203
void TestQuantileWithHessian(bool use_sorted) {
183204
int bin_sizes[] = {2, 16, 256, 512};
184205
int sizes[] = {1000, 1500};

0 commit comments

Comments
 (0)