@@ -50,6 +50,20 @@ auto MakeFullRowSplitDMatrix(std::size_t rows_per_worker, std::size_t cols, std:
5050 }
5151 return GetDMatrixFromData (full_data, rows_per_worker * world, cols);
5252}
53+
54+ auto MakeHostSummary (std::vector<std::pair<float , float >> const & items)
55+ -> common::WQSummaryContainer {
56+ common::WQSummaryContainer summary;
57+ summary.Reserve (items.size ());
58+ summary.SetFromSorted (items);
59+ return summary;
60+ }
61+
62+ auto CopySummaryEntries (common::WQSummaryContainer const & summary)
63+ -> std::vector<common::SketchEntry> {
64+ auto entries = summary.Entries ();
65+ return {entries.cbegin (), entries.cend ()};
66+ }
5367} // namespace
5468
5569namespace common {
@@ -251,14 +265,19 @@ TEST(GPUQuantile, MergeBasic) {
251265 auto columns_ptr = sketch_0.ColumnsPtr ();
252266 std::vector<bst_idx_t > h_columns_ptr (columns_ptr.size ());
253267 dh::CopyDeviceSpanToVector (&h_columns_ptr, columns_ptr);
254- ASSERT_EQ (h_columns_ptr.back (), sketch_1.Data ().size () + size_before_merge);
268+ ASSERT_LE (h_columns_ptr.back (), sketch_1.Data ().size () + size_before_merge);
255269
256270 std::vector<SketchEntry> h_data (sketch_0.Data ().size ());
257271 dh::CopyDeviceSpanToVector (&h_data, sketch_0.Data ());
272+ ASSERT_EQ (static_cast <std::size_t >(h_columns_ptr.back ()), h_data.size ());
258273 for (size_t i = 1 ; i < h_columns_ptr.size (); ++i) {
259274 auto begin = h_columns_ptr[i - 1 ];
260275 auto column = Span<SketchEntry>{h_data}.subspan (begin, h_columns_ptr[i] - begin);
261276 ASSERT_TRUE (std::is_sorted (column.begin (), column.end (), IsSorted{}));
277+ ASSERT_TRUE (std::adjacent_find (column.begin (), column.end (),
278+ [](SketchEntry const & l, SketchEntry const & r) {
279+ return l.value == r.value ;
280+ }) == column.end ());
262281 }
263282 });
264283}
@@ -309,14 +328,19 @@ void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) {
309328 auto columns_ptr = sketch_0.ColumnsPtr ();
310329 std::vector<bst_idx_t > h_columns_ptr (columns_ptr.size ());
311330 dh::CopyDeviceSpanToVector (&h_columns_ptr, columns_ptr);
312- ASSERT_EQ (h_columns_ptr.back (), sketch_1.Data ().size () + size_before_merge);
331+ ASSERT_LE (h_columns_ptr.back (), sketch_1.Data ().size () + size_before_merge);
313332
314333 std::vector<SketchEntry> h_data (sketch_0.Data ().size ());
315334 dh::CopyDeviceSpanToVector (&h_data, sketch_0.Data ());
335+ ASSERT_EQ (static_cast <std::size_t >(h_columns_ptr.back ()), h_data.size ());
316336 for (size_t i = 1 ; i < h_columns_ptr.size (); ++i) {
317337 auto begin = h_columns_ptr[i - 1 ];
318338 auto column = Span<SketchEntry>{h_data}.subspan (begin, h_columns_ptr[i] - begin);
319339 ASSERT_TRUE (std::is_sorted (column.begin (), column.end (), IsSorted{}));
340+ ASSERT_TRUE (std::adjacent_find (column.begin (), column.end (),
341+ [](SketchEntry const & l, SketchEntry const & r) {
342+ return l.value == r.value ;
343+ }) == column.end ());
320344 }
321345}
322346
@@ -370,6 +394,84 @@ TEST(GPUQuantile, MergeCategorical) {
370394 }) == cat_column.end ());
371395}
372396
397+ TEST (GPUQuantile, MergeSameValue) {
398+ auto ctx = MakeCUDACtx (0 );
399+ constexpr bst_feature_t kCols = 1 ;
400+ bst_bin_t n_bins = 16 ;
401+
402+ HostDeviceVector<FeatureType> ft;
403+ SketchContainer sketch_0 (ft, n_bins, kCols , ctx.Device ());
404+ SketchContainer sketch_1 (ft, n_bins, kCols , ctx.Device ());
405+
406+ std::vector<Entry> entries_0{{0 , 0 .5f }};
407+ std::vector<Entry> entries_1{{0 , 0 .5f }};
408+ dh::device_vector<Entry> d_entries_0{entries_0};
409+ dh::device_vector<Entry> d_entries_1{entries_1};
410+ dh::device_vector<size_t > columns_ptr{0 , 1 };
411+ dh::device_vector<size_t > cuts_ptr{0 , 1 };
412+
413+ sketch_0.Push (&ctx, dh::ToSpan (d_entries_0), dh::ToSpan (columns_ptr), dh::ToSpan (cuts_ptr), 1 , 1 ,
414+ {});
415+ sketch_1.Push (&ctx, dh::ToSpan (d_entries_1), dh::ToSpan (columns_ptr), dh::ToSpan (cuts_ptr), 1 , 1 ,
416+ {});
417+
418+ sketch_0.Merge (&ctx, sketch_1.ColumnsPtr (), sketch_1.Data ());
419+
420+ std::vector<bst_idx_t > h_columns_ptr (sketch_0.ColumnsPtr ().size ());
421+ dh::CopyDeviceSpanToVector (&h_columns_ptr, sketch_0.ColumnsPtr ());
422+ std::vector<SketchEntry> h_data (sketch_0.Data ().size ());
423+ dh::CopyDeviceSpanToVector (&h_data, sketch_0.Data ());
424+
425+ ASSERT_EQ (h_columns_ptr.back (), 1 );
426+ ASSERT_EQ (h_data.size (), 1 );
427+ EXPECT_FLOAT_EQ (h_data.front ().value , 0 .5f );
428+ EXPECT_FLOAT_EQ (h_data.front ().rmin , 0 .0f );
429+ EXPECT_FLOAT_EQ (h_data.front ().wmin , 2 .0f );
430+ EXPECT_FLOAT_EQ (h_data.front ().rmax , 2 .0f );
431+ }
432+
433+ TEST (GPUQuantile, MergeMatchesCpuCombine) {
434+ auto ctx = MakeCUDACtx (0 );
435+ constexpr bst_feature_t kCols = 1 ;
436+ bst_bin_t n_bins = 16 ;
437+
438+ auto lhs = MakeHostSummary ({{0 .1f , 1 .0f }, {0 .3f , 2 .0f }, {0 .5f , 1 .0f }});
439+ auto rhs = MakeHostSummary ({{0 .3f , 1 .5f }, {0 .4f , 1 .0f }, {0 .5f , 0 .5f }});
440+
441+ common::WQSummaryContainer expected;
442+ expected.Reserve (lhs.Size () + rhs.Size ());
443+ expected.CopyFrom (lhs);
444+ expected.SetCombine (rhs);
445+
446+ auto lhs_entries = CopySummaryEntries (lhs);
447+ auto rhs_entries = CopySummaryEntries (rhs);
448+
449+ dh::device_vector<SketchEntry> d_lhs{lhs_entries};
450+ dh::device_vector<SketchEntry> d_rhs{rhs_entries};
451+ dh::device_vector<size_t > lhs_ptr{0 , lhs.Size ()};
452+ dh::device_vector<size_t > rhs_ptr{0 , rhs.Size ()};
453+
454+ HostDeviceVector<FeatureType> ft;
455+ SketchContainer sketch (ft, n_bins, kCols , ctx.Device ());
456+ sketch.Merge (&ctx, dh::ToSpan (lhs_ptr), dh::ToSpan (d_lhs));
457+ sketch.Merge (&ctx, dh::ToSpan (rhs_ptr), dh::ToSpan (d_rhs));
458+
459+ std::vector<bst_idx_t > h_columns_ptr (sketch.ColumnsPtr ().size ());
460+ dh::CopyDeviceSpanToVector (&h_columns_ptr, sketch.ColumnsPtr ());
461+ auto h_data = std::vector<SketchEntry>(sketch.Data ().size ());
462+ dh::CopyDeviceSpanToVector (&h_data, sketch.Data ());
463+
464+ ASSERT_EQ (h_columns_ptr.back (), expected.Size ());
465+ auto expected_entries = expected.Entries ();
466+ ASSERT_EQ (h_data.size (), expected_entries.size ());
467+ for (std::size_t i = 0 ; i < h_data.size (); ++i) {
468+ EXPECT_FLOAT_EQ (h_data[i].value , expected_entries[i].value );
469+ EXPECT_FLOAT_EQ (h_data[i].rmin , expected_entries[i].rmin );
470+ EXPECT_FLOAT_EQ (h_data[i].rmax , expected_entries[i].rmax );
471+ EXPECT_FLOAT_EQ (h_data[i].wmin , expected_entries[i].wmin );
472+ }
473+ }
474+
373475TEST (GPUQuantile, MultiMerge) {
374476 constexpr size_t kRows = 20 , kCols = 1 ;
375477 int32_t world = 2 ;
0 commit comments