Skip to content

Commit 2b21bec

Browse files
authored
Use the new tuning API internally for detail::find::dispatch (#9240)
* add new dispatch tuning API to cub::find * reviews
1 parent 89c81d7 commit 2b21bec

3 files changed

Lines changed: 83 additions & 23 deletions

File tree

cub/benchmarks/bench/find_if/base.cu

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,31 @@
99

1010
#include <nvbench_helper.cuh>
1111

12+
// %RANGE% TUNE_LOAD ld 0:2:1
13+
// %RANGE% TUNE_ITEMS_PER_THREAD ipt 7:24:1
14+
// %RANGE% TUNE_THREADS_PER_BLOCK_POW2 tpb 6:10:1
15+
16+
#if !TUNE_BASE
17+
# if TUNE_LOAD == 0
18+
# define TUNE_LOAD_MODIFIER cub::LOAD_DEFAULT
19+
# elif TUNE_LOAD == 1
20+
# define TUNE_LOAD_MODIFIER cub::LOAD_LDG
21+
# else // TUNE_LOAD == 2
22+
# define TUNE_LOAD_MODIFIER cub::LOAD_CA
23+
# endif // TUNE_LOAD
24+
25+
template <typename T>
26+
struct bench_policy_selector
27+
{
28+
[[nodiscard]] _CCCL_HOST_DEVICE constexpr auto operator()(::cuda::compute_capability) const
29+
-> cub::detail::find::find_policy
30+
{
31+
return cub::detail::find::find_policy{
32+
(1 << TUNE_THREADS_PER_BLOCK_POW2), cub::Nominal4BItemsToItems<T>(TUNE_ITEMS_PER_THREAD), 4, TUNE_LOAD_MODIFIER};
33+
}
34+
};
35+
#endif // !TUNE_BASE
36+
1237
template <typename T, typename OffsetT>
1338
void find_if(nvbench::state& state, nvbench::type_list<T, OffsetT>)
1439
{
@@ -23,33 +48,27 @@ void find_if(nvbench::state& state, nvbench::type_list<T, OffsetT>)
2348
thrust::fill(dinput.begin() + mismatch_point, dinput.end(), val);
2449
thrust::device_vector<OffsetT> d_result(1, thrust::no_init);
2550

26-
void* d_temp_storage = nullptr;
27-
size_t temp_storage_bytes{};
28-
2951
state.add_global_memory_reads<T>(mismatch_point);
3052
state.add_global_memory_writes<OffsetT>(1);
3153

32-
cub::DeviceFind::FindIf(
33-
d_temp_storage,
34-
temp_storage_bytes,
35-
thrust::raw_pointer_cast(dinput.data()),
36-
thrust::raw_pointer_cast(d_result.data()),
37-
cuda::equal_to_value<T>(val),
38-
static_cast<OffsetT>(dinput.size()),
39-
nullptr);
40-
41-
thrust::device_vector<uint8_t> temp_storage(temp_storage_bytes, thrust::no_init);
42-
d_temp_storage = thrust::raw_pointer_cast(temp_storage.data());
43-
44-
state.exec(nvbench::exec_tag::no_batch, [&](nvbench::launch& launch) {
45-
cub::DeviceFind::FindIf(
46-
d_temp_storage,
47-
temp_storage_bytes,
54+
caching_allocator_t alloc;
55+
state.exec(nvbench::exec_tag::gpu | nvbench::exec_tag::no_batch, [&](nvbench::launch& launch) {
56+
auto env = cub_bench_env(
57+
alloc,
58+
launch
59+
#if !TUNE_BASE
60+
,
61+
cuda::execution::tune(bench_policy_selector<T>{})
62+
#endif // !TUNE_BASE
63+
);
64+
_CCCL_TRY_CUDA_API(
65+
cub::DeviceFind::FindIf,
66+
"FindIf failed",
4867
thrust::raw_pointer_cast(dinput.data()),
4968
thrust::raw_pointer_cast(d_result.data()),
5069
cuda::equal_to_value<T>(val),
5170
static_cast<OffsetT>(dinput.size()),
52-
launch.get_stream());
71+
env);
5372
});
5473
}
5574

cub/cub/device/device_find.cuh

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -428,9 +428,13 @@ struct DeviceFind
428428

429429
using OffsetT = detail::choose_offset_t<NumItemsT>;
430430

431-
return detail::dispatch_with_env(env, [&]([[maybe_unused]] auto tuning, void* storage, size_t& bytes, auto stream) {
432-
return detail::find::dispatch(storage, bytes, d_in, d_out, static_cast<OffsetT>(num_items), scan_op, stream);
433-
});
431+
using default_policy_selector = detail::find::policy_selector_from_types<detail::it_value_t<InputIteratorT>>;
432+
433+
return detail::dispatch_with_env_and_tuning<default_policy_selector>(
434+
env, [&](auto policy_selector, void* storage, size_t& bytes, auto stream) {
435+
return detail::find::dispatch(
436+
storage, bytes, d_in, d_out, static_cast<OffsetT>(num_items), scan_op, stream, policy_selector);
437+
});
434438
}
435439

436440
//! @rst

cub/test/catch2_test_device_find_env.cu

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ struct stream_registry_factory_t;
1111

1212
#include <thrust/device_vector.h>
1313

14+
#include <cuda/functional>
1415
#include <cuda/iterator>
1516

1617
#include "catch2_test_env_launch_helper.h"
@@ -34,6 +35,21 @@ struct is_greater_than_t
3435
}
3536
};
3637

38+
// A policy selector that forces a specific block size, so a test can verify the tuning was applied.
39+
template <int ThreadsPerBlock>
40+
struct find_tuning
41+
{
42+
_CCCL_HOST_DEVICE_API constexpr auto operator()(cuda::compute_capability) const -> cub::detail::find::find_policy
43+
{
44+
return {ThreadsPerBlock, 4, 4, cub::LOAD_LDG};
45+
}
46+
};
47+
48+
using block_size_extracting_predicate_t = block_size_extracting_op<::cuda::always_false>;
49+
50+
using block_sizes =
51+
c2h::type_list<cuda::std::integral_constant<unsigned int, 64>, cuda::std::integral_constant<unsigned int, 128>>;
52+
3753
#if TEST_LAUNCH == 0
3854

3955
TEST_CASE("Device FindIf works with default environment", "[find][device]")
@@ -186,3 +202,24 @@ C2H_TEST("Device UpperBound uses environment", "[find][device]")
186202
c2h::device_vector<int> expected = {1, 2, 3, 4};
187203
REQUIRE(d_output == expected);
188204
}
205+
206+
#if TEST_LAUNCH != 1
207+
C2H_TEST("Device FindIf can be tuned", "[find][device]", block_sizes)
208+
{
209+
constexpr unsigned int target_block_size = c2h::get<0, TestType>::value;
210+
211+
constexpr int num_items = 1024;
212+
auto d_in = c2h::device_vector<int>(num_items, 0);
213+
auto d_out = c2h::device_vector<int>(1, thrust::no_init);
214+
auto d_block_size = c2h::device_vector<unsigned int>(1, 0);
215+
216+
block_size_extracting_predicate_t predicate{thrust::raw_pointer_cast(d_block_size.data())};
217+
218+
auto env = cuda::execution::tune(find_tuning<static_cast<int>(target_block_size)>{});
219+
220+
device_find_if(d_in.begin(), d_out.begin(), predicate, num_items, env);
221+
222+
REQUIRE(d_out[0] == num_items); // predicate never matches
223+
REQUIRE(d_block_size[0] == target_block_size);
224+
}
225+
#endif // TEST_LAUNCH != 1

0 commit comments

Comments
 (0)