@@ -33,7 +33,7 @@ template <class Unit, class Level, class Hierarchy, class Group>
3333__device__ void test_common_properties (const Hierarchy&, Group& group)
3434{
3535 // Assert that Group satisfies the group concept.
36- static_assert (cudax::group <Group>);
36+ static_assert (cudax::is_group <Group>);
3737
3838 // Test types
3939 static_assert (cuda::std::is_same_v<Unit, typename Group::unit_type>);
@@ -413,7 +413,7 @@ __device__ void test_this_group(const Config& config)
413413}
414414
415415template <class Level , cuda::std::size_t N, class Hierarchy , class Sync >
416- __device__ void test_queries (const cudax::thread_group< Level, cudax::group_by<N>, Hierarchy, Sync>& group)
416+ __device__ void test_queries (const cudax::group<cuda::thread_level, Level, cudax::group_by<N>, Hierarchy, Sync>& group)
417417{
418418 // todo(dabayer): These queries end up in `error: expression must have a constant value`, when group is taken by
419419 // reference. Can we find a solution that works without copying the group?
@@ -428,18 +428,21 @@ __device__ void test_queries(const cudax::thread_group<Level, cudax::group_by<N>
428428 CUDAX_REQUIRE (cuda::gpu_thread.is_part_of (group));
429429}
430430
431- template <class Unit , template < class ...> class GroupTempl , class Level , class Config , cuda::std::size_t N>
431+ template <class Unit , class Level , class Config , cuda::std::size_t N>
432432__device__ void test_group_by_group (const Config& config)
433433{
434434 // Test statically known group size
435435 {
436436 using Mapping = cudax::group_by<N>;
437437
438- GroupTempl group{Level{}, Mapping{}, config};
438+ cudax::group group{Unit{}, Level{}, Mapping{}, config};
439439 static_assert (
440- cuda::std::is_same_v<
441- GroupTempl<Level, Mapping, typename Config::hierarchy_type, cudax::__synchronizer_select_t <Unit, Level, Mapping>>,
442- decltype (group)>);
440+ cuda::std::is_same_v<cudax::group<Unit,
441+ Level,
442+ Mapping,
443+ typename Config::hierarchy_type,
444+ cudax::__synchronizer_select_t <Unit, Level, Mapping>>,
445+ decltype (group)>);
443446
444447 test_common_properties<Unit, Level>(config.hierarchy (), group);
445448 test_queries<Level>(group);
@@ -460,10 +463,11 @@ __device__ void test_group_by_group(const Config& config)
460463
461464 auto & barriers = reinterpret_cast <Barrier (&)[nbarriers]>(barrier_storage);
462465
463- GroupTempl group{Level{}, Mapping{}, config, barriers};
466+ cudax::group group{Unit{}, Level{}, Mapping{}, config, barriers};
464467 static_assert (
465468 cuda::std::is_same_v<
466- GroupTempl<Level, Mapping, typename Config::hierarchy_type, cudax::__barrier_synchronizer<Unit, Level, Mapping>>,
469+ cudax::
470+ group<Unit, Level, Mapping, typename Config::hierarchy_type, cudax::__barrier_synchronizer<Unit, Level, Mapping>>,
467471 decltype (group)>);
468472
469473 test_common_properties<Unit, Level>(config.hierarchy (), group);
@@ -475,29 +479,29 @@ __device__ void test_group_by_group(const Config& config)
475479 // {
476480 // using Mapping = cudax::group_by<cuda::std::dynamic_extent>;
477481
478- // GroupTempl group{Level{}, Mapping{static_cast<unsigned>(N)}, config};
482+ // cudax::group group{Unit{}, Level{}, Mapping{static_cast<unsigned>(N)}, config};
479483 // static_assert(
480- // cuda::std::is_same_v<GroupTempl< Level, Mapping, typename Config::hierarchy_type,
484+ // cuda::std::is_same_v<cudax::group<Unit, Level, Mapping, typename Config::hierarchy_type,
481485 // cudax::__syncwarp_synchronizer<Unit, Level, Mapping>>, decltype(group)>);
482486
483487 // test_common_properties<Unit, Level>(config.hierarchy(), group);
484488 // test_queries<Level>(group);
485489 // }
486490}
487491
488- template <class Unit , template < class ...> class GroupTempl , class Level , class Config >
492+ template <class Unit , class Level , class Config >
489493__device__ void test_group_by_group (const Config& config)
490494{
491495 // powers of 2
492- test_group_by_group<Unit, GroupTempl, Level, Config, 1 >(config);
493- test_group_by_group<Unit, GroupTempl, Level, Config, 4 >(config);
494- test_group_by_group<Unit, GroupTempl, Level, Config, 16 >(config);
495- test_group_by_group<Unit, GroupTempl, Level, Config, 32 >(config);
496+ test_group_by_group<Unit, Level, Config, 1 >(config);
497+ test_group_by_group<Unit, Level, Config, 4 >(config);
498+ test_group_by_group<Unit, Level, Config, 16 >(config);
499+ test_group_by_group<Unit, Level, Config, 32 >(config);
496500
497501 if constexpr (cuda::std::is_same_v<Level, cuda::block_level>)
498502 {
499- test_group_by_group<Unit, GroupTempl, Level, Config, 64 >(config);
500- test_group_by_group<Unit, GroupTempl, Level, Config, 128 >(config);
503+ test_group_by_group<Unit, Level, Config, 64 >(config);
504+ test_group_by_group<Unit, Level, Config, 128 >(config);
501505 }
502506}
503507
@@ -515,13 +519,13 @@ struct TestKernel
515519 test_this_group<cuda::grid_level, cudax::this_grid>(config);
516520
517521 // todo: allow this once hierarchy is queryable for missing levels
518- test_group_by_group<cuda::thread_level, cudax::thread_group, cuda::warp_level>(config);
519- test_group_by_group<cuda::thread_level, cudax::thread_group, cuda::block_level>(config);
522+ test_group_by_group<cuda::thread_level, cuda::warp_level>(config);
523+ test_group_by_group<cuda::thread_level, cuda::block_level>(config);
520524 if constexpr (Hierarchy::has_level (cuda::cluster))
521525 {
522- test_group_by_group<cuda::thread_level, cudax::thread_group, cuda::cluster_level>(config);
526+ test_group_by_group<cuda::thread_level, cuda::cluster_level>(config);
523527 }
524- test_group_by_group<cuda::thread_level, cudax::thread_group, cuda::grid_level>(config);
528+ test_group_by_group<cuda::thread_level, cuda::grid_level>(config);
525529 }
526530};
527531
0 commit comments