Skip to content

Commit 5a9ea63

Browse files
Allow public tuning of cub::DeviceMergeSort (#8600)
Fixes: #8574
1 parent fc452a9 commit 5a9ea63

10 files changed

Lines changed: 183 additions & 101 deletions

cub/benchmarks/bench/merge_sort/keys.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@
1919
template <typename KeyT>
2020
struct policy_selector
2121
{
22-
[[nodiscard]] _CCCL_HOST_DEVICE constexpr auto operator()(cuda::compute_capability) const
23-
-> cub::detail::merge_sort::merge_sort_policy
22+
[[nodiscard]] _CCCL_HOST_DEVICE constexpr auto operator()(cuda::compute_capability) const -> cub::MergeSortPolicy
2423
{
25-
return cub::detail::merge_sort::merge_sort_policy{
24+
return cub::MergeSortPolicy{
2625
TUNE_THREADS_PER_BLOCK,
2726
cub::Nominal4BItemsToItems<KeyT>(TUNE_ITEMS_PER_THREAD),
2827
(TUNE_TRANSPOSE == 0 ? cub::BLOCK_LOAD_DIRECT : cub::BLOCK_LOAD_WARP_TRANSPOSE),

cub/benchmarks/bench/merge_sort/pairs.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@
1919
template <typename KeyT>
2020
struct policy_selector
2121
{
22-
[[nodiscard]] _CCCL_HOST_DEVICE constexpr auto operator()(cuda::compute_capability) const
23-
-> cub::detail::merge_sort::merge_sort_policy
22+
[[nodiscard]] _CCCL_HOST_DEVICE constexpr auto operator()(cuda::compute_capability) const -> cub::MergeSortPolicy
2423
{
25-
return cub::detail::merge_sort::merge_sort_policy{
24+
return cub::MergeSortPolicy{
2625
TUNE_THREADS_PER_BLOCK,
2726
cub::Nominal4BItemsToItems<KeyT>(TUNE_ITEMS_PER_THREAD),
2827
(TUNE_TRANSPOSE == 0 ? cub::BLOCK_LOAD_DIRECT : cub::BLOCK_LOAD_WARP_TRANSPOSE),

cub/cub/agent/agent_merge_sort.cuh

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ struct AgentBlockSort
4343

4444
static constexpr bool KEYS_ONLY = ::cuda::std::is_same_v<ValueT, NullType>;
4545

46-
static constexpr merge_sort_policy policy = PolicyGetter{}();
47-
static constexpr int BLOCK_THREADS = policy.threads_per_block;
48-
static constexpr int ITEMS_PER_THREAD = policy.items_per_thread;
49-
static constexpr int ITEMS_PER_TILE = policy.items_per_tile();
46+
static constexpr MergeSortPolicy policy = PolicyGetter{}();
47+
static constexpr int BLOCK_THREADS = policy.threads_per_block;
48+
static constexpr int ITEMS_PER_THREAD = policy.items_per_thread;
49+
static constexpr int ITEMS_PER_TILE = BLOCK_THREADS * ITEMS_PER_THREAD;
5050

5151
using BlockMergeSortT = BlockMergeSort<KeyT, BLOCK_THREADS, ITEMS_PER_THREAD, ValueT>;
5252

@@ -378,10 +378,10 @@ struct AgentMerge
378378

379379
static constexpr bool KEYS_ONLY = ::cuda::std::is_same_v<ValueT, NullType>;
380380

381-
static constexpr merge_sort_policy policy = PolicyGetter{}();
382-
static constexpr int BLOCK_THREADS = policy.threads_per_block;
383-
static constexpr int ITEMS_PER_THREAD = policy.items_per_thread;
384-
static constexpr int ITEMS_PER_TILE = policy.items_per_tile();
381+
static constexpr MergeSortPolicy policy = PolicyGetter{}();
382+
static constexpr int BLOCK_THREADS = policy.threads_per_block;
383+
static constexpr int ITEMS_PER_THREAD = policy.items_per_thread;
384+
static constexpr int ITEMS_PER_TILE = BLOCK_THREADS * ITEMS_PER_THREAD;
385385

386386
using KeysLoadPingIt = try_make_cache_modified_iterator_t<policy.load_modifier, KeyIteratorT>;
387387
using ItemsLoadPingIt = try_make_cache_modified_iterator_t<policy.load_modifier, ValueIteratorT>;
@@ -412,8 +412,8 @@ struct AgentMerge
412412
typename BlockStoreKeysPong::TempStorage store_keys_pong;
413413
typename BlockStoreItemsPong::TempStorage store_items_pong;
414414

415-
KeyT keys_shared[policy.items_per_tile() + 1];
416-
ValueT items_shared[policy.items_per_tile() + 1];
415+
KeyT keys_shared[ITEMS_PER_TILE + 1];
416+
ValueT items_shared[ITEMS_PER_TILE + 1];
417417
};
418418

419419
/// Alias wrapper allowing storage to be unioned

cub/cub/device/device_merge_sort.cuh

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,24 @@ CUB_NAMESPACE_BEGIN
8686
* CustomLess());
8787
* @endcode
8888
*
89+
* @par Tuning
90+
* @rst
91+
* All algorithms in DeviceMergeSort that accept an environment can be tuned by passing a custom :ref:`policy selector
92+
* <cub-policy-selectors>` that returns a @ref MergeSortPolicy, as shown in the example below:
93+
*
94+
* .. literalinclude:: ../../../cub/test/catch2_test_device_merge_sort_env_api.cu
95+
* :language: c++
96+
* :dedent:
97+
* :start-after: example-begin sort-pairs-policy-selector
98+
* :end-before: example-end sort-pairs-policy-selector
99+
*
100+
* .. literalinclude:: ../../../cub/test/catch2_test_device_merge_sort_env_api.cu
101+
* :language: c++
102+
* :dedent:
103+
* :start-after: example-begin sort-pairs-tuning
104+
* :end-before: example-end sort-pairs-tuning
105+
* @endrst
106+
*
89107
* [LessThan Comparable]: https://en.cppreference.com/w/cpp/named_req/LessThanComparable
90108
*/
91109
struct DeviceMergeSort

cub/cub/device/dispatch/dispatch_merge_sort.cuh

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ template <typename KeyInputIteratorT,
111111
typename KernelLauncherFactory = CUB_DETAIL_DEFAULT_KERNEL_LAUNCHER_FACTORY,
112112
typename KeyT = cub::detail::it_value_t<KeyIteratorT>,
113113
typename ValueT = cub::detail::it_value_t<ValueIteratorT>>
114-
struct DispatchMergeSort
114+
struct CCCL_DEPRECATED_BECAUSE("Please use DeviceMergeSort") DispatchMergeSort
115115
{
116116
/// Whether or not there are values to be trucked along with keys
117117
static constexpr bool KEYS_ONLY = ::cuda::std::is_same_v<ValueT, NullType>;
@@ -185,7 +185,7 @@ private:
185185
template <typename ActivePolicyT>
186186
struct policy_getter
187187
{
188-
_CCCL_HOST_DEVICE_API constexpr auto operator()() -> detail::merge_sort::merge_sort_policy
188+
_CCCL_HOST_DEVICE_API constexpr auto operator()() -> MergeSortPolicy
189189
{
190190
using mp = typename ActivePolicyT::MergeSortPolicy;
191191
return {mp::BLOCK_THREADS, mp::ITEMS_PER_THREAD, mp::LOAD_ALGORITHM, mp::LOAD_MODIFIER, mp::STORE_ALGORITHM};
@@ -195,7 +195,7 @@ private:
195195
public:
196196
// Invocation
197197
template <typename ActivePolicyT>
198-
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t Invoke([[maybe_unused]] ActivePolicyT policy = {})
198+
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t Invoke([[maybe_unused]] ActivePolicyT = {})
199199
{
200200
if (num_items == 0)
201201
{
@@ -206,18 +206,21 @@ public:
206206
return cudaSuccess;
207207
}
208208

209-
constexpr auto tile_size =
210-
detail::merge_sort::merge_sort_vsmem_helper_t<
211-
policy_getter<ActivePolicyT>,
212-
KeyInputIteratorT,
213-
ValueInputIteratorT,
214-
KeyIteratorT,
215-
ValueIteratorT,
216-
OffsetT,
217-
CompareOpT,
218-
KeyT,
219-
ValueT>::policy.items_per_tile();
220-
const auto num_tiles = ::cuda::ceil_div(num_items, tile_size);
209+
static constexpr auto policy = detail::merge_sort::merge_sort_vsmem_helper_t<
210+
policy_getter<ActivePolicyT>,
211+
KeyInputIteratorT,
212+
ValueInputIteratorT,
213+
KeyIteratorT,
214+
ValueIteratorT,
215+
OffsetT,
216+
CompareOpT,
217+
KeyT,
218+
ValueT>::policy;
219+
static_assert(1 <= policy.threads_per_block && policy.threads_per_block <= 1024,
220+
"Number of threads per block need to be inside [1;1024]");
221+
static_assert(1 <= policy.items_per_thread, "Number of items per thread needs to be at least 1");
222+
constexpr auto tile_size = policy.threads_per_block * policy.items_per_thread;
223+
const auto num_tiles = ::cuda::ceil_div(num_items, tile_size);
221224

222225
const auto merge_partitions_size = static_cast<size_t>(1 + num_tiles) * sizeof(OffsetT);
223226
const auto temporary_keys_storage_size = static_cast<size_t>(num_items * kernel_source.KeySize());
@@ -279,17 +282,7 @@ public:
279282
auto keys_buffer = static_cast<KeyT*>(allocations[1]);
280283
auto items_buffer = static_cast<ValueT*>(allocations[2]);
281284

282-
const int threads_per_block =
283-
detail::merge_sort::merge_sort_vsmem_helper_t<
284-
policy_getter<ActivePolicyT>,
285-
KeyInputIteratorT,
286-
ValueInputIteratorT,
287-
KeyIteratorT,
288-
ValueIteratorT,
289-
OffsetT,
290-
CompareOpT,
291-
KeyT,
292-
ValueT>::policy.threads_per_block;
285+
const int threads_per_block = policy.threads_per_block;
293286

294287
// Invoke DeviceMergeSortBlockSortKernel
295288
launcher_factory(
@@ -495,7 +488,7 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE auto dispatch(
495488

496489
return detail::dispatch_compute_cap(policy_selector, cc, [&](auto policy_getter) -> cudaError_t {
497490
#ifdef CUB_DEFINE_RUNTIME_POLICIES
498-
const merge_sort_policy active_policy = policy_getter();
491+
const MergeSortPolicy active_policy = policy_getter();
499492
#else // CUB_DEFINE_RUNTIME_POLICIES
500493
using vsmem_adapted_agents = merge_sort_vsmem_helper_t<
501494
decltype(policy_getter),
@@ -507,7 +500,7 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE auto dispatch(
507500
CompareOpT,
508501
KeyT,
509502
ValueT>;
510-
constexpr merge_sort_policy active_policy = vsmem_adapted_agents::policy;
503+
constexpr MergeSortPolicy active_policy = vsmem_adapted_agents::policy;
511504
#endif // CUB_DEFINE_RUNTIME_POLICIES
512505

513506
#if _CCCL_HOSTED() && defined(CUB_DEBUG_LOG)
@@ -521,7 +514,10 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE auto dispatch(
521514
}))
522515
#endif // _CCCL_HOSTED() && defined(CUB_DEBUG_LOG)
523516

524-
const auto tile_size = active_policy.items_per_tile();
517+
_CCCL_ASSERT(1 <= active_policy.threads_per_block && active_policy.threads_per_block <= 1024,
518+
"Number of threads per block need to be inside [1;1024]");
519+
_CCCL_ASSERT(1 <= active_policy.items_per_thread, "Number of items per thread needs to be at least 1");
520+
const auto tile_size = active_policy.threads_per_block * active_policy.items_per_thread;
525521
const auto num_tiles = ::cuda::ceil_div(num_items, tile_size);
526522

527523
const auto merge_partitions_size = static_cast<size_t>(1 + num_tiles) * sizeof(OffsetT);

cub/cub/device/dispatch/kernels/kernel_merge_sort.cuh

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ struct fallback_policy_getter
3030
_CCCL_EXEC_CHECK_DISABLE
3131
_CCCL_HOST_DEVICE_API _CCCL_FORCEINLINE constexpr auto operator()() const
3232
{
33-
merge_sort_policy policy = DefaultPolicyGetter{}();
33+
MergeSortPolicy policy = DefaultPolicyGetter{}();
3434
policy.threads_per_block = 64;
3535
policy.items_per_thread = 1;
3636
return policy;
@@ -75,7 +75,7 @@ class merge_sort_vsmem_helper_impl
7575
(max_default_size > max_smem_per_block) && (max_fallback_size <= max_smem_per_block);
7676

7777
public:
78-
static constexpr merge_sort_policy policy = uses_fallback_policy ? FallbackPolicyGetter{}() : DefaultPolicyGetter{}();
78+
static constexpr MergeSortPolicy policy = uses_fallback_policy ? FallbackPolicyGetter{}() : DefaultPolicyGetter{}();
7979
using block_sort_agent_t =
8080
::cuda::std::_If<uses_fallback_policy, fallback_block_sort_agent_t, default_block_sort_agent_t>;
8181
using merge_agent_t = ::cuda::std::_If<uses_fallback_policy, fallback_merge_agent_t, default_merge_agent_t>;
@@ -166,9 +166,9 @@ __launch_bounds__(
166166
KeyT,
167167
ValueT>;
168168

169-
static constexpr merge_sort_policy active_policy = vsmem_adapted_agents::policy;
170-
using agent_block_sort_t = typename vsmem_adapted_agents::block_sort_agent_t;
171-
using vsmem_helper_t = vsmem_helper_impl<agent_block_sort_t>;
169+
static constexpr MergeSortPolicy active_policy = vsmem_adapted_agents::policy;
170+
using agent_block_sort_t = typename vsmem_adapted_agents::block_sort_agent_t;
171+
using vsmem_helper_t = vsmem_helper_impl<agent_block_sort_t>;
172172

173173
// Static shared memory allocation
174174
__shared__ typename vsmem_helper_t::static_temp_storage_t static_temp_storage;
@@ -267,9 +267,9 @@ __launch_bounds__(
267267
KeyT,
268268
ValueT>;
269269

270-
static constexpr merge_sort_policy active_policy = vsmem_adapted_agents::policy;
271-
using agent_merge_t = typename vsmem_adapted_agents::merge_agent_t;
272-
using vsmem_helper_t = vsmem_helper_impl<agent_merge_t>;
270+
static constexpr MergeSortPolicy active_policy = vsmem_adapted_agents::policy;
271+
using agent_merge_t = typename vsmem_adapted_agents::merge_agent_t;
272+
using vsmem_helper_t = vsmem_helper_impl<agent_merge_t>;
273273

274274
// Static shared memory allocation
275275
__shared__ typename vsmem_helper_t::static_temp_storage_t static_temp_storage;

cub/cub/device/dispatch/tuning/tuning_merge_sort.cuh

Lines changed: 39 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,39 @@ struct AgentMergeSortPolicy
4141
static constexpr cub::BlockStoreAlgorithm STORE_ALGORITHM = StoreAlgorithm;
4242
};
4343

44+
//! The tuning policy for all algorithms in @ref DeviceMergeSort.
45+
struct MergeSortPolicy
46+
{
47+
int threads_per_block; //!< Number of threads in a CUDA block
48+
int items_per_thread; //!< Number of items processed per thread
49+
BlockLoadAlgorithm load_algorithm; //!< The @ref BlockLoadAlgorithm used for loading items from global memory
50+
CacheLoadModifier load_modifier; //!< The @ref CacheLoadModifier used for loading items from global memory
51+
BlockStoreAlgorithm store_algorithm; //!< The @ref BlockStoreAlgorithm used for storing items to global memory
52+
53+
[[nodiscard]] _CCCL_HOST_DEVICE_API constexpr friend bool
54+
operator==(const MergeSortPolicy& lhs, const MergeSortPolicy& rhs)
55+
{
56+
return lhs.threads_per_block == rhs.threads_per_block && lhs.items_per_thread == rhs.items_per_thread
57+
&& lhs.load_algorithm == rhs.load_algorithm && lhs.load_modifier == rhs.load_modifier
58+
&& lhs.store_algorithm == rhs.store_algorithm;
59+
}
60+
61+
[[nodiscard]] _CCCL_HOST_DEVICE_API constexpr friend bool
62+
operator!=(const MergeSortPolicy& lhs, const MergeSortPolicy& rhs)
63+
{
64+
return !(lhs == rhs);
65+
}
66+
67+
#if _CCCL_HOSTED()
68+
friend ::std::ostream& operator<<(::std::ostream& os, const MergeSortPolicy& p)
69+
{
70+
return os << "MergeSortPolicy { .threads_per_block = " << p.threads_per_block
71+
<< ", .items_per_thread = " << p.items_per_thread << ", .load_algorithm = " << p.load_algorithm
72+
<< ", .load_modifier = " << p.load_modifier << ", .store_algorithm = " << p.store_algorithm << " }";
73+
}
74+
#endif // _CCCL_HOSTED()
75+
};
76+
4477
namespace detail::merge_sort
4578
{
4679
// TODO(bgruber): drop in CCCL 4.0 when we remove all public CUB dispatchers
@@ -87,56 +120,19 @@ struct policy_hub
87120
using MaxPolicy = Policy600;
88121
};
89122

90-
struct merge_sort_policy
91-
{
92-
int threads_per_block;
93-
int items_per_thread;
94-
BlockLoadAlgorithm load_algorithm;
95-
CacheLoadModifier load_modifier;
96-
BlockStoreAlgorithm store_algorithm;
97-
98-
[[nodiscard]] _CCCL_HOST_DEVICE_API constexpr int items_per_tile() const
99-
{
100-
return threads_per_block * items_per_thread;
101-
}
102-
103-
[[nodiscard]] _CCCL_HOST_DEVICE_API constexpr friend bool
104-
operator==(const merge_sort_policy& lhs, const merge_sort_policy& rhs)
105-
{
106-
return lhs.threads_per_block == rhs.threads_per_block && lhs.items_per_thread == rhs.items_per_thread
107-
&& lhs.load_algorithm == rhs.load_algorithm && lhs.load_modifier == rhs.load_modifier
108-
&& lhs.store_algorithm == rhs.store_algorithm;
109-
}
110-
111-
[[nodiscard]] _CCCL_HOST_DEVICE_API constexpr friend bool
112-
operator!=(const merge_sort_policy& lhs, const merge_sort_policy& rhs)
113-
{
114-
return !(lhs == rhs);
115-
}
116-
117-
#if _CCCL_HOSTED()
118-
friend ::std::ostream& operator<<(::std::ostream& os, const merge_sort_policy& p)
119-
{
120-
return os << "merge_sort_policy { .threads_per_block = " << p.threads_per_block
121-
<< ", .items_per_thread = " << p.items_per_thread << ", .load_algorithm = " << p.load_algorithm
122-
<< ", .load_modifier = " << p.load_modifier << ", .store_algorithm = " << p.store_algorithm << " }";
123-
}
124-
#endif // _CCCL_HOSTED()
125-
};
126-
127123
#if _CCCL_HAS_CONCEPTS()
128124
template <typename T>
129-
concept merge_sort_policy_selector = policy_selector<T, merge_sort_policy>;
125+
concept merge_sort_policy_selector = policy_selector<T, MergeSortPolicy>;
130126
#endif // _CCCL_HAS_CONCEPTS()
131127

132128
struct policy_selector
133129
{
134130
int key_size;
135131

136-
[[nodiscard]] _CCCL_HOST_DEVICE_API constexpr auto operator()(::cuda::compute_capability) const -> merge_sort_policy
132+
[[nodiscard]] _CCCL_HOST_DEVICE_API constexpr auto operator()(::cuda::compute_capability) const -> MergeSortPolicy
137133
{
138134
// from SM60
139-
return merge_sort_policy{
135+
return MergeSortPolicy{
140136
256,
141137
detail::nominal_4B_items_to_items(17, key_size),
142138
BLOCK_LOAD_WARP_TRANSPOSE,
@@ -152,8 +148,7 @@ static_assert(merge_sort_policy_selector<policy_selector>);
152148
template <typename KeyIteratorT>
153149
struct policy_selector_from_types
154150
{
155-
[[nodiscard]] _CCCL_HOST_DEVICE_API constexpr auto operator()(::cuda::compute_capability cc) const
156-
-> merge_sort_policy
151+
[[nodiscard]] _CCCL_HOST_DEVICE_API constexpr auto operator()(::cuda::compute_capability cc) const -> MergeSortPolicy
157152
{
158153
return policy_selector{int{sizeof(it_value_t<KeyIteratorT>)}}(cc);
159154
}
@@ -164,11 +159,11 @@ template <typename PolicyHub>
164159
struct policy_selector_from_hub
165160
{
166161
// this is only called in device code, so we can ignore the cc parameter
167-
_CCCL_DEVICE_API constexpr auto operator()(::cuda::compute_capability /*cc*/) const -> merge_sort_policy
162+
_CCCL_DEVICE_API constexpr auto operator()(::cuda::compute_capability /*cc*/) const -> MergeSortPolicy
168163
{
169164
using ap = typename PolicyHub::MaxPolicy::ActivePolicy;
170165
using mp = typename ap::MergeSortPolicy;
171-
return merge_sort_policy{
166+
return MergeSortPolicy{
172167
mp::BLOCK_THREADS, mp::ITEMS_PER_THREAD, mp::LOAD_ALGORITHM, mp::LOAD_MODIFIER, mp::STORE_ALGORITHM};
173168
}
174169
};

cub/test/catch2_test_device_merge_sort_custom_policy_hub.cu

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
33

4+
// TODO(bgruber): drop this test with CCCL 4.0 when we drop the merge sort dispatcher
5+
6+
// disable deprecation warnings for DispatchMergeSort
7+
#define CCCL_IGNORE_DEPRECATED_API
8+
49
#include "insert_nested_NVTX_range_guard.h"
510

611
#include <cub/device/device_merge_sort.cuh>
@@ -14,8 +19,6 @@
1419

1520
using namespace cub;
1621

17-
// TODO(bgruber): drop this test with CCCL 4.0 when we drop the merge sort dispatcher after publishing the tuning API
18-
1922
template <typename KeyIteratorT>
2023
struct my_policy_hub
2124
{

0 commit comments

Comments
 (0)