Skip to content

Commit c13f95b

Browse files
authored
[cudax] Make groups unit generic (#8521)
1 parent 19ca1f3 commit c13f95b

5 files changed

Lines changed: 65 additions & 65 deletions

File tree

cudax/include/cuda/experimental/__group/concepts.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
namespace cuda::experimental
3434
{
3535
template <class _Group>
36-
_CCCL_CONCEPT group = _CCCL_REQUIRES_EXPR((_Group), _Group&& __g, const _Group&& __cg)(
36+
_CCCL_CONCEPT is_group = _CCCL_REQUIRES_EXPR((_Group), _Group&& __g, const _Group&& __cg)(
3737
typename(typename _Group::unit_type),
3838
requires(__is_hierarchy_level_v<typename _Group::unit_type>),
3939
typename(typename _Group::level_type),

cudax/include/cuda/experimental/__group/fwd.cuh

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ using __implicit_hierarchy_t =
4444
hierarchy_level_desc<cluster_level, ::cuda::std::dims<3, unsigned>>,
4545
hierarchy_level_desc<block_level, ::cuda::std::dims<3, unsigned>>>;
4646

47-
// this groups
47+
// groups
4848

4949
template <class _Level, class _Hierarchy>
5050
class __this_group_base;
@@ -60,16 +60,8 @@ class this_cluster;
6060
template <class _Hierarchy>
6161
class this_grid;
6262

63-
// other groups
64-
65-
template <class _Level, class _Mapping, class _Hierarchy, class _Synchronizer>
66-
class thread_group;
67-
template <class _Level, class _Mapping, class _Hierarchy, class _Synchronizer>
68-
class warp_group;
69-
template <class _Level, class _Mapping, class _Hierarchy, class _Synchronizer>
70-
class block_group;
71-
template <class _Level, class _Mapping, class _Hierarchy, class _Synchronizer>
72-
class cluster_group;
63+
template <class _Unit, class _Level, class _Mapping, class _Hierarchy, class _Synchronizer>
64+
class group;
7365

7466
// mappings
7567

cudax/include/cuda/experimental/__group/group.cuh

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,10 @@
4444

4545
namespace cuda::experimental
4646
{
47-
// todo(dabayer): Make groups be based on another group, not level.
48-
template <class _Level, class _Mapping, class _Hierarchy, class _Synchronizer>
49-
class thread_group
47+
template <class _Unit, class _Level, class _Mapping, class _Hierarchy, class _Synchronizer>
48+
class group
5049
{
51-
using _MappingResult = __group_mapping_result_t<_Mapping, thread_level, _Level, _Hierarchy>;
50+
using _MappingResult = __group_mapping_result_t<_Mapping, _Unit, _Level, _Hierarchy>;
5251
static_assert(__group_mapping_result<_MappingResult>);
5352

5453
_Hierarchy __hier_;
@@ -57,7 +56,7 @@ class thread_group
5756
_Synchronizer __synchronizer_;
5857

5958
public:
60-
using unit_type = thread_level;
59+
using unit_type = _Unit;
6160
using level_type = _Level;
6261
using mapping_type = _Mapping;
6362
using __mapping_result_type = _MappingResult;
@@ -67,11 +66,11 @@ public:
6766
// todo(dabayer): Do we want default behaviour like this, or do we want some kind of cuda::auto_sync_mechanism{} tag?
6867
_CCCL_TEMPLATE(class _HierarchyLike)
6968
_CCCL_REQUIRES(::cuda::std::is_same_v<_Hierarchy, __hierarchy_type_of<_HierarchyLike>>)
70-
_CCCL_DEVICE_API explicit thread_group(
71-
const _Level&, const _Mapping& __mapping, const _HierarchyLike& __hier_like) noexcept
69+
_CCCL_DEVICE_API explicit group(
70+
const _Unit&, const _Level&, const _Mapping& __mapping, const _HierarchyLike& __hier_like) noexcept
7271
: __hier_{::cuda::__unpack_hierarchy_if_needed(__hier_like)}
7372
, __mapping_{__mapping}
74-
, __mapping_result_{__mapping_.map(thread_level{}, _Level{}, ::cuda::__unpack_hierarchy_if_needed(__hier_like))}
73+
, __mapping_result_{__mapping_.map(_Unit{}, _Level{}, ::cuda::__unpack_hierarchy_if_needed(__hier_like))}
7574
, __synchronizer_{__mapping_result_}
7675
{
7776
::cuda::experimental::__check_mapping_result(__mapping_result_);
@@ -80,14 +79,15 @@ public:
8079
_CCCL_TEMPLATE(class _Synchronizer2 = _Synchronizer, class _MappingResult2 = _MappingResult, class _HierarchyLike)
8180
_CCCL_REQUIRES(__is_barrier_synchronizer<_Synchronizer2>
8281
_CCCL_AND ::cuda::std::is_same_v<_Hierarchy, __hierarchy_type_of<_HierarchyLike>>)
83-
_CCCL_DEVICE_API explicit thread_group(
82+
_CCCL_DEVICE_API explicit group(
83+
const _Unit&,
8484
const _Level&,
8585
const _Mapping& __mapping,
8686
const _HierarchyLike& __hier_like,
8787
::cuda::std::span<typename _Synchronizer2::__barrier_type, _MappingResult::static_group_count()> __barriers) noexcept
8888
: __hier_{::cuda::__unpack_hierarchy_if_needed(__hier_like)}
8989
, __mapping_{__mapping}
90-
, __mapping_result_{__mapping_.map(thread_level{}, _Level{}, ::cuda::__unpack_hierarchy_if_needed(__hier_like))}
90+
, __mapping_result_{__mapping_.map(_Unit{}, _Level{}, ::cuda::__unpack_hierarchy_if_needed(__hier_like))}
9191
, __synchronizer_{__mapping_result_, __barriers}
9292
{
9393
::cuda::experimental::__check_mapping_result(__mapping_result_);
@@ -130,24 +130,28 @@ public:
130130
}
131131
};
132132

133-
_CCCL_TEMPLATE(class _Level, ::cuda::std::size_t _Np, class _HierarchyLike)
134-
_CCCL_REQUIRES(__is_hierarchy_level_v<_Level> _CCCL_AND __is_or_has_hierarchy_member_v<_HierarchyLike>)
135-
_CCCL_HOST_DEVICE thread_group(const _Level&, const group_by<_Np>&, const _HierarchyLike&)
136-
-> thread_group<_Level,
137-
group_by<_Np>,
138-
__hierarchy_type_of<_HierarchyLike>,
139-
__synchronizer_select_t<thread_level, _Level, group_by<_Np>>>;
140-
141-
_CCCL_TEMPLATE(class _Level,
133+
_CCCL_TEMPLATE(class _Unit, class _Level, ::cuda::std::size_t _Np, class _HierarchyLike)
134+
_CCCL_REQUIRES(__is_hierarchy_level_v<_Unit> _CCCL_AND __is_hierarchy_level_v<_Level> _CCCL_AND
135+
__is_or_has_hierarchy_member_v<_HierarchyLike>)
136+
_CCCL_DEVICE group(const _Unit&, const _Level&, const group_by<_Np>&, const _HierarchyLike&)
137+
-> group<_Unit,
138+
_Level,
139+
group_by<_Np>,
140+
__hierarchy_type_of<_HierarchyLike>,
141+
__synchronizer_select_t<_Unit, _Level, group_by<_Np>>>;
142+
143+
_CCCL_TEMPLATE(class _Unit,
144+
class _Level,
142145
::cuda::std::size_t _Np,
143146
class _HierarchyLike,
144147
class _SyncParam,
145-
class _Synchronizer = __barrier_synchronizer<thread_level, _Level, group_by<_Np>>)
148+
class _Synchronizer = __barrier_synchronizer<_Unit, _Level, group_by<_Np>>)
146149
_CCCL_REQUIRES(
147-
__is_hierarchy_level_v<_Level> _CCCL_AND __is_or_has_hierarchy_member_v<_HierarchyLike>
148-
_CCCL_AND ::cuda::std::is_constructible_v<::cuda::std::span<typename _Synchronizer::__barrier_type>, _SyncParam>)
149-
_CCCL_HOST_DEVICE thread_group(const _Level&, const group_by<_Np>&, const _HierarchyLike&, _SyncParam&&)
150-
-> thread_group<_Level, group_by<_Np>, __hierarchy_type_of<_HierarchyLike>, _Synchronizer>;
150+
__is_hierarchy_level_v<_Unit> _CCCL_AND __is_hierarchy_level_v<_Level> _CCCL_AND
151+
__is_or_has_hierarchy_member_v<_HierarchyLike>
152+
_CCCL_AND ::cuda::std::is_constructible_v<::cuda::std::span<typename _Synchronizer::__barrier_type>, _SyncParam>)
153+
_CCCL_DEVICE group(const _Unit&, const _Level&, const group_by<_Np>&, const _HierarchyLike&, _SyncParam&&)
154+
-> group<_Unit, _Level, group_by<_Np>, __hierarchy_type_of<_HierarchyLike>, _Synchronizer>;
151155
} // namespace cuda::experimental
152156

153157
#endif // !_CCCL_DOXYGEN_INVOKED

cudax/test/group/group.cu

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ template <class Unit, class Level, class Hierarchy, class Group>
3333
__device__ void test_common_properties(const Hierarchy&, Group& group)
3434
{
3535
// Assert that Group satisfies the group concept.
36-
static_assert(cudax::group<Group>);
36+
static_assert(cudax::is_group<Group>);
3737

3838
// Test types
3939
static_assert(cuda::std::is_same_v<Unit, typename Group::unit_type>);
@@ -413,7 +413,7 @@ __device__ void test_this_group(const Config& config)
413413
}
414414

415415
template <class Level, cuda::std::size_t N, class Hierarchy, class Sync>
416-
__device__ void test_queries(const cudax::thread_group<Level, cudax::group_by<N>, Hierarchy, Sync>& group)
416+
__device__ void test_queries(const cudax::group<cuda::thread_level, Level, cudax::group_by<N>, Hierarchy, Sync>& group)
417417
{
418418
// todo(dabayer): These queries end up in `error: expression must have a constant value`, when group is taken by
419419
// reference. Can we find a solution that works without copying the group?
@@ -428,18 +428,21 @@ __device__ void test_queries(const cudax::thread_group<Level, cudax::group_by<N>
428428
CUDAX_REQUIRE(cuda::gpu_thread.is_part_of(group));
429429
}
430430

431-
template <class Unit, template <class...> class GroupTempl, class Level, class Config, cuda::std::size_t N>
431+
template <class Unit, class Level, class Config, cuda::std::size_t N>
432432
__device__ void test_group_by_group(const Config& config)
433433
{
434434
// Test statically known group size
435435
{
436436
using Mapping = cudax::group_by<N>;
437437

438-
GroupTempl group{Level{}, Mapping{}, config};
438+
cudax::group group{Unit{}, Level{}, Mapping{}, config};
439439
static_assert(
440-
cuda::std::is_same_v<
441-
GroupTempl<Level, Mapping, typename Config::hierarchy_type, cudax::__synchronizer_select_t<Unit, Level, Mapping>>,
442-
decltype(group)>);
440+
cuda::std::is_same_v<cudax::group<Unit,
441+
Level,
442+
Mapping,
443+
typename Config::hierarchy_type,
444+
cudax::__synchronizer_select_t<Unit, Level, Mapping>>,
445+
decltype(group)>);
443446

444447
test_common_properties<Unit, Level>(config.hierarchy(), group);
445448
test_queries<Level>(group);
@@ -460,10 +463,11 @@ __device__ void test_group_by_group(const Config& config)
460463

461464
auto& barriers = reinterpret_cast<Barrier(&)[nbarriers]>(barrier_storage);
462465

463-
GroupTempl group{Level{}, Mapping{}, config, barriers};
466+
cudax::group group{Unit{}, Level{}, Mapping{}, config, barriers};
464467
static_assert(
465468
cuda::std::is_same_v<
466-
GroupTempl<Level, Mapping, typename Config::hierarchy_type, cudax::__barrier_synchronizer<Unit, Level, Mapping>>,
469+
cudax::
470+
group<Unit, Level, Mapping, typename Config::hierarchy_type, cudax::__barrier_synchronizer<Unit, Level, Mapping>>,
467471
decltype(group)>);
468472

469473
test_common_properties<Unit, Level>(config.hierarchy(), group);
@@ -475,29 +479,29 @@ __device__ void test_group_by_group(const Config& config)
475479
// {
476480
// using Mapping = cudax::group_by<cuda::std::dynamic_extent>;
477481

478-
// GroupTempl group{Level{}, Mapping{static_cast<unsigned>(N)}, config};
482+
// cudax::group group{Unit{}, Level{}, Mapping{static_cast<unsigned>(N)}, config};
479483
// static_assert(
480-
// cuda::std::is_same_v<GroupTempl<Level, Mapping, typename Config::hierarchy_type,
484+
// cuda::std::is_same_v<cudax::group<Unit, Level, Mapping, typename Config::hierarchy_type,
481485
// cudax::__syncwarp_synchronizer<Unit, Level, Mapping>>, decltype(group)>);
482486

483487
// test_common_properties<Unit, Level>(config.hierarchy(), group);
484488
// test_queries<Level>(group);
485489
// }
486490
}
487491

488-
template <class Unit, template <class...> class GroupTempl, class Level, class Config>
492+
template <class Unit, class Level, class Config>
489493
__device__ void test_group_by_group(const Config& config)
490494
{
491495
// powers of 2
492-
test_group_by_group<Unit, GroupTempl, Level, Config, 1>(config);
493-
test_group_by_group<Unit, GroupTempl, Level, Config, 4>(config);
494-
test_group_by_group<Unit, GroupTempl, Level, Config, 16>(config);
495-
test_group_by_group<Unit, GroupTempl, Level, Config, 32>(config);
496+
test_group_by_group<Unit, Level, Config, 1>(config);
497+
test_group_by_group<Unit, Level, Config, 4>(config);
498+
test_group_by_group<Unit, Level, Config, 16>(config);
499+
test_group_by_group<Unit, Level, Config, 32>(config);
496500

497501
if constexpr (cuda::std::is_same_v<Level, cuda::block_level>)
498502
{
499-
test_group_by_group<Unit, GroupTempl, Level, Config, 64>(config);
500-
test_group_by_group<Unit, GroupTempl, Level, Config, 128>(config);
503+
test_group_by_group<Unit, Level, Config, 64>(config);
504+
test_group_by_group<Unit, Level, Config, 128>(config);
501505
}
502506
}
503507

@@ -515,13 +519,13 @@ struct TestKernel
515519
test_this_group<cuda::grid_level, cudax::this_grid>(config);
516520

517521
// todo: allow this once hierarchy is queryable for missing levels
518-
test_group_by_group<cuda::thread_level, cudax::thread_group, cuda::warp_level>(config);
519-
test_group_by_group<cuda::thread_level, cudax::thread_group, cuda::block_level>(config);
522+
test_group_by_group<cuda::thread_level, cuda::warp_level>(config);
523+
test_group_by_group<cuda::thread_level, cuda::block_level>(config);
520524
if constexpr (Hierarchy::has_level(cuda::cluster))
521525
{
522-
test_group_by_group<cuda::thread_level, cudax::thread_group, cuda::cluster_level>(config);
526+
test_group_by_group<cuda::thread_level, cuda::cluster_level>(config);
523527
}
524-
test_group_by_group<cuda::thread_level, cudax::thread_group, cuda::grid_level>(config);
528+
test_group_by_group<cuda::thread_level, cuda::grid_level>(config);
525529
}
526530
};
527531

libcudacxx/include/cuda/__hierarchy/hierarchy_level_base.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -174,49 +174,49 @@ struct hierarchy_level_base
174174
# if _CCCL_CUDA_COMPILATION()
175175

176176
_CCCL_TEMPLATE(class _Group)
177-
_CCCL_REQUIRES(::cuda::experimental::group<_Group>)
177+
_CCCL_REQUIRES(::cuda::experimental::is_group<_Group>)
178178
[[nodiscard]] _CCCL_API static constexpr ::cuda::std::size_t static_count(const _Group&) noexcept
179179
{
180180
return ::cuda::experimental::__static_count_query_group<_Level, _Group>();
181181
}
182182

183183
_CCCL_TEMPLATE(class _Group)
184-
_CCCL_REQUIRES(::cuda::experimental::group<_Group>)
184+
_CCCL_REQUIRES(::cuda::experimental::is_group<_Group>)
185185
[[nodiscard]] _CCCL_API static constexpr auto count(const _Group& __group) noexcept
186186
{
187187
return count_as<__default_1d_query_type<typename _Group::unit_type>>(__group);
188188
}
189189

190190
_CCCL_TEMPLATE(class _Group)
191-
_CCCL_REQUIRES(::cuda::experimental::group<_Group>)
191+
_CCCL_REQUIRES(::cuda::experimental::is_group<_Group>)
192192
[[nodiscard]] _CCCL_API static auto rank(const _Group& __group) noexcept
193193
{
194194
return rank_as<__default_1d_query_type<typename _Group::unit_type>>(__group);
195195
}
196196

197197
_CCCL_TEMPLATE(class _Tp, class _Group)
198-
_CCCL_REQUIRES(::cuda::std::__cccl_is_integer_v<_Tp> _CCCL_AND ::cuda::experimental::group<_Group>)
198+
_CCCL_REQUIRES(::cuda::std::__cccl_is_integer_v<_Tp> _CCCL_AND ::cuda::experimental::is_group<_Group>)
199199
[[nodiscard]] _CCCL_API static constexpr _Tp count_as(const _Group& __group) noexcept
200200
{
201201
return ::cuda::experimental::__count_query_group<_Tp, _Level>(__group);
202202
}
203203

204204
_CCCL_TEMPLATE(class _Tp, class _Group)
205-
_CCCL_REQUIRES(::cuda::std::__cccl_is_integer_v<_Tp> _CCCL_AND ::cuda::experimental::group<_Group>)
205+
_CCCL_REQUIRES(::cuda::std::__cccl_is_integer_v<_Tp> _CCCL_AND ::cuda::experimental::is_group<_Group>)
206206
[[nodiscard]] _CCCL_API static _Tp rank_as(const _Group& __group) noexcept
207207
{
208208
return ::cuda::experimental::__rank_query_group<_Tp, _Level>(__group);
209209
}
210210

211211
_CCCL_TEMPLATE(class _Group)
212-
_CCCL_REQUIRES(::cuda::experimental::group<_Group>)
212+
_CCCL_REQUIRES(::cuda::experimental::is_group<_Group>)
213213
[[nodiscard]] _CCCL_DEVICE_API static constexpr bool is_root_rank(const _Group& __group) noexcept
214214
{
215215
return _Level::rank(__group) == 0;
216216
}
217217

218218
_CCCL_TEMPLATE(class _Group)
219-
_CCCL_REQUIRES(::cuda::experimental::group<_Group>)
219+
_CCCL_REQUIRES(::cuda::experimental::is_group<_Group>)
220220
[[nodiscard]] _CCCL_API static constexpr bool is_part_of(const _Group& __group) noexcept
221221
{
222222
// todo: static_assert that the _Level <= _Group::unit_type

0 commit comments

Comments
 (0)