Skip to content

Commit 4eb7bcc

Browse files
Refactor
1 parent 17265ef commit 4eb7bcc

1 file changed

Lines changed: 39 additions & 50 deletions

File tree

cub/test/catch2_test_device_merge_sort_env.cu

Lines changed: 39 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -41,35 +41,6 @@ struct merge_sort_tuning
4141
}
4242
};
4343

44-
struct unrelated_policy
45-
{};
46-
47-
struct unrelated_tuning
48-
{
49-
// should never be called
50-
auto operator()(cuda::arch_id /*arch*/) const -> unrelated_policy
51-
{
52-
throw 1337;
53-
}
54-
};
55-
56-
struct block_size_compare_t
57-
{
58-
unsigned int* block_size;
59-
60-
__device__ bool operator()(int lhs, int rhs) const
61-
{
62-
if (threadIdx.x == 0)
63-
{
64-
// use an atomic operation to write the block dim in case multiple blocks are launched
65-
atomicMin(block_size, blockDim.x);
66-
}
67-
return lhs < rhs;
68-
}
69-
};
70-
71-
using block_sizes = c2h::type_list<cuda::std::integral_constant<int, 64>, cuda::std::integral_constant<int, 128>>;
72-
7344
#if TEST_LAUNCH == 0
7445

7546
TEST_CASE("DeviceMergeSort::SortPairs works with default environment", "[merge_sort][device]")
@@ -475,14 +446,32 @@ TEST_CASE("DeviceMergeSort::StableSortKeysCopy uses custom stream", "[merge_sort
475446
REQUIRE(d_keys_out == expected_keys);
476447
}
477448

449+
struct block_size_compare_t
450+
{
451+
unsigned int* block_size;
452+
453+
__device__ bool operator()(int lhs, int rhs) const
454+
{
455+
if (threadIdx.x == 0)
456+
{
457+
// use an atomic operation to write the block dim in case multiple blocks are launched
458+
atomicMin(block_size, blockDim.x);
459+
}
460+
return lhs < rhs;
461+
}
462+
};
463+
464+
using block_sizes =
465+
c2h::type_list<cuda::std::integral_constant<unsigned int, 64>, cuda::std::integral_constant<unsigned int, 128>>;
466+
478467
C2H_TEST("DeviceMergeSort::SortPairs can be tuned", "[merge_sort][device]", block_sizes)
479468
{
480-
constexpr int target_block_size = c2h::get<0, TestType>::value;
469+
constexpr unsigned int target_block_size = c2h::get<0, TestType>::value;
481470
c2h::device_vector<int> d_keys{4, 1, 3, 2};
482471
c2h::device_vector<int> d_values{0, 1, 2, 3};
483-
c2h::device_vector<int> d_block_size(1);
472+
c2h::device_vector<unsigned int> d_block_size(1);
484473
auto compare_op = block_size_compare_t{thrust::raw_pointer_cast(d_block_size.data())};
485-
auto env = cuda::execution::__tune(merge_sort_tuning<target_block_size>{}, unrelated_tuning{});
474+
auto env = cuda::execution::__tune(merge_sort_tuning<target_block_size>{});
486475

487476
REQUIRE(cudaSuccess
488477
== cub::DeviceMergeSort::SortPairs(
@@ -492,14 +481,14 @@ C2H_TEST("DeviceMergeSort::SortPairs can be tuned", "[merge_sort][device]", bloc
492481

493482
C2H_TEST("DeviceMergeSort::SortPairsCopy can be tuned", "[merge_sort][device]", block_sizes)
494483
{
495-
constexpr int target_block_size = c2h::get<0, TestType>::value;
484+
constexpr unsigned int target_block_size = c2h::get<0, TestType>::value;
496485
c2h::device_vector<int> d_keys_in{4, 1, 3, 2};
497486
c2h::device_vector<int> d_values_in{0, 1, 2, 3};
498487
c2h::device_vector<int> d_keys_out(4);
499488
c2h::device_vector<int> d_values_out(4);
500-
c2h::device_vector<int> d_block_size(1);
489+
c2h::device_vector<unsigned int> d_block_size(1);
501490
auto compare_op = block_size_compare_t{thrust::raw_pointer_cast(d_block_size.data())};
502-
auto env = cuda::execution::__tune(merge_sort_tuning<target_block_size>{}, unrelated_tuning{});
491+
auto env = cuda::execution::__tune(merge_sort_tuning<target_block_size>{});
503492

504493
REQUIRE(
505494
cudaSuccess
@@ -516,11 +505,11 @@ C2H_TEST("DeviceMergeSort::SortPairsCopy can be tuned", "[merge_sort][device]",
516505

517506
C2H_TEST("DeviceMergeSort::SortKeys can be tuned", "[merge_sort][device]", block_sizes)
518507
{
519-
constexpr int target_block_size = c2h::get<0, TestType>::value;
508+
constexpr unsigned int target_block_size = c2h::get<0, TestType>::value;
520509
c2h::device_vector<int> d_keys{4, 1, 3, 2};
521-
c2h::device_vector<int> d_block_size(1);
510+
c2h::device_vector<unsigned int> d_block_size(1);
522511
auto compare_op = block_size_compare_t{thrust::raw_pointer_cast(d_block_size.data())};
523-
auto env = cuda::execution::__tune(merge_sort_tuning<target_block_size>{}, unrelated_tuning{});
512+
auto env = cuda::execution::__tune(merge_sort_tuning<target_block_size>{});
524513

525514
REQUIRE(cudaSuccess
526515
== cub::DeviceMergeSort::SortKeys(d_keys.data().get(), static_cast<int>(d_keys.size()), compare_op, env));
@@ -529,12 +518,12 @@ C2H_TEST("DeviceMergeSort::SortKeys can be tuned", "[merge_sort][device]", block
529518

530519
C2H_TEST("DeviceMergeSort::SortKeysCopy can be tuned", "[merge_sort][device]", block_sizes)
531520
{
532-
constexpr int target_block_size = c2h::get<0, TestType>::value;
521+
constexpr unsigned int target_block_size = c2h::get<0, TestType>::value;
533522
c2h::device_vector<int> d_keys_in{4, 1, 3, 2};
534523
c2h::device_vector<int> d_keys_out(4);
535-
c2h::device_vector<int> d_block_size(1);
524+
c2h::device_vector<unsigned int> d_block_size(1);
536525
auto compare_op = block_size_compare_t{thrust::raw_pointer_cast(d_block_size.data())};
537-
auto env = cuda::execution::__tune(merge_sort_tuning<target_block_size>{}, unrelated_tuning{});
526+
auto env = cuda::execution::__tune(merge_sort_tuning<target_block_size>{});
538527

539528
REQUIRE(cudaSuccess
540529
== cub::DeviceMergeSort::SortKeysCopy(
@@ -544,12 +533,12 @@ C2H_TEST("DeviceMergeSort::SortKeysCopy can be tuned", "[merge_sort][device]", b
544533

545534
C2H_TEST("DeviceMergeSort::StableSortPairs can be tuned", "[merge_sort][device]", block_sizes)
546535
{
547-
constexpr int target_block_size = c2h::get<0, TestType>::value;
536+
constexpr unsigned int target_block_size = c2h::get<0, TestType>::value;
548537
c2h::device_vector<int> d_keys{4, 1, 3, 2};
549538
c2h::device_vector<int> d_values{0, 1, 2, 3};
550-
c2h::device_vector<int> d_block_size(1);
539+
c2h::device_vector<unsigned int> d_block_size(1);
551540
auto compare_op = block_size_compare_t{thrust::raw_pointer_cast(d_block_size.data())};
552-
auto env = cuda::execution::__tune(merge_sort_tuning<target_block_size>{}, unrelated_tuning{});
541+
auto env = cuda::execution::__tune(merge_sort_tuning<target_block_size>{});
553542

554543
REQUIRE(cudaSuccess
555544
== cub::DeviceMergeSort::StableSortPairs(
@@ -559,11 +548,11 @@ C2H_TEST("DeviceMergeSort::StableSortPairs can be tuned", "[merge_sort][device]"
559548

560549
C2H_TEST("DeviceMergeSort::StableSortKeys can be tuned", "[merge_sort][device]", block_sizes)
561550
{
562-
constexpr int target_block_size = c2h::get<0, TestType>::value;
551+
constexpr unsigned int target_block_size = c2h::get<0, TestType>::value;
563552
c2h::device_vector<int> d_keys{4, 1, 3, 2};
564-
c2h::device_vector<int> d_block_size(1);
553+
c2h::device_vector<unsigned int> d_block_size(1);
565554
auto compare_op = block_size_compare_t{thrust::raw_pointer_cast(d_block_size.data())};
566-
auto env = cuda::execution::__tune(merge_sort_tuning<target_block_size>{}, unrelated_tuning{});
555+
auto env = cuda::execution::__tune(merge_sort_tuning<target_block_size>{});
567556

568557
REQUIRE(
569558
cudaSuccess
@@ -573,12 +562,12 @@ C2H_TEST("DeviceMergeSort::StableSortKeys can be tuned", "[merge_sort][device]",
573562

574563
C2H_TEST("DeviceMergeSort::StableSortKeysCopy can be tuned", "[merge_sort][device]", block_sizes)
575564
{
576-
constexpr int target_block_size = c2h::get<0, TestType>::value;
565+
constexpr unsigned int target_block_size = c2h::get<0, TestType>::value;
577566
c2h::device_vector<int> d_keys_in{4, 1, 3, 2};
578567
c2h::device_vector<int> d_keys_out(4);
579-
c2h::device_vector<int> d_block_size(1);
568+
c2h::device_vector<unsigned int> d_block_size(1);
580569
auto compare_op = block_size_compare_t{thrust::raw_pointer_cast(d_block_size.data())};
581-
auto env = cuda::execution::__tune(merge_sort_tuning<target_block_size>{}, unrelated_tuning{});
570+
auto env = cuda::execution::__tune(merge_sort_tuning<target_block_size>{});
582571

583572
REQUIRE(cudaSuccess
584573
== cub::DeviceMergeSort::StableSortKeysCopy(

0 commit comments

Comments
 (0)