|
31 | 31 | #include <cuco/utility/traits.hpp> |
32 | 32 |
|
33 | 33 | #include <cub/device/device_for.cuh> |
| 34 | +#include <cub/device/device_reduce.cuh> |
34 | 35 | #include <cub/device/device_select.cuh> |
35 | 36 | #include <cuda/atomic> |
36 | 37 | #include <cuda/iterator> |
@@ -958,21 +959,56 @@ class open_addressing_impl : private open_addressing_compatible<Key, Value, Prob |
958 | 959 | */ |
959 | 960 | [[nodiscard]] size_type size(cuda::stream_ref stream) const |
960 | 961 | { |
961 | | - auto counter = |
962 | | - detail::counter_storage<size_type, thread_scope, allocator_type>{this->allocator(), stream}; |
963 | | - counter.reset(stream); |
| 962 | + using temp_allocator_type = |
| 963 | + typename std::allocator_traits<allocator_type>::template rebind_alloc<char>; |
| 964 | + auto temp_allocator = temp_allocator_type{this->allocator()}; |
| 965 | + |
| 966 | + auto* d_count = |
| 967 | + reinterpret_cast<size_type*>(temp_allocator.allocate(sizeof(size_type), stream)); |
964 | 968 |
|
965 | | - auto const grid_size = cuco::detail::grid_size(this->capacity()); |
966 | 969 | auto const is_filled = detail::open_addressing_ns::slot_is_filled<has_payload, key_type>{ |
967 | 970 | this->empty_key_sentinel(), this->erased_key_sentinel()}; |
| 971 | + auto const slot_begin = cuda::make_transform_iterator( |
| 972 | + cuda::counting_iterator{size_type{0}}, |
| 973 | + detail::open_addressing_ns::get_slot<has_payload, storage_ref_type>(this->storage_ref())); |
| 974 | + |
| 975 | + std::size_t temp_storage_bytes = 0; |
| 976 | + |
| 977 | + CUCO_CUDA_TRY(cub::DeviceReduce::TransformReduce(nullptr, |
| 978 | + temp_storage_bytes, |
| 979 | + slot_begin, |
| 980 | + d_count, |
| 981 | + this->capacity(), |
| 982 | + cuda::std::plus<size_type>{}, |
| 983 | + is_filled, |
| 984 | + size_type{0}, |
| 985 | + stream.get())); |
| 986 | + |
| 987 | + auto d_temp_storage = temp_allocator.allocate(temp_storage_bytes, stream); |
| 988 | + |
| 989 | + CUCO_CUDA_TRY(cub::DeviceReduce::TransformReduce(d_temp_storage, |
| 990 | + temp_storage_bytes, |
| 991 | + slot_begin, |
| 992 | + d_count, |
| 993 | + this->capacity(), |
| 994 | + cuda::std::plus<size_type>{}, |
| 995 | + is_filled, |
| 996 | + size_type{0}, |
| 997 | + stream.get())); |
| 998 | + |
| 999 | + size_type h_count; |
| 1000 | + CUCO_CUDA_TRY(cuco::detail::memcpy_async( |
| 1001 | + &h_count, d_count, sizeof(size_type), cudaMemcpyDeviceToHost, stream)); |
| 1002 | +#if CCCL_MAJOR_VERSION > 3 || (CCCL_MAJOR_VERSION == 3 && CCCL_MINOR_VERSION >= 1) |
| 1003 | + stream.sync(); |
| 1004 | +#else |
| 1005 | + stream.wait(); |
| 1006 | +#endif |
968 | 1007 |
|
969 | | - // TODO: custom kernel to be replaced by cub::DeviceReduce::Sum when cub version is bumped to |
970 | | - // v2.1.0 |
971 | | - detail::open_addressing_ns::size<cuco::detail::default_block_size()> |
972 | | - <<<grid_size, cuco::detail::default_block_size(), 0, stream.get()>>>( |
973 | | - storage_.ref(), is_filled, counter.data()); |
| 1008 | + temp_allocator.deallocate(d_temp_storage, temp_storage_bytes, stream); |
| 1009 | + temp_allocator.deallocate(reinterpret_cast<char*>(d_count), sizeof(size_type), stream); |
974 | 1010 |
|
975 | | - return counter.load_to_host(stream); |
| 1011 | + return h_count; |
976 | 1012 | } |
977 | 1013 |
|
978 | 1014 | /** |
|
0 commit comments