Skip to content

Commit 26a1379

Browse files
authored
Use context for thrust caching policy. (#12191)
1 parent d46abbe commit 26a1379

6 files changed

Lines changed: 59 additions & 68 deletions

File tree

src/common/device_helpers.cuh

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -764,11 +764,6 @@ template <cudaMemcpyKind kind, typename T, typename U>
764764
#endif // CUDART_VERSION >= 12080
765765
}
766766

767-
inline auto CachingThrustPolicy() {
768-
XGBCachingDeviceAllocator<char> alloc;
769-
return thrust::cuda::par_nosync(alloc).on(::xgboost::curt::DefaultStream());
770-
}
771-
772767
// Force nvcc to load data as constant
773768
template <typename T>
774769
class LDGIterator {

src/tree/constraints.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,9 @@ __global__ void InteractionConstraintSplitKernel(LBitField64 feature, int32_t fe
281281
}
282282
}
283283

284-
void FeatureInteractionConstraintDevice::Split(bst_node_t node_id, bst_feature_t feature_id,
285-
bst_node_t left_id, bst_node_t right_id) {
284+
void FeatureInteractionConstraintDevice::Split(Context const* ctx, bst_node_t node_id,
285+
bst_feature_t feature_id, bst_node_t left_id,
286+
bst_node_t right_id) {
286287
if (!has_constraint_) {
287288
return;
288289
}
@@ -310,7 +311,7 @@ void FeatureInteractionConstraintDevice::Split(bst_node_t node_id, bst_feature_t
310311
launch_split(InteractionConstraintSplitKernel, feature_buffer_, feature_id, node, left, right);
311312

312313
// clear the buffer after use
313-
thrust::fill_n(dh::CachingThrustPolicy(), feature_buffer_.Data(), feature_buffer_.NumValues(), 0);
314+
thrust::fill_n(ctx->CUDACtx()->CTP(), feature_buffer_.Data(), feature_buffer_.NumValues(), 0);
314315
}
315316

316317
} // namespace xgboost

src/tree/constraints.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ struct FeatureInteractionConstraintDevice {
9292
common::Span<bst_feature_t const> Query(common::Span<bst_feature_t const> feature_list,
9393
bst_node_t nidx);
9494
/*! \brief Apply split for node_id. */
95-
void Split(bst_node_t node_id, bst_feature_t feature_id, bst_node_t left_id, bst_node_t right_id);
95+
void Split(Context const* ctx, bst_node_t node_id, bst_feature_t feature_id, bst_node_t left_id,
96+
bst_node_t right_id);
9697
};
9798

9899
} // namespace xgboost

src/tree/updater_gpu_hist.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,7 @@ struct GPUHistMakerDevice {
650650
evaluator_.ApplyTreeSplit(candidate, p_tree);
651651

652652
const auto& parent = tree[candidate.nidx];
653-
interaction_constraints.Split(candidate.nidx, parent.SplitIndex(), parent.LeftChild(),
653+
interaction_constraints.Split(ctx_, candidate.nidx, parent.SplitIndex(), parent.LeftChild(),
654654
parent.RightChild());
655655
}
656656

tests/cpp/common/test_device_vector.cu

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2024-2025, XGBoost Contributors
2+
* Copyright 2024-2026, XGBoost Contributors
33
*/
44
#include <gtest/gtest.h>
55
#include <thrust/iterator/counting_iterator.h> // for make_counting_iterator
@@ -8,9 +8,11 @@
88
#include <numeric> // for iota
99
#include <thread> // for thread
1010

11+
#include "../../../src/common/cuda_context.cuh" // for CUDAContext
1112
#include "../../../src/common/cuda_rt_utils.h" // for DrVersion
12-
#include "../../../src/common/device_helpers.cuh" // for CachingThrustPolicy, PinnedMemory
13+
#include "../../../src/common/device_helpers.cuh" // for PinnedMemory
1314
#include "../../../src/common/device_vector.cuh"
15+
#include "../helpers.h" // for MakeCUDACtx
1416
#include "xgboost/global_config.h" // for GlobalConfigThreadLocalStore
1517
#include "xgboost/windefs.h" // for xgboost_IS_WIN
1618

@@ -33,6 +35,7 @@ TEST(AsyncPoolAllocator, Basic) {
3335
#endif // !defined(XGBOOST_USE_RMM)
3436

3537
TEST(DeviceUVector, Basic) {
38+
auto ctx = xgboost::MakeCUDACtx(0);
3639
GlobalMemoryLogger().Clear();
3740
std::int32_t verbosity{3};
3841
std::swap(verbosity, xgboost::GlobalConfigThreadLocalStore::Get()->verbosity);
@@ -51,11 +54,11 @@ TEST(DeviceUVector, Basic) {
5154
ASSERT_EQ(std::distance(uvec1.begin(), uvec1.end()), uvec1.size());
5255
auto orig = uvec1.size();
5356

54-
thrust::sequence(dh::CachingThrustPolicy(), uvec1.begin(), uvec1.end(), 0);
57+
thrust::sequence(ctx.CUDACtx()->CTP(), uvec1.begin(), uvec1.end(), 0);
5558
uvec1.resize(32);
5659
ASSERT_EQ(uvec1.size(), 32);
5760
ASSERT_EQ(uvec1.Capacity(), 32);
58-
auto eq = thrust::equal(dh::CachingThrustPolicy(), uvec1.cbegin(), uvec1.cbegin() + orig,
61+
auto eq = thrust::equal(ctx.CUDACtx()->CTP(), uvec1.cbegin(), uvec1.cbegin() + orig,
5962
thrust::make_counting_iterator(0));
6063
ASSERT_TRUE(eq);
6164

@@ -69,6 +72,7 @@ namespace {
6972
class TestVirtualMem : public ::testing::TestWithParam<CUmemLocationType> {
7073
public:
7174
void Run() {
75+
auto ctx = xgboost::MakeCUDACtx(0);
7276
auto type = this->GetParam();
7377
detail::GrowOnlyVirtualMemVec vec{type};
7478
auto prop = xgboost::cudr::MakeAllocProp(type);
@@ -86,7 +90,7 @@ class TestVirtualMem : public ::testing::TestWithParam<CUmemLocationType> {
8690
};
8791
auto fill = [&](std::int32_t n_orig, xgboost::common::Span<std::int32_t> data) {
8892
if (type == CU_MEM_LOCATION_TYPE_DEVICE) {
89-
thrust::sequence(dh::CachingThrustPolicy(), data.data() + n_orig, data.data() + data.size(),
93+
thrust::sequence(ctx.CUDACtx()->CTP(), data.data() + n_orig, data.data() + data.size(),
9094
n_orig);
9195
dh::safe_cuda(cudaMemcpy(h_data.data(), data.data(), data.size_bytes(), cudaMemcpyDefault));
9296
} else {
@@ -151,7 +155,7 @@ TEST(TestVirtualMem, Version) {
151155
PinnedMemory pinned;
152156
#if defined(xgboost_IS_WIN)
153157
ASSERT_FALSE(pinned.IsVm());
154-
#else // defined(xgboost_IS_WIN)
158+
#else // defined(xgboost_IS_WIN)
155159
if (major == 12 && minor >= 5 || major > 12) {
156160
ASSERT_TRUE(pinned.IsVm());
157161
} else {

tests/cpp/tree/test_constraints.cu

Lines changed: 42 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,11 @@ struct FConstraintWrapper : public FeatureInteractionConstraintDevice {
2121
common::Span<LBitField64> GetNodeConstraints() {
2222
return FeatureInteractionConstraintDevice::s_node_constraints_;
2323
}
24-
FConstraintWrapper(tree::TrainParam param, bst_feature_t n_features) :
25-
FeatureInteractionConstraintDevice(param, n_features) {}
24+
FConstraintWrapper(tree::TrainParam param, bst_feature_t n_features)
25+
: FeatureInteractionConstraintDevice(param, n_features) {}
2626

27-
dh::device_vector<bst_feature_t> const& GetDSets() const {
28-
return d_sets_;
29-
}
30-
dh::device_vector<size_t> const& GetDSetsPtr() const {
31-
return d_sets_ptr_;
32-
}
27+
dh::device_vector<bst_feature_t> const& GetDSets() const { return d_sets_; }
28+
dh::device_vector<size_t> const& GetDSetsPtr() const { return d_sets_ptr_; }
3329
};
3430

3531
std::string GetConstraintsStr() {
@@ -46,12 +42,11 @@ tree::TrainParam GetParameter() {
4642

4743
void CompareBitField(LBitField64 d_field, std::set<uint32_t> positions) {
4844
std::vector<LBitField64::value_type> h_field_storage(d_field.Bits().size());
49-
thrust::copy(thrust::device_ptr<LBitField64::value_type>(d_field.Bits().data()),
50-
thrust::device_ptr<LBitField64::value_type>(
51-
d_field.Bits().data() + d_field.Bits().size()),
52-
h_field_storage.data());
53-
LBitField64 h_field{ {h_field_storage.data(),
54-
h_field_storage.data() + h_field_storage.size()} };
45+
thrust::copy(
46+
thrust::device_ptr<LBitField64::value_type>(d_field.Bits().data()),
47+
thrust::device_ptr<LBitField64::value_type>(d_field.Bits().data() + d_field.Bits().size()),
48+
h_field_storage.data());
49+
LBitField64 h_field{{h_field_storage.data(), h_field_storage.data() + h_field_storage.size()}};
5550

5651
for (size_t i = 0; i < h_field.Capacity(); ++i) {
5752
if (positions.find(i) != positions.cend()) {
@@ -64,7 +59,6 @@ void CompareBitField(LBitField64 d_field, std::set<uint32_t> positions) {
6459

6560
} // anonymous namespace
6661

67-
6862
TEST(GPUFeatureInteractionConstraint, Init) {
6963
{
7064
int32_t constexpr kFeatures = 6;
@@ -75,12 +69,10 @@ TEST(GPUFeatureInteractionConstraint, Init) {
7569
for (LBitField64 const& d_node : s_nodes_constraints) {
7670
std::vector<LBitField64::value_type> h_node_storage(d_node.Bits().size());
7771
thrust::copy(thrust::device_ptr<LBitField64::value_type const>(d_node.Bits().data()),
78-
thrust::device_ptr<LBitField64::value_type const>(
79-
d_node.Bits().data() + d_node.Bits().size()),
72+
thrust::device_ptr<LBitField64::value_type const>(d_node.Bits().data() +
73+
d_node.Bits().size()),
8074
h_node_storage.data());
81-
LBitField64 h_node {
82-
{h_node_storage.data(), h_node_storage.data() + h_node_storage.size()}
83-
};
75+
LBitField64 h_node{{h_node_storage.data(), h_node_storage.data() + h_node_storage.size()}};
8476
// no feature is attached to node.
8577
for (size_t i = 0; i < h_node.Capacity(); ++i) {
8678
ASSERT_FALSE(h_node.Check(i));
@@ -94,8 +86,8 @@ TEST(GPUFeatureInteractionConstraint, Init) {
9486
tree::TrainParam param = GetParameter();
9587
param.interaction_constraints = R"([[0, 1, 3], [3, 5, 6]])";
9688
FConstraintWrapper constraints(param, kFeatures);
97-
std::vector<bst_feature_t> h_sets {0, 0, 0, 1, 1, 1};
98-
std::vector<size_t> h_sets_ptr {0, 1, 2, 2, 4, 4, 5, 6};
89+
std::vector<bst_feature_t> h_sets{0, 0, 0, 1, 1, 1};
90+
std::vector<size_t> h_sets_ptr{0, 1, 2, 2, 4, 4, 5, 6};
9991
auto d_sets = constraints.GetDSets();
10092
ASSERT_EQ(h_sets.size(), d_sets.size());
10193
auto d_sets_ptr = constraints.GetDSetsPtr();
@@ -120,18 +112,19 @@ TEST(GPUFeatureInteractionConstraint, Init) {
120112
auto _128_end = d_sets_ptr[128 + 1];
121113
ASSERT_EQ(_128_end - _128_beg, 2);
122114
ASSERT_EQ(d_sets[_128_beg], 1);
123-
ASSERT_EQ(d_sets[_128_end-1], 2);
115+
ASSERT_EQ(d_sets[_128_end - 1], 2);
124116
}
125117
}
126118

127119
TEST(GPUFeatureInteractionConstraint, Split) {
120+
auto ctx = MakeCUDACtx(0);
128121
tree::TrainParam param = GetParameter();
129122
int32_t constexpr kFeatures = 6;
130123
FConstraintWrapper constraints(param, kFeatures);
131124

132125
{
133126
LBitField64 d_node[3];
134-
constraints.Split(0, /*feature_id=*/1, 1, 2);
127+
constraints.Split(&ctx, 0, /*feature_id=*/1, 1, 2);
135128
for (size_t nid = 0; nid < 3; ++nid) {
136129
d_node[nid] = constraints.GetNodeConstraints()[nid];
137130
ASSERT_EQ(d_node[nid].Bits().size(), 1);
@@ -141,7 +134,7 @@ TEST(GPUFeatureInteractionConstraint, Split) {
141134

142135
{
143136
LBitField64 d_node[5];
144-
constraints.Split(1, /*feature_id=*/0, /*left_id=*/3, /*right_id=*/4);
137+
constraints.Split(&ctx, 1, /*feature_id=*/0, /*left_id=*/3, /*right_id=*/4);
145138
for (auto nid : {1, 3, 4}) {
146139
d_node[nid] = constraints.GetNodeConstraints()[nid];
147140
CompareBitField(d_node[nid], {0, 1, 2});
@@ -165,24 +158,22 @@ TEST(GPUFeatureInteractionConstraint, QueryNode) {
165158
}
166159

167160
{
168-
constraints.Split(/*node_id=*/ 0, /*feature_id=*/ 1, 1, 2);
161+
constraints.Split(&ctx, /*node_id=*/0, /*feature_id=*/1, 1, 2);
169162
auto span = constraints.QueryNode(&ctx, 0);
170-
std::vector<bst_feature_t> h_result (span.size());
163+
std::vector<bst_feature_t> h_result(span.size());
171164
thrust::copy(thrust::device_ptr<bst_feature_t>(span.data()),
172-
thrust::device_ptr<bst_feature_t>(span.data() + span.size()),
173-
h_result.begin());
165+
thrust::device_ptr<bst_feature_t>(span.data() + span.size()), h_result.begin());
174166
ASSERT_EQ(h_result.size(), 2);
175167
ASSERT_EQ(h_result[0], 1);
176168
ASSERT_EQ(h_result[1], 2);
177169
}
178170

179171
{
180-
constraints.Split(1, /*feature_id=*/0, 3, 4);
172+
constraints.Split(&ctx, 1, /*feature_id=*/0, 3, 4);
181173
auto span = constraints.QueryNode(&ctx, 1);
182-
std::vector<bst_feature_t> h_result (span.size());
174+
std::vector<bst_feature_t> h_result(span.size());
183175
thrust::copy(thrust::device_ptr<bst_feature_t>(span.data()),
184-
thrust::device_ptr<bst_feature_t>(span.data() + span.size()),
185-
h_result.begin());
176+
thrust::device_ptr<bst_feature_t>(span.data() + span.size()), h_result.begin());
186177
ASSERT_EQ(h_result.size(), 3);
187178
ASSERT_EQ(h_result[0], 0);
188179
ASSERT_EQ(h_result[1], 1);
@@ -192,8 +183,7 @@ TEST(GPUFeatureInteractionConstraint, QueryNode) {
192183
span = constraints.QueryNode(&ctx, 3);
193184
h_result.resize(span.size());
194185
thrust::copy(thrust::device_ptr<bst_feature_t>(span.data()),
195-
thrust::device_ptr<bst_feature_t>(span.data() + span.size()),
196-
h_result.begin());
186+
thrust::device_ptr<bst_feature_t>(span.data() + span.size()), h_result.begin());
197187
ASSERT_EQ(h_result.size(), 3);
198188
ASSERT_EQ(h_result[0], 0);
199189
ASSERT_EQ(h_result[1], 1);
@@ -204,12 +194,11 @@ TEST(GPUFeatureInteractionConstraint, QueryNode) {
204194
tree::TrainParam large_param = GetParameter();
205195
large_param.interaction_constraints = R"([[1, 139], [244, 0], [139, 221]])";
206196
FConstraintWrapper large_features(large_param, 256);
207-
large_features.Split(0, 139, 1, 2);
197+
large_features.Split(&ctx, 0, 139, 1, 2);
208198
auto span = large_features.QueryNode(&ctx, 0);
209-
std::vector<bst_feature_t> h_result (span.size());
199+
std::vector<bst_feature_t> h_result(span.size());
210200
thrust::copy(thrust::device_ptr<bst_feature_t>(span.data()),
211-
thrust::device_ptr<bst_feature_t>(span.data() + span.size()),
212-
h_result.begin());
201+
thrust::device_ptr<bst_feature_t>(span.data() + span.size()), h_result.begin());
213202
ASSERT_EQ(h_result.size(), 3);
214203
ASSERT_EQ(h_result[0], 1);
215204
ASSERT_EQ(h_result[1], 139);
@@ -230,12 +219,13 @@ void CompareFeatureList(common::Span<bst_feature_t const> s_output,
230219
} // anonymous namespace
231220

232221
TEST(GPUFeatureInteractionConstraint, Query) {
222+
auto ctx = MakeCUDACtx(0);
233223
{
234224
tree::TrainParam param = GetParameter();
235225
bst_feature_t constexpr kFeatures = 6;
236226
FConstraintWrapper constraints(param, kFeatures);
237-
std::vector<bst_feature_t> h_input_feature_list {0, 1, 2, 3, 4, 5};
238-
dh::device_vector<bst_feature_t> d_input_feature_list (h_input_feature_list);
227+
std::vector<bst_feature_t> h_input_feature_list{0, 1, 2, 3, 4, 5};
228+
dh::device_vector<bst_feature_t> d_input_feature_list(h_input_feature_list);
239229
common::Span<bst_feature_t> s_input_feature_list = dh::ToSpan(d_input_feature_list);
240230

241231
auto s_output = constraints.Query(s_input_feature_list, 0);
@@ -245,9 +235,9 @@ TEST(GPUFeatureInteractionConstraint, Query) {
245235
tree::TrainParam param = GetParameter();
246236
bst_feature_t constexpr kFeatures = 6;
247237
FConstraintWrapper constraints(param, kFeatures);
248-
constraints.Split(/*node_id=*/0, /*feature_id=*/1, /*left_id=*/1, /*right_id=*/2);
249-
constraints.Split(/*node_id=*/1, /*feature_id=*/0, /*left_id=*/3, /*right_id=*/4);
250-
constraints.Split(/*node_id=*/4, /*feature_id=*/3, /*left_id=*/5, /*right_id=*/6);
238+
constraints.Split(&ctx, /*node_id=*/0, /*feature_id=*/1, /*left_id=*/1, /*right_id=*/2);
239+
constraints.Split(&ctx, /*node_id=*/1, /*feature_id=*/0, /*left_id=*/3, /*right_id=*/4);
240+
constraints.Split(&ctx, /*node_id=*/4, /*feature_id=*/3, /*left_id=*/5, /*right_id=*/6);
251241
/*
252242
* (node id) [allowed features]
253243
*
@@ -263,8 +253,8 @@ TEST(GPUFeatureInteractionConstraint, Query) {
263253
*
264254
*/
265255

266-
std::vector<bst_feature_t> h_input_feature_list {0, 1, 2, 3, 4, 5};
267-
dh::device_vector<bst_feature_t> d_input_feature_list (h_input_feature_list);
256+
std::vector<bst_feature_t> h_input_feature_list{0, 1, 2, 3, 4, 5};
257+
dh::device_vector<bst_feature_t> d_input_feature_list(h_input_feature_list);
268258
common::Span<bst_feature_t> s_input_feature_list = dh::ToSpan(d_input_feature_list);
269259

270260
auto s_output = constraints.Query(s_input_feature_list, 1);
@@ -289,10 +279,10 @@ TEST(GPUFeatureInteractionConstraint, Query) {
289279
param.interaction_constraints = constraints_str;
290280

291281
FConstraintWrapper constraints(param, kFeatures);
292-
constraints.Split(/*node_id=*/0, /*feature_id=*/2, /*left_id=*/1, /*right_id=*/2);
282+
constraints.Split(&ctx, /*node_id=*/0, /*feature_id=*/2, /*left_id=*/1, /*right_id=*/2);
293283

294-
std::vector<bst_feature_t> h_input_feature_list {0, 1, 2, 3, 4, 5};
295-
dh::device_vector<bst_feature_t> d_input_feature_list (h_input_feature_list);
284+
std::vector<bst_feature_t> h_input_feature_list{0, 1, 2, 3, 4, 5};
285+
dh::device_vector<bst_feature_t> d_input_feature_list(h_input_feature_list);
296286
common::Span<bst_feature_t> s_input_feature_list = dh::ToSpan(d_input_feature_list);
297287

298288
auto s_output = constraints.Query(s_input_feature_list, 1);
@@ -306,10 +296,10 @@ TEST(GPUFeatureInteractionConstraint, Query) {
306296
std::string const constraints_str = R"constraint([[0, 1]])constraint";
307297
param.interaction_constraints = constraints_str;
308298
FConstraintWrapper constraints(param, kFeatures);
309-
std::vector<bst_feature_t> h_input_feature_list {0, 1, 2, 3, 4, 5};
310-
dh::device_vector<bst_feature_t> d_input_feature_list (h_input_feature_list);
299+
std::vector<bst_feature_t> h_input_feature_list{0, 1, 2, 3, 4, 5};
300+
dh::device_vector<bst_feature_t> d_input_feature_list(h_input_feature_list);
311301
common::Span<bst_feature_t> s_input_feature_list = dh::ToSpan(d_input_feature_list);
312-
constraints.Split(/*node_id=*/0, /*feature_id=*/2, /*left_id=*/1, /*right_id=*/2);
302+
constraints.Split(&ctx, /*node_id=*/0, /*feature_id=*/2, /*left_id=*/1, /*right_id=*/2);
313303
auto s_output = constraints.Query(s_input_feature_list, 1);
314304
CompareFeatureList(s_output, {2});
315305
s_output = constraints.Query(s_input_feature_list, 2);

0 commit comments

Comments
 (0)