Skip to content

Commit 6383b91

Browse files
authored
[cub] Replace cub parameter framework with cuda::argument (#9074)
* Replace cub params framework with argument framework * Rebase fixes * Review feedback * Namespace fix
1 parent c62fff1 commit 6383b91

8 files changed

Lines changed: 197 additions & 347 deletions

File tree

cub/benchmarks/bench/segmented_topk/fixed/keys.cu

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <cub/detail/choose_offset.cuh>
55
#include <cub/device/dispatch/dispatch_batched_topk.cuh>
66

7+
#include <cuda/__argument_>
78
#include <cuda/iterator>
89

910
#include <nvbench_helper.cuh>
@@ -44,33 +45,13 @@ void fixed_seg_size_topk_keys(
4445
nvbench::state& state,
4546
nvbench::type_list<KeyT, nvbench::enum_type<MaxSegmentSize>, nvbench::enum_type<MaxNumSelected>>)
4647
{
47-
// Range of guaranteed total number of items
48-
constexpr auto min_num_total_items = 1;
49-
constexpr auto max_num_total_items = ::cuda::std::numeric_limits<::cuda::std::int32_t>::max();
50-
51-
// Static segment size
52-
using seg_size_t = cub::detail::batched_topk::segment_size_static<MaxSegmentSize>;
53-
54-
// Static k (number of selected output elements per segment)
55-
using k_value_t = cub::detail::batched_topk::k_static<MaxNumSelected>;
56-
57-
// Static selection direction (max)
58-
using select_direction_value_t = cub::detail::batched_topk::select_direction_static<cub::detail::topk::select::max>;
59-
60-
// Number of segments is a host-accessible value
61-
using num_segments_uniform_t = cub::detail::batched_topk::num_segments_uniform<>;
62-
63-
// Total number of items guarantee type
64-
using total_num_items_guarantee_t =
65-
cub::detail::batched_topk::total_num_items_guarantee<min_num_total_items, max_num_total_items>;
66-
6748
// Retrieve axis parameters
6849
const auto max_elements = static_cast<size_t>(state.get_int64("Elements{io}"));
6950
const auto segment_size = static_cast<::cuda::std::ptrdiff_t>(MaxSegmentSize);
7051
const auto selected_elements = static_cast<::cuda::std::ptrdiff_t>(MaxNumSelected);
7152
const auto num_segments = ::cuda::std::max<std::size_t>(1, (max_elements / segment_size));
7253
const auto elements = num_segments * segment_size;
73-
const auto total_num_items = total_num_items_guarantee_t{static_cast<::cuda::std::int64_t>(elements)};
54+
const auto total_num_items = ::cuda::__argument::__immediate{static_cast<::cuda::std::int64_t>(elements)};
7455
const bit_entropy entropy = str_to_entropy(state.get_string("Entropy"));
7556

7657
// Skip workloads where k exceeds the segment size
@@ -87,9 +68,9 @@ void fixed_seg_size_topk_keys(
8768
auto d_keys_in = cuda::make_strided_iterator(cuda::make_counting_iterator(d_keys_in_ptr), segment_size);
8869
auto d_keys_out = cuda::make_strided_iterator(cuda::make_counting_iterator(d_keys_out_ptr), selected_elements);
8970

90-
auto segment_sizes = seg_size_t{};
91-
auto k = k_value_t{};
92-
auto select_directions = select_direction_value_t{};
71+
auto segment_sizes = ::cuda::__argument::__constant<MaxSegmentSize>{};
72+
auto k = ::cuda::__argument::__constant<MaxNumSelected>{};
73+
auto select_direction = ::cuda::__argument::__constant<cub::detail::topk::select::max>{};
9374

9475
state.add_element_count(elements, "NumElements");
9576
state.add_element_count(segment_size, "SegmentSize");
@@ -117,8 +98,8 @@ void fixed_seg_size_topk_keys(
11798
static_cast<cub::NullType**>(nullptr),
11899
segment_sizes,
119100
k,
120-
select_directions,
121-
num_segments_uniform_t{static_cast<::cuda::std::int64_t>(num_segments)},
101+
select_direction,
102+
::cuda::__argument::__immediate{static_cast<::cuda::std::int64_t>(num_segments)},
122103
total_num_items,
123104
env);
124105
});

cub/benchmarks/bench/segmented_topk/variable/keys.cu

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <thrust/reduce.h>
99
#include <thrust/tabulate.h>
1010

11+
#include <cuda/__argument_>
1112
#include <cuda/iterator>
1213
#include <cuda/random>
1314
#include <cuda/std/algorithm>
@@ -171,20 +172,17 @@ void variable_seg_size_topk_keys(nvbench::state& state,
171172
static_cast<cuda::std::int64_t>(MaxSegmentSize));
172173
const auto input_elements = thrust::reduce(d_segment_sizes.begin(), d_segment_sizes.end());
173174
const auto output_elements = static_cast<std::size_t>(num_segments) * K;
174-
const auto total_num_items =
175-
cub::detail::batched_topk::total_num_items_guarantee<1, cuda::std::numeric_limits<cuda::std::int64_t>::max()>{
176-
static_cast<cuda::std::int64_t>(input_elements)};
175+
const auto total_num_items = ::cuda::__argument::__immediate{static_cast<cuda::std::int64_t>(input_elements)};
177176

178177
auto in_keys_buffer = gen_data<MaxSegmentSize, K>(
179178
num_segments, string_to_pattern(state.get_string("Pattern")), thrust::raw_pointer_cast(d_segment_sizes.data()));
180179
auto out_keys_buffer = thrust::device_vector<KeyT>(output_elements, thrust::no_init);
181180

182-
cub::detail::batched_topk::segment_size_per_segment<const cuda::std::int64_t*, 1, MaxSegmentSize> segment_sizes_param{
183-
thrust::raw_pointer_cast(d_segment_sizes.data())};
184-
cub::detail::batched_topk::k_static<K> k_param{};
185-
cub::detail::batched_topk::select_direction_static<cub::detail::topk::select::max> select_directions{};
186-
cub::detail::batched_topk::num_segments_uniform<> num_segments_uniform_param{
187-
static_cast<cuda::std::int64_t>(num_segments)};
181+
auto segment_sizes_param = ::cuda::__argument::__immediate_sequence{
182+
thrust::raw_pointer_cast(d_segment_sizes.data()), ::cuda::__argument::__bounds<1, MaxSegmentSize>()};
183+
auto k_param = ::cuda::__argument::__constant<K>{};
184+
auto select_direction = ::cuda::__argument::__constant<cub::detail::topk::select::max>{};
185+
auto num_segments_param = ::cuda::__argument::__immediate{static_cast<cuda::std::int64_t>(num_segments)};
188186

189187
auto d_keys_in = cuda::make_strided_iterator(
190188
cuda::make_counting_iterator(thrust::raw_pointer_cast(in_keys_buffer.data())),
@@ -210,8 +208,8 @@ void variable_seg_size_topk_keys(nvbench::state& state,
210208
static_cast<cub::NullType**>(nullptr),
211209
segment_sizes_param,
212210
k_param,
213-
select_directions,
214-
num_segments_uniform_param,
211+
select_direction,
212+
num_segments_param,
215213
total_num_items,
216214
env);
217215
});

cub/cub/agent/agent_batched_topk.cuh

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <cub/device/dispatch/tuning/tuning_batched_topk.cuh>
2424
#include <cub/util_type.cuh>
2525

26+
#include <cuda/__argument_>
2627
#include <cuda/__cmath/ceil_div.h>
2728

2829
CUB_NAMESPACE_BEGIN
@@ -72,8 +73,8 @@ struct agent_batched_topk_worker_per_segment
7273
using key_t = it_value_t<key_it_t>;
7374
using value_t = it_value_t<value_it_t>;
7475

75-
using segment_size_val_t = typename SegmentSizeParameterT::value_type;
76-
using num_segments_val_t = typename NumSegmentsParameterT::value_type;
76+
using segment_size_val_t = typename ::cuda::__argument::__traits<SegmentSizeParameterT>::element_type;
77+
using num_segments_val_t = typename ::cuda::__argument::__traits<NumSegmentsParameterT>::element_type;
7778
using counters_t = batched_topk_counters<num_segments_val_t>;
7879

7980
static constexpr auto policy = PolicyGetter{}();
@@ -94,7 +95,7 @@ struct agent_batched_topk_worker_per_segment
9495
multi_worker_per_segment_policy.threads_per_block * multi_worker_per_segment_policy.items_per_thread;
9596

9697
// Check if there could be large segments present
97-
static constexpr bool only_small_segments = params::static_max_value_v<SegmentSizeParameterT> <= tile_size;
98+
static constexpr bool only_small_segments = ::cuda::__argument::__traits<SegmentSizeParameterT>::max <= tile_size;
9899

99100
// Check if we are dealing with keys-only or key-value pairs
100101
static constexpr bool is_keys_only = ::cuda::std::is_same_v<value_t, cub::NullType>;
@@ -190,16 +191,16 @@ struct agent_batched_topk_worker_per_segment
190191

191192
// Boundary check
192193
// TODO (elstehle): consider skipping boundary check if we can safely assume the right grid dimensions
193-
if (segment_id >= num_segments.get_param(0))
194+
if (segment_id >= params::get_param(num_segments, 0))
194195
{
195196
return;
196197
}
197198

198-
constexpr bool is_full_tile = params::has_single_static_value_v<SegmentSizeParameterT>
199-
&& params::static_min_value_v<SegmentSizeParameterT> == tile_size;
199+
constexpr bool is_full_tile = ::cuda::__argument::__traits<SegmentSizeParameterT>::is_constant
200+
&& ::cuda::__argument::__traits<SegmentSizeParameterT>::lowest == tile_size;
200201

201202
// Resolve Segment Parameters
202-
const auto segment_size = segment_sizes.get_param(segment_id);
203+
const auto segment_size = params::get_param(segment_sizes, segment_id);
203204
if (!only_small_segments && segment_size > tile_size)
204205
{
205206
// Enqueue large segment
@@ -215,8 +216,8 @@ struct agent_batched_topk_worker_per_segment
215216
else
216217
{
217218
// Process small segment
218-
const auto k = (::cuda::std::min) (k_param.get_param(segment_id),
219-
static_cast<decltype(k_param.get_param(segment_id))>(segment_size));
219+
const auto k = (::cuda::std::min) (params::get_param(k_param, segment_id),
220+
static_cast<decltype(params::get_param(k_param, segment_id))>(segment_size));
220221
const auto direction = select_directions.get_param(segment_id);
221222

222223
// Determine padding key based on direction

0 commit comments

Comments
 (0)