@@ -41,6 +41,39 @@ struct AgentMergeSortPolicy
4141 static constexpr cub::BlockStoreAlgorithm STORE_ALGORITHM = StoreAlgorithm;
4242};
4343
44+ // ! The tuning policy for all algorithms in @ref DeviceMergeSort.
45+ struct MergeSortPolicy
46+ {
47+ int threads_per_block; // !< Number of threads in a CUDA block
48+ int items_per_thread; // !< Number of items processed per thread
49+ BlockLoadAlgorithm load_algorithm; // !< The @ref BlockLoadAlgorithm used for loading items from global memory
50+ CacheLoadModifier load_modifier; // !< The @ref CacheLoadModifier used for loading items from global memory
51+ BlockStoreAlgorithm store_algorithm; // !< The @ref BlockStoreAlgorithm used for storing items to global memory
52+
53+ [[nodiscard]] _CCCL_HOST_DEVICE_API constexpr friend bool
54+ operator ==(const MergeSortPolicy& lhs, const MergeSortPolicy& rhs)
55+ {
56+ return lhs.threads_per_block == rhs.threads_per_block && lhs.items_per_thread == rhs.items_per_thread
57+ && lhs.load_algorithm == rhs.load_algorithm && lhs.load_modifier == rhs.load_modifier
58+ && lhs.store_algorithm == rhs.store_algorithm ;
59+ }
60+
61+ [[nodiscard]] _CCCL_HOST_DEVICE_API constexpr friend bool
62+ operator !=(const MergeSortPolicy& lhs, const MergeSortPolicy& rhs)
63+ {
64+ return !(lhs == rhs);
65+ }
66+
67+ #if _CCCL_HOSTED()
68+ friend ::std::ostream& operator <<(::std::ostream& os, const MergeSortPolicy& p)
69+ {
70+ return os << " MergeSortPolicy { .threads_per_block = " << p.threads_per_block
71+ << " , .items_per_thread = " << p.items_per_thread << " , .load_algorithm = " << p.load_algorithm
72+ << " , .load_modifier = " << p.load_modifier << " , .store_algorithm = " << p.store_algorithm << " }" ;
73+ }
74+ #endif // _CCCL_HOSTED()
75+ };
76+
4477namespace detail ::merge_sort
4578{
4679// TODO(bgruber): drop in CCCL 4.0 when we remove all public CUB dispatchers
@@ -87,56 +120,19 @@ struct policy_hub
87120 using MaxPolicy = Policy600;
88121};
89122
90- struct merge_sort_policy
91- {
92- int threads_per_block;
93- int items_per_thread;
94- BlockLoadAlgorithm load_algorithm;
95- CacheLoadModifier load_modifier;
96- BlockStoreAlgorithm store_algorithm;
97-
98- [[nodiscard]] _CCCL_HOST_DEVICE_API constexpr int items_per_tile () const
99- {
100- return threads_per_block * items_per_thread;
101- }
102-
103- [[nodiscard]] _CCCL_HOST_DEVICE_API constexpr friend bool
104- operator ==(const merge_sort_policy& lhs, const merge_sort_policy& rhs)
105- {
106- return lhs.threads_per_block == rhs.threads_per_block && lhs.items_per_thread == rhs.items_per_thread
107- && lhs.load_algorithm == rhs.load_algorithm && lhs.load_modifier == rhs.load_modifier
108- && lhs.store_algorithm == rhs.store_algorithm ;
109- }
110-
111- [[nodiscard]] _CCCL_HOST_DEVICE_API constexpr friend bool
112- operator !=(const merge_sort_policy& lhs, const merge_sort_policy& rhs)
113- {
114- return !(lhs == rhs);
115- }
116-
117- #if _CCCL_HOSTED()
118- friend ::std::ostream& operator <<(::std::ostream& os, const merge_sort_policy& p)
119- {
120- return os << " merge_sort_policy { .threads_per_block = " << p.threads_per_block
121- << " , .items_per_thread = " << p.items_per_thread << " , .load_algorithm = " << p.load_algorithm
122- << " , .load_modifier = " << p.load_modifier << " , .store_algorithm = " << p.store_algorithm << " }" ;
123- }
124- #endif // _CCCL_HOSTED()
125- };
126-
127123#if _CCCL_HAS_CONCEPTS()
128124template <typename T>
129- concept merge_sort_policy_selector = policy_selector<T, merge_sort_policy >;
125+ concept merge_sort_policy_selector = policy_selector<T, MergeSortPolicy >;
130126#endif // _CCCL_HAS_CONCEPTS()
131127
132128struct policy_selector
133129{
134130 int key_size;
135131
136- [[nodiscard]] _CCCL_HOST_DEVICE_API constexpr auto operator ()(::cuda::compute_capability) const -> merge_sort_policy
132+ [[nodiscard]] _CCCL_HOST_DEVICE_API constexpr auto operator ()(::cuda::compute_capability) const -> MergeSortPolicy
137133 {
138134 // from SM60
139- return merge_sort_policy {
135+ return MergeSortPolicy {
140136 256 ,
141137 detail::nominal_4B_items_to_items (17 , key_size),
142138 BLOCK_LOAD_WARP_TRANSPOSE ,
@@ -152,8 +148,7 @@ static_assert(merge_sort_policy_selector<policy_selector>);
152148template <typename KeyIteratorT>
153149struct policy_selector_from_types
154150{
155- [[nodiscard]] _CCCL_HOST_DEVICE_API constexpr auto operator ()(::cuda::compute_capability cc) const
156- -> merge_sort_policy
151+ [[nodiscard]] _CCCL_HOST_DEVICE_API constexpr auto operator ()(::cuda::compute_capability cc) const -> MergeSortPolicy
157152 {
158153 return policy_selector{int {sizeof (it_value_t <KeyIteratorT>)}}(cc);
159154 }
@@ -164,11 +159,11 @@ template <typename PolicyHub>
164159struct policy_selector_from_hub
165160{
166161 // this is only called in device code, so we can ignore the cc parameter
167- _CCCL_DEVICE_API constexpr auto operator ()(::cuda::compute_capability /* cc*/ ) const -> merge_sort_policy
162+ _CCCL_DEVICE_API constexpr auto operator ()(::cuda::compute_capability /* cc*/ ) const -> MergeSortPolicy
168163 {
169164 using ap = typename PolicyHub::MaxPolicy::ActivePolicy;
170165 using mp = typename ap::MergeSortPolicy;
171- return merge_sort_policy {
166+ return MergeSortPolicy {
172167 mp::BLOCK_THREADS , mp::ITEMS_PER_THREAD , mp::LOAD_ALGORITHM , mp::LOAD_MODIFIER , mp::STORE_ALGORITHM };
173168 }
174169};
0 commit comments