Skip to content

Commit 88d6c6b

Browse files
authored
Revert PR #580 streaming workaround (CCCL #1422 resolved) (#810)
This PR reverts the #580 streaming workaround as large size type is now supported by CUB.
1 parent f3c5102 commit 88d6c6b

1 file changed

Lines changed: 34 additions & 46 deletions

File tree

include/cuco/detail/open_addressing/open_addressing_impl.cuh

Lines changed: 34 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -830,60 +830,48 @@ class open_addressing_impl : private open_addressing_compatible<Key, Value, Prob
830830
using temp_allocator_type =
831831
typename std::allocator_traits<allocator_type>::template rebind_alloc<char>;
832832

833-
cuco::detail::index_type constexpr stride = std::numeric_limits<std::int32_t>::max();
834-
835-
cuco::detail::index_type h_num_out{0};
836833
auto temp_allocator = temp_allocator_type{this->allocator()};
837834
auto d_num_out =
838835
reinterpret_cast<size_type*>(temp_allocator.allocate(sizeof(size_type), stream));
839836

840-
// TODO: PR #580 to be reverted once https://github.com/NVIDIA/cccl/issues/1422 is resolved
841-
for (cuco::detail::index_type offset = 0;
842-
offset < static_cast<cuco::detail::index_type>(this->capacity());
843-
offset += stride) {
844-
auto const num_items =
845-
std::min(static_cast<cuco::detail::index_type>(this->capacity()) - offset, stride);
846-
auto const begin = cuda::make_transform_iterator(
847-
cuda::counting_iterator{static_cast<size_type>(offset)},
848-
detail::open_addressing_ns::get_slot<has_payload, storage_ref_type>(this->storage_ref()));
849-
auto const is_filled = detail::open_addressing_ns::slot_is_filled<has_payload, key_type>{
850-
this->empty_key_sentinel(), this->erased_key_sentinel()};
851-
852-
std::size_t temp_storage_bytes = 0;
853-
854-
CUCO_CUDA_TRY(cub::DeviceSelect::If(nullptr,
855-
temp_storage_bytes,
856-
begin,
857-
output_begin + h_num_out,
858-
d_num_out,
859-
static_cast<std::int32_t>(num_items),
860-
is_filled,
861-
stream.get()));
862-
863-
// Allocate temporary storage
864-
auto d_temp_storage = temp_allocator.allocate(temp_storage_bytes, stream);
865-
866-
CUCO_CUDA_TRY(cub::DeviceSelect::If(d_temp_storage,
867-
temp_storage_bytes,
868-
begin,
869-
output_begin + h_num_out,
870-
d_num_out,
871-
static_cast<std::int32_t>(num_items),
872-
is_filled,
873-
stream.get()));
874-
875-
size_type temp_count;
876-
CUCO_CUDA_TRY(cuco::detail::memcpy_async(
877-
&temp_count, d_num_out, sizeof(size_type), cudaMemcpyDeviceToHost, stream));
837+
auto const begin = cuda::make_transform_iterator(
838+
cuda::counting_iterator{size_type{0}},
839+
detail::open_addressing_ns::get_slot<has_payload, storage_ref_type>(this->storage_ref()));
840+
auto const is_filled = detail::open_addressing_ns::slot_is_filled<has_payload, key_type>{
841+
this->empty_key_sentinel(), this->erased_key_sentinel()};
842+
843+
std::size_t temp_storage_bytes = 0;
844+
845+
CUCO_CUDA_TRY(cub::DeviceSelect::If(nullptr,
846+
temp_storage_bytes,
847+
begin,
848+
output_begin,
849+
d_num_out,
850+
this->capacity(),
851+
is_filled,
852+
stream.get()));
853+
854+
auto d_temp_storage = temp_allocator.allocate(temp_storage_bytes, stream);
855+
856+
CUCO_CUDA_TRY(cub::DeviceSelect::If(d_temp_storage,
857+
temp_storage_bytes,
858+
begin,
859+
output_begin,
860+
d_num_out,
861+
this->capacity(),
862+
is_filled,
863+
stream.get()));
864+
865+
size_type h_num_out;
866+
CUCO_CUDA_TRY(cuco::detail::memcpy_async(
867+
&h_num_out, d_num_out, sizeof(size_type), cudaMemcpyDeviceToHost, stream));
878868
#if CCCL_MAJOR_VERSION > 3 || (CCCL_MAJOR_VERSION == 3 && CCCL_MINOR_VERSION >= 1)
879-
stream.sync();
869+
stream.sync();
880870
#else
881-
stream.wait();
871+
stream.wait();
882872
#endif
883-
h_num_out += temp_count;
884-
temp_allocator.deallocate(d_temp_storage, temp_storage_bytes, stream);
885-
}
886873

874+
temp_allocator.deallocate(d_temp_storage, temp_storage_bytes, stream);
887875
temp_allocator.deallocate(reinterpret_cast<char*>(d_num_out), sizeof(size_type), stream);
888876

889877
return output_begin + h_num_out;

0 commit comments

Comments
 (0)