@@ -21,15 +21,11 @@ struct FConstraintWrapper : public FeatureInteractionConstraintDevice {
2121 common::Span<LBitField64> GetNodeConstraints () {
2222 return FeatureInteractionConstraintDevice::s_node_constraints_;
2323 }
24- FConstraintWrapper (tree::TrainParam param, bst_feature_t n_features) :
25- FeatureInteractionConstraintDevice (param, n_features) {}
24+ FConstraintWrapper (tree::TrainParam param, bst_feature_t n_features)
25+ : FeatureInteractionConstraintDevice(param, n_features) {}
2626
27- dh::device_vector<bst_feature_t > const & GetDSets () const {
28- return d_sets_;
29- }
30- dh::device_vector<size_t > const & GetDSetsPtr () const {
31- return d_sets_ptr_;
32- }
27+ dh::device_vector<bst_feature_t > const & GetDSets () const { return d_sets_; }
28+ dh::device_vector<size_t > const & GetDSetsPtr () const { return d_sets_ptr_; }
3329};
3430
3531std::string GetConstraintsStr () {
@@ -46,12 +42,11 @@ tree::TrainParam GetParameter() {
4642
4743void CompareBitField (LBitField64 d_field, std::set<uint32_t > positions) {
4844 std::vector<LBitField64::value_type> h_field_storage (d_field.Bits ().size ());
49- thrust::copy (thrust::device_ptr<LBitField64::value_type>(d_field.Bits ().data ()),
50- thrust::device_ptr<LBitField64::value_type>(
51- d_field.Bits ().data () + d_field.Bits ().size ()),
52- h_field_storage.data ());
53- LBitField64 h_field{ {h_field_storage.data (),
54- h_field_storage.data () + h_field_storage.size ()} };
45+ thrust::copy (
46+ thrust::device_ptr<LBitField64::value_type>(d_field.Bits ().data ()),
47+ thrust::device_ptr<LBitField64::value_type>(d_field.Bits ().data () + d_field.Bits ().size ()),
48+ h_field_storage.data ());
49+ LBitField64 h_field{{h_field_storage.data (), h_field_storage.data () + h_field_storage.size ()}};
5550
5651 for (size_t i = 0 ; i < h_field.Capacity (); ++i) {
5752 if (positions.find (i) != positions.cend ()) {
@@ -64,7 +59,6 @@ void CompareBitField(LBitField64 d_field, std::set<uint32_t> positions) {
6459
6560} // anonymous namespace
6661
67-
6862TEST (GPUFeatureInteractionConstraint, Init) {
6963 {
7064 int32_t constexpr kFeatures = 6 ;
@@ -75,12 +69,10 @@ TEST(GPUFeatureInteractionConstraint, Init) {
7569 for (LBitField64 const & d_node : s_nodes_constraints) {
7670 std::vector<LBitField64::value_type> h_node_storage (d_node.Bits ().size ());
7771 thrust::copy (thrust::device_ptr<LBitField64::value_type const >(d_node.Bits ().data ()),
78- thrust::device_ptr<LBitField64::value_type const >(
79- d_node. Bits (). data () + d_node.Bits ().size ()),
72+ thrust::device_ptr<LBitField64::value_type const >(d_node. Bits (). data () +
73+ d_node.Bits ().size ()),
8074 h_node_storage.data ());
81- LBitField64 h_node {
82- {h_node_storage.data (), h_node_storage.data () + h_node_storage.size ()}
83- };
75+ LBitField64 h_node{{h_node_storage.data (), h_node_storage.data () + h_node_storage.size ()}};
8476 // no feature is attached to node.
8577 for (size_t i = 0 ; i < h_node.Capacity (); ++i) {
8678 ASSERT_FALSE (h_node.Check (i));
@@ -94,8 +86,8 @@ TEST(GPUFeatureInteractionConstraint, Init) {
9486 tree::TrainParam param = GetParameter ();
9587 param.interaction_constraints = R"( [[0, 1, 3], [3, 5, 6]])" ;
9688 FConstraintWrapper constraints (param, kFeatures );
97- std::vector<bst_feature_t > h_sets {0 , 0 , 0 , 1 , 1 , 1 };
98- std::vector<size_t > h_sets_ptr {0 , 1 , 2 , 2 , 4 , 4 , 5 , 6 };
89+ std::vector<bst_feature_t > h_sets{0 , 0 , 0 , 1 , 1 , 1 };
90+ std::vector<size_t > h_sets_ptr{0 , 1 , 2 , 2 , 4 , 4 , 5 , 6 };
9991 auto d_sets = constraints.GetDSets ();
10092 ASSERT_EQ (h_sets.size (), d_sets.size ());
10193 auto d_sets_ptr = constraints.GetDSetsPtr ();
@@ -120,18 +112,19 @@ TEST(GPUFeatureInteractionConstraint, Init) {
120112 auto _128_end = d_sets_ptr[128 + 1 ];
121113 ASSERT_EQ (_128_end - _128_beg, 2 );
122114 ASSERT_EQ (d_sets[_128_beg], 1 );
123- ASSERT_EQ (d_sets[_128_end- 1 ], 2 );
115+ ASSERT_EQ (d_sets[_128_end - 1 ], 2 );
124116 }
125117}
126118
127119TEST (GPUFeatureInteractionConstraint, Split) {
120+ auto ctx = MakeCUDACtx (0 );
128121 tree::TrainParam param = GetParameter ();
129122 int32_t constexpr kFeatures = 6 ;
130123 FConstraintWrapper constraints (param, kFeatures );
131124
132125 {
133126 LBitField64 d_node[3 ];
134- constraints.Split (0 , /* feature_id=*/ 1 , 1 , 2 );
127+ constraints.Split (&ctx, 0 , /* feature_id=*/ 1 , 1 , 2 );
135128 for (size_t nid = 0 ; nid < 3 ; ++nid) {
136129 d_node[nid] = constraints.GetNodeConstraints ()[nid];
137130 ASSERT_EQ (d_node[nid].Bits ().size (), 1 );
@@ -141,7 +134,7 @@ TEST(GPUFeatureInteractionConstraint, Split) {
141134
142135 {
143136 LBitField64 d_node[5 ];
144- constraints.Split (1 , /* feature_id=*/ 0 , /* left_id=*/ 3 , /* right_id=*/ 4 );
137+ constraints.Split (&ctx, 1 , /* feature_id=*/ 0 , /* left_id=*/ 3 , /* right_id=*/ 4 );
145138 for (auto nid : {1 , 3 , 4 }) {
146139 d_node[nid] = constraints.GetNodeConstraints ()[nid];
147140 CompareBitField (d_node[nid], {0 , 1 , 2 });
@@ -165,24 +158,22 @@ TEST(GPUFeatureInteractionConstraint, QueryNode) {
165158 }
166159
167160 {
168- constraints.Split (/* node_id=*/ 0 , /* feature_id=*/ 1 , 1 , 2 );
161+ constraints.Split (&ctx, /* node_id=*/ 0 , /* feature_id=*/ 1 , 1 , 2 );
169162 auto span = constraints.QueryNode (&ctx, 0 );
170- std::vector<bst_feature_t > h_result (span.size ());
163+ std::vector<bst_feature_t > h_result (span.size ());
171164 thrust::copy (thrust::device_ptr<bst_feature_t >(span.data ()),
172- thrust::device_ptr<bst_feature_t >(span.data () + span.size ()),
173- h_result.begin ());
165+ thrust::device_ptr<bst_feature_t >(span.data () + span.size ()), h_result.begin ());
174166 ASSERT_EQ (h_result.size (), 2 );
175167 ASSERT_EQ (h_result[0 ], 1 );
176168 ASSERT_EQ (h_result[1 ], 2 );
177169 }
178170
179171 {
180- constraints.Split (1 , /* feature_id=*/ 0 , 3 , 4 );
172+ constraints.Split (&ctx, 1 , /* feature_id=*/ 0 , 3 , 4 );
181173 auto span = constraints.QueryNode (&ctx, 1 );
182- std::vector<bst_feature_t > h_result (span.size ());
174+ std::vector<bst_feature_t > h_result (span.size ());
183175 thrust::copy (thrust::device_ptr<bst_feature_t >(span.data ()),
184- thrust::device_ptr<bst_feature_t >(span.data () + span.size ()),
185- h_result.begin ());
176+ thrust::device_ptr<bst_feature_t >(span.data () + span.size ()), h_result.begin ());
186177 ASSERT_EQ (h_result.size (), 3 );
187178 ASSERT_EQ (h_result[0 ], 0 );
188179 ASSERT_EQ (h_result[1 ], 1 );
@@ -192,8 +183,7 @@ TEST(GPUFeatureInteractionConstraint, QueryNode) {
192183 span = constraints.QueryNode (&ctx, 3 );
193184 h_result.resize (span.size ());
194185 thrust::copy (thrust::device_ptr<bst_feature_t >(span.data ()),
195- thrust::device_ptr<bst_feature_t >(span.data () + span.size ()),
196- h_result.begin ());
186+ thrust::device_ptr<bst_feature_t >(span.data () + span.size ()), h_result.begin ());
197187 ASSERT_EQ (h_result.size (), 3 );
198188 ASSERT_EQ (h_result[0 ], 0 );
199189 ASSERT_EQ (h_result[1 ], 1 );
@@ -204,12 +194,11 @@ TEST(GPUFeatureInteractionConstraint, QueryNode) {
204194 tree::TrainParam large_param = GetParameter ();
205195 large_param.interaction_constraints = R"( [[1, 139], [244, 0], [139, 221]])" ;
206196 FConstraintWrapper large_features (large_param, 256 );
207- large_features.Split (0 , 139 , 1 , 2 );
197+ large_features.Split (&ctx, 0 , 139 , 1 , 2 );
208198 auto span = large_features.QueryNode (&ctx, 0 );
209- std::vector<bst_feature_t > h_result (span.size ());
199+ std::vector<bst_feature_t > h_result (span.size ());
210200 thrust::copy (thrust::device_ptr<bst_feature_t >(span.data ()),
211- thrust::device_ptr<bst_feature_t >(span.data () + span.size ()),
212- h_result.begin ());
201+ thrust::device_ptr<bst_feature_t >(span.data () + span.size ()), h_result.begin ());
213202 ASSERT_EQ (h_result.size (), 3 );
214203 ASSERT_EQ (h_result[0 ], 1 );
215204 ASSERT_EQ (h_result[1 ], 139 );
@@ -230,12 +219,13 @@ void CompareFeatureList(common::Span<bst_feature_t const> s_output,
230219} // anonymous namespace
231220
232221TEST (GPUFeatureInteractionConstraint, Query) {
222+ auto ctx = MakeCUDACtx (0 );
233223 {
234224 tree::TrainParam param = GetParameter ();
235225 bst_feature_t constexpr kFeatures = 6 ;
236226 FConstraintWrapper constraints (param, kFeatures );
237- std::vector<bst_feature_t > h_input_feature_list {0 , 1 , 2 , 3 , 4 , 5 };
238- dh::device_vector<bst_feature_t > d_input_feature_list (h_input_feature_list);
227+ std::vector<bst_feature_t > h_input_feature_list{0 , 1 , 2 , 3 , 4 , 5 };
228+ dh::device_vector<bst_feature_t > d_input_feature_list (h_input_feature_list);
239229 common::Span<bst_feature_t > s_input_feature_list = dh::ToSpan (d_input_feature_list);
240230
241231 auto s_output = constraints.Query (s_input_feature_list, 0 );
@@ -245,9 +235,9 @@ TEST(GPUFeatureInteractionConstraint, Query) {
245235 tree::TrainParam param = GetParameter ();
246236 bst_feature_t constexpr kFeatures = 6 ;
247237 FConstraintWrapper constraints (param, kFeatures );
248- constraints.Split (/* node_id=*/ 0 , /* feature_id=*/ 1 , /* left_id=*/ 1 , /* right_id=*/ 2 );
249- constraints.Split (/* node_id=*/ 1 , /* feature_id=*/ 0 , /* left_id=*/ 3 , /* right_id=*/ 4 );
250- constraints.Split (/* node_id=*/ 4 , /* feature_id=*/ 3 , /* left_id=*/ 5 , /* right_id=*/ 6 );
238+ constraints.Split (&ctx, /* node_id=*/ 0 , /* feature_id=*/ 1 , /* left_id=*/ 1 , /* right_id=*/ 2 );
239+ constraints.Split (&ctx, /* node_id=*/ 1 , /* feature_id=*/ 0 , /* left_id=*/ 3 , /* right_id=*/ 4 );
240+ constraints.Split (&ctx, /* node_id=*/ 4 , /* feature_id=*/ 3 , /* left_id=*/ 5 , /* right_id=*/ 6 );
251241 /*
252242 * (node id) [allowed features]
253243 *
@@ -263,8 +253,8 @@ TEST(GPUFeatureInteractionConstraint, Query) {
263253 *
264254 */
265255
266- std::vector<bst_feature_t > h_input_feature_list {0 , 1 , 2 , 3 , 4 , 5 };
267- dh::device_vector<bst_feature_t > d_input_feature_list (h_input_feature_list);
256+ std::vector<bst_feature_t > h_input_feature_list{0 , 1 , 2 , 3 , 4 , 5 };
257+ dh::device_vector<bst_feature_t > d_input_feature_list (h_input_feature_list);
268258 common::Span<bst_feature_t > s_input_feature_list = dh::ToSpan (d_input_feature_list);
269259
270260 auto s_output = constraints.Query (s_input_feature_list, 1 );
@@ -289,10 +279,10 @@ TEST(GPUFeatureInteractionConstraint, Query) {
289279 param.interaction_constraints = constraints_str;
290280
291281 FConstraintWrapper constraints (param, kFeatures );
292- constraints.Split (/* node_id=*/ 0 , /* feature_id=*/ 2 , /* left_id=*/ 1 , /* right_id=*/ 2 );
282+ constraints.Split (&ctx, /* node_id=*/ 0 , /* feature_id=*/ 2 , /* left_id=*/ 1 , /* right_id=*/ 2 );
293283
294- std::vector<bst_feature_t > h_input_feature_list {0 , 1 , 2 , 3 , 4 , 5 };
295- dh::device_vector<bst_feature_t > d_input_feature_list (h_input_feature_list);
284+ std::vector<bst_feature_t > h_input_feature_list{0 , 1 , 2 , 3 , 4 , 5 };
285+ dh::device_vector<bst_feature_t > d_input_feature_list (h_input_feature_list);
296286 common::Span<bst_feature_t > s_input_feature_list = dh::ToSpan (d_input_feature_list);
297287
298288 auto s_output = constraints.Query (s_input_feature_list, 1 );
@@ -306,10 +296,10 @@ TEST(GPUFeatureInteractionConstraint, Query) {
306296 std::string const constraints_str = R"constraint( [[0, 1]])constraint" ;
307297 param.interaction_constraints = constraints_str;
308298 FConstraintWrapper constraints (param, kFeatures );
309- std::vector<bst_feature_t > h_input_feature_list {0 , 1 , 2 , 3 , 4 , 5 };
310- dh::device_vector<bst_feature_t > d_input_feature_list (h_input_feature_list);
299+ std::vector<bst_feature_t > h_input_feature_list{0 , 1 , 2 , 3 , 4 , 5 };
300+ dh::device_vector<bst_feature_t > d_input_feature_list (h_input_feature_list);
311301 common::Span<bst_feature_t > s_input_feature_list = dh::ToSpan (d_input_feature_list);
312- constraints.Split (/* node_id=*/ 0 , /* feature_id=*/ 2 , /* left_id=*/ 1 , /* right_id=*/ 2 );
302+ constraints.Split (&ctx, /* node_id=*/ 0 , /* feature_id=*/ 2 , /* left_id=*/ 1 , /* right_id=*/ 2 );
313303 auto s_output = constraints.Query (s_input_feature_list, 1 );
314304 CompareFeatureList (s_output, {2 });
315305 s_output = constraints.Query (s_input_feature_list, 2 );
0 commit comments