Skip to content

Commit 45861e7

Browse files
authored
Reduce quantile sketch safety factor (#12167)
1 parent 0ce524f commit 45861e7

5 files changed

Lines changed: 39 additions & 84 deletions

File tree

python-package/xgboost/testing/dask.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -425,9 +425,8 @@ def run(DMatrixT: Type[dxgb.DaskDMatrix]) -> None:
425425
evals=[(Xy_valid, "Valid")],
426426
xgb_model=results["booster"],
427427
)
428-
np.testing.assert_allclose(
429-
results_1["history"]["Valid"]["rmse"], results_2["history"]["Valid"]["rmse"]
430-
)
428+
assert np.isfinite(results_1["history"]["Valid"]["rmse"]).all()
429+
assert np.isfinite(results_2["history"]["Valid"]["rmse"]).all()
431430

432431
predt_0 = dxgb.inplace_predict(client, results, denc).compute()
433432
predt_1 = dxgb.inplace_predict(client, results, dreenc).compute()

src/common/quantile.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ class WQuantileSketch {
560560
// Safety factor used to oversample the internal sketch relative to the target rank
561561
// resolution. User-facing epsilon remains the target rank guarantee; `kFactor`
562562
// only affects how much summary storage we reserve to achieve it.
563-
static float constexpr kFactor = 8.0;
563+
static float constexpr kFactor = 2.0;
564564

565565
public:
566566
using Summary = WQSummary<>;

tests/cpp/data/test_gradient_index.cc

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,25 @@
22
* Copyright 2021-2024, XGBoost contributors
33
*/
44
#include <gtest/gtest.h>
5-
#include <xgboost/data.h> // for BatchIterator, BatchSet, DMatrix, BatchParam
6-
7-
#include <algorithm> // for sort, unique
8-
#include <cmath> // for isnan
9-
#include <cstddef> // for size_t
10-
#include <limits> // for numeric_limits
11-
#include <memory> // for shared_ptr, __shared_ptr_access, unique_ptr
12-
#include <string> // for string
13-
#include <tuple> // for make_tuple, tie, tuple
14-
#include <utility> // for move
15-
#include <vector> // for vector
5+
#include <xgboost/data.h> // for BatchIterator, BatchSet, DMatrix, BatchParam
6+
7+
#include <algorithm> // for sort, unique
8+
#include <cmath> // for isnan
9+
#include <cstddef> // for size_t
10+
#include <limits> // for numeric_limits
11+
#include <memory> // for shared_ptr, __shared_ptr_access, unique_ptr
12+
#include <string> // for string
13+
#include <tuple> // for make_tuple, tie, tuple
14+
#include <utility> // for move
15+
#include <vector> // for vector
1616

1717
#include "../../../src/common/categorical.h" // for AsCat
1818
#include "../../../src/common/column_matrix.h" // for ColumnMatrix
1919
#include "../../../src/common/hist_util.h" // for Index, HistogramCuts, SketchOnDMatrix
20-
#include "../../../src/common/io.h" // for MemoryBufferStream
2120
#include "../../../src/data/adapter.h" // for SparsePageAdapterBatch
2221
#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix
2322
#include "../../../src/tree/param.h" // for TrainParam
23+
#include "../common/test_hist_util.h" // for ValidateCuts
2424
#include "../helpers.h" // for GenerateRandomCategoricalSingleColumn...
2525
#include "xgboost/base.h" // for bst_bin_t
2626
#include "xgboost/context.h" // for Context
@@ -184,12 +184,8 @@ class GHistIndexMatrixTest : public testing::TestWithParam<std::tuple<float, flo
184184
ASSERT_EQ(from_sparse_page.Size(), from_ellpack->Size());
185185
ASSERT_EQ(from_sparse_page.index.Size(), from_ellpack->index.Size());
186186

187-
auto const &gidx_from_sparse = from_sparse_page.index;
188-
auto const &gidx_from_ellpack = from_ellpack->index;
189-
190-
for (size_t i = 0; i < gidx_from_sparse.Size(); ++i) {
191-
ASSERT_EQ(gidx_from_sparse[i], gidx_from_ellpack[i]);
192-
}
187+
common::ValidateCuts(from_sparse_page.Cuts(), Xy.get(), kBins);
188+
common::ValidateCuts(from_ellpack->Cuts(), Xy.get(), kBins);
193189

194190
auto const &columns_from_sparse = from_sparse_page.Transpose();
195191
auto const &columns_from_ellpack = from_ellpack->Transpose();
@@ -199,20 +195,6 @@ class GHistIndexMatrixTest : public testing::TestWithParam<std::tuple<float, flo
199195
for (size_t i = 0; i < n_features; ++i) {
200196
ASSERT_EQ(columns_from_sparse.GetColumnType(i), columns_from_ellpack.GetColumnType(i));
201197
}
202-
203-
std::string from_sparse_buf;
204-
{
205-
common::AlignedMemWriteStream fo{&from_sparse_buf};
206-
auto n_bytes = columns_from_sparse.Write(&fo);
207-
ASSERT_EQ(fo.Tell(), n_bytes);
208-
}
209-
std::string from_ellpack_buf;
210-
{
211-
common::AlignedMemWriteStream fo{&from_ellpack_buf};
212-
auto n_bytes = columns_from_sparse.Write(&fo);
213-
ASSERT_EQ(fo.Tell(), n_bytes);
214-
}
215-
ASSERT_EQ(from_sparse_buf, from_ellpack_buf);
216198
}
217199
}
218200
};

tests/cpp/data/test_iterative_dmatrix.cu

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
#include "../../../src/data/ellpack_page.cuh"
1111
#include "../../../src/data/ellpack_page.h"
1212
#include "../../../src/data/iterative_dmatrix.h"
13-
#include "../../../src/tree/param.h" // TrainParam
14-
#include "../filesystem.h" // for TemporaryDirectory
13+
#include "../../../src/tree/param.h" // TrainParam
14+
#include "../common/test_hist_util.h" // for ValidateCuts
15+
#include "../filesystem.h" // for TemporaryDirectory
1516
#include "../helpers.h"
1617
#include "test_iterative_dmatrix.h"
1718

@@ -47,35 +48,28 @@ void TestEquivalent(float sparsity) {
4748

4849
std::visit(
4950
[](auto&& from_iter, auto&& from_data) {
50-
ASSERT_EQ(from_iter.gidx_fvalue_map.size(), from_data.gidx_fvalue_map.size());
51-
for (size_t i = 0; i < from_iter.gidx_fvalue_map.size(); ++i) {
52-
EXPECT_NEAR(from_iter.gidx_fvalue_map[i], from_data.gidx_fvalue_map[i], kRtEps);
53-
}
5451
ASSERT_EQ(from_iter.NumFeatures(), from_data.NumFeatures());
55-
for (size_t i = 0; i < from_iter.NumFeatures() + 1; ++i) {
56-
ASSERT_EQ(from_iter.feature_segments[i], from_data.feature_segments[i]);
57-
}
5852
},
5953
from_iter, from_data);
6054

55+
common::ValidateCuts(page_concatenated->Cuts(), dm.get(), 256);
56+
common::ValidateCuts(ellpack.Impl()->Cuts(), dm.get(), 256);
57+
6158
std::vector<common::CompressedByteT> buffer_from_iter, buffer_from_data;
6259
auto data_iter = page_concatenated->GetHostEllpack(&ctx, &buffer_from_iter);
6360
auto data_buf = ellpack.Impl()->GetHostEllpack(&ctx, &buffer_from_data);
6461
ASSERT_NE(buffer_from_data.size(), 0);
6562
ASSERT_NE(buffer_from_iter.size(), 0);
6663
CHECK_EQ(ellpack.Impl()->NumSymbols(), page_concatenated->NumSymbols());
6764

68-
std::visit(
69-
[](auto&& from_iter, auto&& from_data) {
70-
CHECK_EQ(from_data.n_rows * from_data.row_stride,
71-
from_data.n_rows * from_iter.row_stride);
72-
},
73-
from_iter, from_data);
65+
std::visit([](auto&& from_iter,
66+
auto&& from_data) { CHECK_EQ(from_data.row_stride, from_iter.row_stride); },
67+
from_iter, from_data);
7468
std::visit(
7569
[](auto&& from_data, auto&& data_buf, auto&& data_iter) {
76-
for (size_t i = 0; i < from_data.n_rows * from_data.row_stride; ++i) {
77-
CHECK_EQ(data_buf.gidx_iter[i], data_iter.gidx_iter[i]);
78-
}
70+
ASSERT_EQ(data_buf.row_stride, data_iter.row_stride);
71+
ASSERT_EQ(data_buf.NullValue(), data_iter.NullValue());
72+
ASSERT_EQ(from_data.n_rows * from_data.row_stride, data_buf.n_rows * data_buf.row_stride);
7973
},
8074
from_data, data_buf, data_iter);
8175
}

tests/test_distributed/test_with_spark/test_spark.py

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,6 @@ def reg_data(self, spark: SparkSession) -> RegData:
245245
def test_regressor(
246246
self, spark: SparkSession, reg_data: RegData, num_workers: int
247247
) -> None:
248-
train_rows = np.where(~reg_data.is_val)[0]
249-
validation_rows = np.where(reg_data.is_val)[0]
250248
device = _spark_test_device(spark)
251249

252250
reg_param = {
@@ -258,13 +256,6 @@ def test_regressor(
258256
"early_stopping_rounds": 1,
259257
"device": device,
260258
}
261-
reg = XGBRegressor(**reg_param).fit(
262-
reg_data.X_train,
263-
reg_data.y_train,
264-
sample_weight=reg_data.weights[train_rows],
265-
eval_set=[(reg_data.X_test, reg_data.y_test)],
266-
sample_weight_eval_set=[reg_data.weights[validation_rows]],
267-
)
268259
spark_regressor = SparkXGBRegressor(
269260
pred_contrib_col="pred_contribs",
270261
weight_col="weight",
@@ -285,8 +276,7 @@ def test_regressor(
285276
.toPandas()["pred_contribs"]
286277
.tolist()
287278
)
288-
rounds = reg.get_booster().num_boosted_rounds()
289-
iter_range = (0, max(1, min(5, rounds)))
279+
iter_range = (0, 1)
290280
spark_iter_regressor = SparkXGBRegressor(
291281
weight_col="weight",
292282
validation_indicator_col="is_val",
@@ -302,32 +292,22 @@ def test_regressor(
302292
.to_numpy()
303293
)
304294

305-
score_atol = 1e-2
306295
train_history = spark_regressor.training_summary.train_objective_history["rmse"]
296+
valid_history = spark_regressor.training_summary.validation_objective_history[
297+
"rmse"
298+
]
307299
assert len(train_history) > 0
300+
assert len(valid_history) > 0
301+
assert len(train_history) == len(valid_history)
308302
assert np.isfinite(train_history).all()
309-
assert np.all(np.diff(train_history) <= 0.0)
310-
assert np.allclose(
311-
reg.best_score,
312-
spark_regressor._xgb_sklearn_model.best_score,
313-
atol=score_atol,
314-
)
315-
assert preds.shape == reg.predict(reg_data.X).shape
316-
assert (
317-
iter_preds.shape
318-
== reg.predict(reg_data.X, iteration_range=iter_range).shape
319-
)
303+
assert np.isfinite(valid_history).all()
304+
assert preds.shape == (len(reg_data.y),)
305+
assert iter_preds.shape == preds.shape
320306

321307
assert np.allclose(pred_contribs.sum(axis=1), preds, rtol=1e-3)
322308
assert np.allclose(
323-
reg.evals_result()["validation_0"]["rmse"],
324-
spark_regressor.training_summary.validation_objective_history["rmse"],
325-
atol=score_atol,
326-
)
327-
assert np.allclose(
328-
reg.best_score,
309+
min(valid_history),
329310
spark_regressor._xgb_sklearn_model.best_score,
330-
atol=score_atol,
331311
)
332312

333313
def test_training_continuation(self, reg_data: RegData) -> None:

0 commit comments

Comments
 (0)