Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 10 additions & 38 deletions src/common/hist_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ void SortByWeight(Context const* ctx, dh::device_vector<float>* weights,
}

void RemoveDuplicatedCategories(Context const* ctx, MetaInfo const& info,
Span<bst_idx_t> d_cuts_ptr,
dh::device_vector<Entry>* p_sorted_entries,
dh::device_vector<float>* p_sorted_weights,
dh::caching_device_vector<size_t>* p_column_sizes_scan) {
Expand Down Expand Up @@ -100,27 +99,9 @@ void RemoveDuplicatedCategories(Context const* ctx, MetaInfo const& info,
}
sorted_entries.resize(n_uniques);

// Renew the column scan and cut scan based on categorical data.
dh::caching_device_vector<SketchContainer::OffsetT> new_cuts_size(info.num_col_ + 1);
CHECK_EQ(new_column_scan.size(), new_cuts_size.size());
dh::LaunchN(new_column_scan.size(), ctx->CUDACtx()->Stream(),
[=, d_new_cuts_size = dh::ToSpan(new_cuts_size),
d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan),
d_new_columns_ptr = dh::ToSpan(new_column_scan)] __device__(size_t idx) {
d_old_column_sizes_scan[idx] = d_new_columns_ptr[idx];
if (idx == d_new_columns_ptr.size() - 1) {
return;
}
if (IsCat(d_feature_types, idx)) {
// Cut size is the same as number of categories in input.
d_new_cuts_size[idx] = d_new_columns_ptr[idx + 1] - d_new_columns_ptr[idx];
} else {
d_new_cuts_size[idx] = d_cuts_ptr[idx + 1] - d_cuts_ptr[idx];
}
});
// Turn size into ptr.
thrust::exclusive_scan(ctx->CUDACtx()->CTP(), new_cuts_size.cbegin(), new_cuts_size.cend(),
d_cuts_ptr.data());
// Renew the column scan based on categorical data. Numerical columns preserve their original
// span, while categorical columns shrink to their unique category count.
column_sizes_scan = std::move(new_column_scan);
}
} // namespace detail

Expand All @@ -141,7 +122,7 @@ namespace {
void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo const& info,
std::size_t begin, std::size_t end,
SketchContainer* sketch_container, // <- output sketch
int num_cuts_per_feature, common::Span<float const> sample_weight) {
common::Span<float const> sample_weight) {
dh::device_vector<Entry> sorted_entries;
if (page.data.DeviceCanRead()) {
// direct copy if data is already on device
Expand Down Expand Up @@ -175,35 +156,28 @@ void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo c
detail::EntryCompareOp());
}

HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
dh::caching_device_vector<size_t> column_sizes_scan;
data::IsValidFunctor dummy_is_valid(std::numeric_limits<float>::quiet_NaN());
auto batch_it = dh::MakeTransformIterator<data::COOTuple>(
sorted_entries.data().get(), [] __device__(Entry const& e) -> data::COOTuple {
return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size.
});
detail::GetColumnSizesScan(ctx->CUDACtx(), ctx->Device(), info.num_col_, num_cuts_per_feature,
IterSpan{batch_it, sorted_entries.size()}, dummy_is_valid, &cuts_ptr,
detail::GetColumnSizesScan(ctx->CUDACtx(), ctx->Device(), info.num_col_,
IterSpan{batch_it, sorted_entries.size()}, dummy_is_valid,
&column_sizes_scan);
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
if (sketch_container->HasCategorical()) {
auto p_weight = entry_weight.empty() ? nullptr : &entry_weight;
detail::RemoveDuplicatedCategories(ctx, info, d_cuts_ptr, &sorted_entries, p_weight,
&column_sizes_scan);
detail::RemoveDuplicatedCategories(ctx, info, &sorted_entries, p_weight, &column_sizes_scan);
}

auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());

// Add cuts into sketches
auto n_rows_in_batch = RowsInEntrySpan(page, begin, end);
sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
h_cuts_ptr.back(), n_rows_in_batch, dh::ToSpan(entry_weight));
sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan),
n_rows_in_batch, dh::ToSpan(entry_weight));

sorted_entries.clear();
sorted_entries.shrink_to_fit();
CHECK_EQ(sorted_entries.capacity(), 0);
CHECK_NE(cuts_ptr.Size(), 0);
}

// Unify group weight, Hessian, and sample weight into sample weight.
Expand Down Expand Up @@ -275,7 +249,6 @@ HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_b
HostDeviceVector<float> weight;
weight.SetDevice(ctx->Device());

std::size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(max_bin, info.num_row_);
auto sketch_batch_num_elements = detail::kSketchBatchNumElements;

CUDAContext const* cuctx = ctx->CUDACtx();
Expand All @@ -290,8 +263,7 @@ HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_b
for (auto begin = 0ull; begin < page_nnz; begin += sketch_batch_num_elements) {
std::size_t end =
std::min(page_nnz, static_cast<std::size_t>(begin + sketch_batch_num_elements));
ProcessWeightedBatch(ctx, page, info, begin, end, &sketch_container, num_cuts_per_feature,
d_weight);
ProcessWeightedBatch(ctx, page, info, begin, end, &sketch_container, d_weight);
Comment thread
RAMitchell marked this conversation as resolved.
}
}

Expand Down
67 changes: 18 additions & 49 deletions src/common/hist_util.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -139,23 +139,12 @@ void LaunchGetColumnSizeKernel(CUDAContext const* cuctx, DeviceOrd device,

template <typename BatchIt>
void GetColumnSizesScan(CUDAContext const* cuctx, DeviceOrd device, size_t num_columns,
std::size_t num_cuts_per_feature, IterSpan<BatchIt> batch_iter,
data::IsValidFunctor is_valid,
HostDeviceVector<SketchContainer::OffsetT>* cuts_ptr,
IterSpan<BatchIt> batch_iter, data::IsValidFunctor is_valid,
dh::caching_device_vector<size_t>* column_sizes_scan) {
column_sizes_scan->resize(num_columns + 1);
cuts_ptr->SetDevice(device);
cuts_ptr->Resize(num_columns + 1, 0);

auto d_column_sizes_scan = dh::ToSpan(*column_sizes_scan);
LaunchGetColumnSizeKernel(cuctx, device, batch_iter, is_valid, d_column_sizes_scan);
// Calculate cuts CSC pointer
auto cut_ptr_it = dh::MakeTransformIterator<size_t>(
column_sizes_scan->begin(), [=] __device__(size_t column_size) {
return thrust::min(num_cuts_per_feature, column_size);
});
thrust::exclusive_scan(cuctx->CTP(), cut_ptr_it,
cut_ptr_it + column_sizes_scan->size(), cuts_ptr->DevicePointer());
thrust::exclusive_scan(cuctx->CTP(), column_sizes_scan->begin(), column_sizes_scan->end(),
column_sizes_scan->begin());
}
Expand All @@ -170,8 +159,7 @@ size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows);
template <typename AdapterBatch, typename BatchIter>
void MakeEntriesFromAdapter(CUDAContext const* cuctx, AdapterBatch const& batch,
BatchIter batch_iter, Range1d range, float missing, size_t columns,
size_t cuts_per_feature, DeviceOrd device,
HostDeviceVector<SketchContainer::OffsetT>* cut_sizes_scan,
DeviceOrd device,
dh::caching_device_vector<size_t>* column_sizes_scan,
dh::device_vector<Entry>* sorted_entries) {
auto entry_iter = dh::MakeTransformIterator<Entry>(
Expand All @@ -182,8 +170,7 @@ void MakeEntriesFromAdapter(CUDAContext const* cuctx, AdapterBatch const& batch,
auto span = IterSpan{batch_iter + range.begin(), n};
data::IsValidFunctor is_valid(missing);
// Work out how many valid entries we have in each column
GetColumnSizesScan(cuctx, device, columns, cuts_per_feature, span, is_valid, cut_sizes_scan,
column_sizes_scan);
GetColumnSizesScan(cuctx, device, columns, span, is_valid, column_sizes_scan);
size_t num_valid = column_sizes_scan->back();
// Copy current subset of valid elements into temporary storage and sort
sorted_entries->resize(num_valid);
Expand All @@ -195,7 +182,6 @@ void SortByWeight(Context const* ctx, dh::device_vector<float>* weights,
dh::device_vector<Entry>* sorted_entries);

void RemoveDuplicatedCategories(Context const* ctx, MetaInfo const& info,
Span<bst_idx_t> d_cuts_ptr,
dh::device_vector<Entry>* p_sorted_entries,
dh::device_vector<float>* p_sorted_weights,
dh::caching_device_vector<size_t>* p_column_sizes_scan);
Expand Down Expand Up @@ -231,43 +217,35 @@ inline HistogramCuts DeviceSketch(Context const* ctx, DMatrix* p_fmat, bst_bin_t
template <typename AdapterBatch>
void ProcessSlidingWindow(Context const* ctx, AdapterBatch const& batch, MetaInfo const& info,
size_t n_features, size_t begin, size_t end, float missing,
SketchContainer* sketch_container, int num_cuts,
bst_idx_t approx_n_samples) {
SketchContainer* sketch_container, bst_idx_t approx_n_samples) {
// Copy current subset of valid elements into temporary storage and sort
dh::device_vector<Entry> sorted_entries;
dh::caching_device_vector<size_t> column_sizes_scan;
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
thrust::make_counting_iterator(0llu),
[=] __device__(size_t idx) { return batch.GetElement(idx); });
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
cuts_ptr.SetDevice(ctx->Device());
CUDAContext const* cuctx = ctx->CUDACtx();
detail::MakeEntriesFromAdapter(cuctx, batch, batch_iter, {begin, end}, missing, n_features,
num_cuts, ctx->Device(), &cuts_ptr, &column_sizes_scan,
&sorted_entries);
ctx->Device(), &column_sizes_scan, &sorted_entries);
thrust::sort(cuctx->TP(), sorted_entries.begin(), sorted_entries.end(), detail::EntryCompareOp());

if (sketch_container->HasCategorical()) {
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
detail::RemoveDuplicatedCategories(ctx, info, d_cuts_ptr, &sorted_entries, nullptr,
&column_sizes_scan);
detail::RemoveDuplicatedCategories(ctx, info, &sorted_entries, nullptr, &column_sizes_scan);
}

auto d_cuts_ptr = cuts_ptr.DeviceSpan();
auto const& h_cuts_ptr = cuts_ptr.HostVector();
// Extract the cuts from all columns concurrently
sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
h_cuts_ptr.back(), approx_n_samples);
sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan),
approx_n_samples);

sorted_entries.clear();
sorted_entries.shrink_to_fit();
}

template <typename Batch>
void ProcessWeightedSlidingWindow(Context const* ctx, Batch batch, MetaInfo const& info,
int num_cuts_per_feature, bool is_ranking, float missing,
size_t columns, size_t begin, size_t end,
SketchContainer* sketch_container, bst_idx_t approx_n_samples) {
bool is_ranking, float missing, size_t columns, size_t begin,
size_t end, SketchContainer* sketch_container,
bst_idx_t approx_n_samples) {
curt::SetDevice(ctx->Ordinal());
info.weights_.SetDevice(ctx->Device());
auto weights = info.weights_.ConstDeviceSpan();
Expand All @@ -278,10 +256,8 @@ void ProcessWeightedSlidingWindow(Context const* ctx, Batch batch, MetaInfo cons
auto cuctx = ctx->CUDACtx();
dh::device_vector<Entry> sorted_entries;
dh::caching_device_vector<size_t> column_sizes_scan;
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
detail::MakeEntriesFromAdapter(cuctx, batch, batch_iter, {begin, end}, missing, columns,
num_cuts_per_feature, ctx->Device(), &cuts_ptr, &column_sizes_scan,
&sorted_entries);
ctx->Device(), &column_sizes_scan, &sorted_entries);
data::IsValidFunctor is_valid(missing);

dh::device_vector<float> temp_weights(sorted_entries.size());
Expand Down Expand Up @@ -323,17 +299,13 @@ void ProcessWeightedSlidingWindow(Context const* ctx, Batch batch, MetaInfo cons
detail::SortByWeight(ctx, &temp_weights, &sorted_entries);

if (sketch_container->HasCategorical()) {
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
detail::RemoveDuplicatedCategories(ctx, info, d_cuts_ptr, &sorted_entries, &temp_weights,
detail::RemoveDuplicatedCategories(ctx, info, &sorted_entries, &temp_weights,
&column_sizes_scan);
}

auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
auto d_cuts_ptr = cuts_ptr.DeviceSpan();

// Extract cuts
sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
h_cuts_ptr.back(), approx_n_samples, dh::ToSpan(temp_weights));
sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan),
approx_n_samples, dh::ToSpan(temp_weights));
sorted_entries.clear();
sorted_entries.shrink_to_fit();
}
Expand All @@ -359,7 +331,6 @@ void AdapterDeviceSketch(Context const* ctx, Batch batch, bst_bin_t num_bins, Me
bst_idx_t begin = 0;

while (begin < kRemaining) {
auto num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows);
auto remaining = kRemaining - begin;
auto sketch_batch_num_elements = std::min(detail::kSketchBatchNumElements, remaining);
// Re-estimate the needed number of cuts based on the size of the sub-batch.
Expand All @@ -370,17 +341,15 @@ void AdapterDeviceSketch(Context const* ctx, Batch batch, bst_bin_t num_bins, Me
// dense assumption.
auto approx_n_samples =
std::max(common::DivRoundUp(sketch_batch_num_elements, num_cols), bst_idx_t{1});
num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, approx_n_samples);
bst_idx_t end =
std::min(batch.Size(), static_cast<std::size_t>(begin + sketch_batch_num_elements));

if (weighted) {
ProcessWeightedSlidingWindow(ctx, batch, info, num_cuts_per_feature,
HostSketchContainer::UseGroup(info), missing, num_cols, begin,
end, sketch_container, approx_n_samples);
ProcessWeightedSlidingWindow(ctx, batch, info, HostSketchContainer::UseGroup(info), missing,
num_cols, begin, end, sketch_container, approx_n_samples);
} else {
ProcessSlidingWindow(ctx, batch, info, num_cols, begin, end, missing, sketch_container,
num_cuts_per_feature, approx_n_samples);
approx_n_samples);
}
Comment thread
RAMitchell marked this conversation as resolved.
begin += sketch_batch_num_elements;
}
Expand Down
Loading
Loading