2727#include " hist_util.h"
2828#include " quantile.cuh"
2929#include " quantile.h"
30- #include " transform_iterator.h" // MakeIndexTransformIter
3130#include " xgboost/span.h"
3231
3332namespace xgboost ::common {
@@ -663,19 +662,6 @@ void SketchContainer::AllReduce(Context const *ctx, bool is_column_split) {
663662 LOG (FATAL ) << " Distributed GPU quantile sketch reduction requires NCCL support." ;
664663}
665664
666- namespace {
667- struct InvalidCatOp {
668- Span<SketchEntry const > values;
669- Span<size_t const > ptrs;
670- Span<FeatureType const > ft;
671-
672- XGBOOST_DEVICE bool operator ()(size_t i) const {
673- auto fidx = dh::SegmentId (ptrs, i);
674- return IsCat (ft, fidx) && InvalidCat (values[i].value );
675- }
676- };
677- } // anonymous namespace
678-
679665HistogramCuts SketchContainer::MakeCuts (Context const *ctx, bool is_column_split) {
680666 curt::SetDevice (ctx->Ordinal ());
681667 HistogramCuts cuts{num_columns_};
@@ -685,133 +671,46 @@ HistogramCuts SketchContainer::MakeCuts(Context const *ctx, bool is_column_split
685671 this ->AllReduce (ctx, is_column_split);
686672
687673 timer_.Start (__func__);
688- // Prune to final number of bins.
689- this ->Prune (ctx, num_bins_ + 1 );
690-
691- // Set up inputs
692- auto d_in_columns_ptr = this ->columns_ptr_ .ConstDeviceSpan ();
693-
694- auto const in_cut_values = dh::ToSpan (this ->entries_ );
695-
696- // Set up output ptr
697- p_cuts->cut_ptrs_ .SetDevice (ctx->Device ());
698674 auto &h_out_columns_ptr = p_cuts->cut_ptrs_ .HostVector ();
699- h_out_columns_ptr.front () = 0 ;
700- auto const &h_feature_types = this ->feature_types_ .ConstHostSpan ();
675+ h_out_columns_ptr.assign (num_columns_ + 1 , 0 );
676+ auto &h_out_cut_values = p_cuts->cut_values_ .HostVector ();
677+ h_out_cut_values.clear ();
701678
702- auto d_ft = feature_types_.ConstDeviceSpan ();
679+ auto const &h_in_columns_ptr = this ->columns_ptr_ .ConstHostVector ();
680+ std::vector<SketchEntry> h_entries (this ->entries_ .size ());
681+ dh::CopyDeviceSpanToVector (&h_entries, dh::ToSpan (this ->entries_ ));
682+ auto const &h_feature_types = this ->feature_types_ .ConstHostSpan ();
703683
704- std::vector<SketchEntry> max_values;
705684 float max_cat{-1 .f };
706- if (has_categorical_) {
707- auto key_it = dh::MakeTransformIterator<bst_feature_t >(
708- thrust::make_counting_iterator (0ul ), [=] XGBOOST_DEVICE (size_t i) -> bst_feature_t {
709- return dh::SegmentId (d_in_columns_ptr, i);
710- });
711- auto invalid_op = InvalidCatOp{in_cut_values, d_in_columns_ptr, d_ft};
712- auto val_it = dh::MakeTransformIterator<SketchEntry>(
713- thrust::make_counting_iterator (0ul ), [=] XGBOOST_DEVICE (size_t i) {
714- auto fidx = dh::SegmentId (d_in_columns_ptr, i);
715- auto v = in_cut_values[i];
716- if (IsCat (d_ft, fidx)) {
717- if (invalid_op (i)) {
718- // use inf to indicate invalid value, this way we can keep it as in
719- // indicator in the reduce operation as it's always the greatest value.
720- v.value = std::numeric_limits<float >::infinity ();
721- }
722- }
723- return v;
724- });
725- CHECK_EQ (num_columns_, d_in_columns_ptr.size () - 1 );
726- max_values.resize (d_in_columns_ptr.size () - 1 );
727-
728- // In some cases (e.g. column-wise data split), we may have empty columns, so we need to keep
729- // track of the unique keys (feature indices) after the thrust::reduce_by_key` call.
730- dh::caching_device_vector<size_t > d_max_keys (d_in_columns_ptr.size () - 1 );
731- dh::caching_device_vector<SketchEntry> d_max_values (d_in_columns_ptr.size () - 1 );
732- auto new_end = thrust::reduce_by_key (
733- ctx->CUDACtx ()->CTP (), key_it, key_it + in_cut_values.size (), val_it, d_max_keys.begin (),
734- d_max_values.begin (), thrust::equal_to<bst_feature_t >{},
735- [] __device__ (auto l, auto r) { return l.value > r.value ? l : r; });
736- d_max_keys.erase (new_end.first , d_max_keys.end ());
737- d_max_values.erase (new_end.second , d_max_values.end ());
738-
739- // The device vector needs to be initialized explicitly since we may have some missing columns.
740- SketchEntry default_entry{};
741- dh::caching_device_vector<SketchEntry> d_max_results (d_in_columns_ptr.size () - 1 ,
742- default_entry);
743- thrust::scatter (ctx->CUDACtx ()->CTP (), d_max_values.begin (), d_max_values.end (),
744- d_max_keys.begin (), d_max_results.begin ());
745- dh::CopyDeviceSpanToVector (&max_values, dh::ToSpan (d_max_results));
746- auto max_it = MakeIndexTransformIter ([&](auto i) {
747- if (IsCat (h_feature_types, i)) {
748- return max_values[i].value ;
749- }
750- return -1 .f ;
751- });
752- max_cat = *std::max_element (max_it, max_it + max_values.size ());
753- if (std::isinf (max_cat)) {
754- InvalidCategory ();
755- }
756- }
757-
758- // Set up output cuts
685+ WQSummaryContainer summary;
759686 for (bst_feature_t i = 0 ; i < num_columns_; ++i) {
760- size_t column_size = std::max (static_cast <size_t >(1ul ), this ->Column (i).size ());
687+ auto begin = h_in_columns_ptr[i];
688+ auto end = h_in_columns_ptr[i + 1 ];
689+ auto column = Span<SketchEntry const >{h_entries.data () + begin, end - begin};
690+
761691 if (IsCat (h_feature_types, i)) {
762- // column_size is the number of unique values in that feature.
763- CheckMaxCat (max_values[i].value , column_size);
764- h_out_columns_ptr[i + 1 ] = max_values[i].value + 1 ; // includes both max_cat and 0.
692+ auto column_size = std::max (static_cast <std::size_t >(1 ), column.size ());
693+ auto feature_max = column.empty () ? 0 .0f : column.back ().value ;
694+ if (std::any_of (column.cbegin (), column.cend (),
695+ [](auto const &entry) { return InvalidCat (entry.value ); })) {
696+ InvalidCategory ();
697+ }
698+ CheckMaxCat (feature_max, column_size);
699+ max_cat = std::max (max_cat, feature_max);
700+ for (std::size_t cat = 0 ; cat <= static_cast <std::size_t >(feature_max); ++cat) {
701+ h_out_cut_values.push_back (cat);
702+ }
765703 } else {
766- h_out_columns_ptr[i + 1 ] =
767- std::min (static_cast <size_t >(column_size), static_cast <size_t >(num_bins_));
704+ summary.Reserve (column.size ());
705+ std::copy (column.cbegin (), column.cend (), summary.space .begin ());
706+ summary.SetSize (column.size ());
707+ auto queried = summary.QueryCutValues (static_cast <std::size_t >(num_bins_));
708+ h_out_cut_values.insert (h_out_cut_values.end (), queried.cbegin (), queried.cend ());
768709 }
710+ h_out_columns_ptr[i + 1 ] = h_out_cut_values.size ();
769711 }
770- std::partial_sum (h_out_columns_ptr.begin (), h_out_columns_ptr.end (), h_out_columns_ptr.begin ());
771- auto d_out_columns_ptr = p_cuts->cut_ptrs_ .ConstDeviceSpan ();
772-
773- size_t total_bins = h_out_columns_ptr.back ();
774- p_cuts->cut_values_ .SetDevice (ctx->Device ());
775- p_cuts->cut_values_ .Resize (total_bins);
776- auto out_cut_values = p_cuts->cut_values_ .DeviceSpan ();
777-
778- dh::LaunchN (total_bins, [=] __device__ (size_t idx) {
779- auto column_id = dh::SegmentId (d_out_columns_ptr, idx);
780- auto in_column = in_cut_values.subspan (
781- d_in_columns_ptr[column_id], d_in_columns_ptr[column_id + 1 ] - d_in_columns_ptr[column_id]);
782- auto out_column =
783- out_cut_values.subspan (d_out_columns_ptr[column_id],
784- d_out_columns_ptr[column_id + 1 ] - d_out_columns_ptr[column_id]);
785- idx -= d_out_columns_ptr[column_id];
786- if (in_column.size () == 0 ) {
787- // If the column is empty, we push a dummy value. It won't affect training as the
788- // column is empty, trees cannot split on it. This is just to be consistent with
789- // rest of the library.
790- if (idx == 0 ) {
791- out_column[0 ] = kRtEps ;
792- assert (out_column.size () == 1 );
793- }
794- return ;
795- }
796-
797- if (IsCat (d_ft, column_id)) {
798- out_column[idx] = idx;
799- return ;
800- }
801-
802- // Last thread is responsible for setting a value that's greater than other cuts.
803- if (idx == out_column.size () - 1 ) {
804- const bst_float cpt = in_column.back ().value ;
805- // this must be bigger than last value in a scale
806- const bst_float last = cpt + (fabs (cpt) + 1e-5 );
807- out_column[idx] = last;
808- return ;
809- }
810- assert (idx + 1 < in_column.size ());
811- out_column[idx] = in_column[idx + 1 ].value ;
812- });
813-
814712 p_cuts->SetCategorical (this ->has_categorical_ , max_cat);
713+ p_cuts->SetDevice (ctx->Device ());
815714 timer_.Stop (__func__);
816715 return cuts;
817716}
0 commit comments