44#include < cub/detail/choose_offset.cuh>
55#include < cub/device/dispatch/dispatch_batched_topk.cuh>
66
7+ #include < cuda/__argument_>
78#include < cuda/iterator>
89
910#include < nvbench_helper.cuh>
@@ -44,33 +45,13 @@ void fixed_seg_size_topk_keys(
4445 nvbench::state& state,
4546 nvbench::type_list<KeyT, nvbench::enum_type<MaxSegmentSize>, nvbench::enum_type<MaxNumSelected>>)
4647{
47- // Range of guaranteed total number of items
48- constexpr auto min_num_total_items = 1 ;
49- constexpr auto max_num_total_items = ::cuda::std::numeric_limits<::cuda::std::int32_t >::max ();
50-
51- // Static segment size
52- using seg_size_t = cub::detail::batched_topk::segment_size_static<MaxSegmentSize>;
53-
54- // Static k (number of selected output elements per segment)
55- using k_value_t = cub::detail::batched_topk::k_static<MaxNumSelected>;
56-
57- // Static selection direction (max)
58- using select_direction_value_t = cub::detail::batched_topk::select_direction_static<cub::detail::topk::select::max>;
59-
60- // Number of segments is a host-accessible value
61- using num_segments_uniform_t = cub::detail::batched_topk::num_segments_uniform<>;
62-
63- // Total number of items guarantee type
64- using total_num_items_guarantee_t =
65- cub::detail::batched_topk::total_num_items_guarantee<min_num_total_items, max_num_total_items>;
66-
6748 // Retrieve axis parameters
6849 const auto max_elements = static_cast <size_t >(state.get_int64 (" Elements{io}" ));
6950 const auto segment_size = static_cast <::cuda::std::ptrdiff_t >(MaxSegmentSize);
7051 const auto selected_elements = static_cast <::cuda::std::ptrdiff_t >(MaxNumSelected);
7152 const auto num_segments = ::cuda::std::max<std::size_t >(1 , (max_elements / segment_size));
7253 const auto elements = num_segments * segment_size;
73- const auto total_num_items = total_num_items_guarantee_t {static_cast <::cuda::std::int64_t >(elements)};
54+ const auto total_num_items = ::cuda::__argument::__immediate {static_cast <::cuda::std::int64_t >(elements)};
7455 const bit_entropy entropy = str_to_entropy (state.get_string (" Entropy" ));
7556
7657 // Skip workloads where k exceeds the segment size
@@ -87,9 +68,9 @@ void fixed_seg_size_topk_keys(
8768 auto d_keys_in = cuda::make_strided_iterator (cuda::make_counting_iterator (d_keys_in_ptr), segment_size);
8869 auto d_keys_out = cuda::make_strided_iterator (cuda::make_counting_iterator (d_keys_out_ptr), selected_elements);
8970
90- auto segment_sizes = seg_size_t {};
91- auto k = k_value_t {};
92- auto select_directions = select_direction_value_t {};
71+ auto segment_sizes = ::cuda::__argument::__constant<MaxSegmentSize> {};
72+ auto k = ::cuda::__argument::__constant<MaxNumSelected> {};
73+ auto select_direction = ::cuda::__argument::__constant<cub::detail::topk::select::max> {};
9374
9475 state.add_element_count (elements, " NumElements" );
9576 state.add_element_count (segment_size, " SegmentSize" );
@@ -117,8 +98,8 @@ void fixed_seg_size_topk_keys(
11798 static_cast <cub::NullType**>(nullptr ),
11899 segment_sizes,
119100 k,
120- select_directions ,
121- num_segments_uniform_t {static_cast <::cuda::std::int64_t >(num_segments)},
101+ select_direction ,
102+ ::cuda::__argument::__immediate {static_cast <::cuda::std::int64_t >(num_segments)},
122103 total_num_items,
123104 env);
124105 });
0 commit comments