diff --git a/cub/benchmarks/bench/collectives/coop_collectives.cu b/cub/benchmarks/bench/collectives/coop_collectives.cu new file mode 100644 index 00000000000..d340a22271e --- /dev/null +++ b/cub/benchmarks/bench/collectives/coop_collectives.cu @@ -0,0 +1,443 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +using value_types = nvbench::type_list; + +template +struct warp_reduce_owner_t +{ + _CCCL_DEVICE _CCCL_FORCEINLINE T operator()(T thread_data) const + { + using warp_reduce_t = cub::WarpReduce; + __shared__ typename warp_reduce_t::TempStorage temp_storage[32]; + + const int warp_id = static_cast(threadIdx.x) / cub::detail::warp_threads; + return warp_reduce_t{temp_storage[warp_id]}.Sum(thread_data); + } +}; + +template +struct warp_reduce_manual_broadcast_t +{ + _CCCL_DEVICE _CCCL_FORCEINLINE T operator()(T thread_data) const + { + using warp_reduce_t = cub::WarpReduce; + __shared__ typename warp_reduce_t::TempStorage temp_storage[32]; + + const int warp_id = static_cast(threadIdx.x) / cub::detail::warp_threads; + T aggregate = warp_reduce_t{temp_storage[warp_id]}.Sum(thread_data); + const auto logical_warp_id = cub::detail::logical_warp_id(); + const auto member_mask = cub::WarpMask(logical_warp_id); + return cub::ShuffleIndex(aggregate, 0, member_mask); + } +}; + +template +struct warp_reduce_coop_broadcast_t +{ + _CCCL_DEVICE _CCCL_FORCEINLINE T operator()(T thread_data) const + { + using warp_reduce_t = cub::WarpReduceBroadcast; + __shared__ typename warp_reduce_t::TempStorage temp_storage[32]; + + const int warp_id = static_cast(threadIdx.x) / cub::detail::warp_threads; + return warp_reduce_t{temp_storage[warp_id]}.Sum(thread_data); + } +}; + +template +struct warp_allreduce4_manual_t +{ + _CCCL_DEVICE _CCCL_FORCEINLINE T operator()(T thread_data) const + { + constexpr int logical_warp_threads = 4; + const int lane_id = cub::detail::logical_lane_id(); + const int logical_warp_id = cub::detail::logical_warp_id(); + const auto member_mask = cub::WarpMask(logical_warp_id); + + _CCCL_PRAGMA_UNROLL_FULL() + for (int offset = 1; offset < logical_warp_threads; offset <<= 1) + { + const T peer = cub::ShuffleIndex(thread_data, lane_id ^ offset, member_mask); + thread_data += peer; + } + return thread_data; + } +}; + +template +struct warp_reduce_coop_broadcast4_t +{ + _CCCL_DEVICE _CCCL_FORCEINLINE T operator()(T thread_data) const + { + using warp_reduce_t = cub::WarpReduceBroadcast; + __shared__ typename warp_reduce_t::TempStorage temp_storage[32]; + + const int warp_id = static_cast(threadIdx.x) / cub::detail::warp_threads; + return warp_reduce_t{temp_storage[warp_id]}.Sum(thread_data); + } +}; + +template +struct warp_reduce_serial_batched_t +{ + _CCCL_DEVICE _CCCL_FORCEINLINE T operator()(T thread_data) const + { + using warp_reduce_t = cub::WarpReduce; + __shared__ typename warp_reduce_t::TempStorage temp_storage[Batches][32]; + + const int warp_id = static_cast(threadIdx.x) / cub::detail::warp_threads; + const int lane_id = static_cast(threadIdx.x) % cub::detail::warp_threads; + + T result = thread_data; + _CCCL_PRAGMA_UNROLL_FULL() + for (int batch = 0; batch < Batches; ++batch) + { + const T aggregate = warp_reduce_t{temp_storage[batch][warp_id]}.Sum(thread_data + static_cast(batch)); + if (lane_id == batch) + { + result = aggregate; + } + } + return result; + } +}; + +template +struct warp_reduce_batched_t +{ + _CCCL_DEVICE _CCCL_FORCEINLINE T operator()(T thread_data) const + { + using warp_reduce_t = cub::WarpReduceBatched; + __shared__ typename warp_reduce_t::TempStorage temp_storage[32]; + + const int warp_id = static_cast(threadIdx.x) / cub::detail::warp_threads; + const int lane_id = static_cast(threadIdx.x) % cub::detail::warp_threads; + + ::cuda::std::array inputs{}; + _CCCL_PRAGMA_UNROLL_FULL() + for (int batch = 0; batch < Batches; ++batch) + { + inputs[batch] = thread_data + static_cast(batch); + } + + T result = warp_reduce_t{temp_storage[warp_id]}.Sum(inputs); + return lane_id < Batches ? result : thread_data; + } +}; + +template +struct warp_reduce_serial_batched_broadcast4_t +{ + _CCCL_DEVICE _CCCL_FORCEINLINE T operator()(T thread_data) const + { + using warp_reduce_t = cub::WarpReduceBroadcast; + __shared__ typename warp_reduce_t::TempStorage temp_storage[Batches][32]; + + const int warp_id = static_cast(threadIdx.x) / cub::detail::warp_threads; + + T result{}; + _CCCL_PRAGMA_UNROLL_FULL() + for (int batch = 0; batch < Batches; ++batch) + { + result += warp_reduce_t{temp_storage[batch][warp_id]}.Sum(thread_data + static_cast(batch)); + } + return result; + } +}; + +template +struct warp_reduce_batched_broadcast4_t +{ + _CCCL_DEVICE _CCCL_FORCEINLINE T operator()(T thread_data) const + { + using warp_reduce_t = cub::WarpReduceBatchedBroadcast; + __shared__ typename warp_reduce_t::TempStorage temp_storage[32]; + + const int warp_id = static_cast(threadIdx.x) / cub::detail::warp_threads; + + ::cuda::std::array inputs{}; + _CCCL_PRAGMA_UNROLL_FULL() + for (int batch = 0; batch < Batches; ++batch) + { + inputs[batch] = thread_data + static_cast(batch); + } + + const auto outputs = warp_reduce_t{temp_storage[warp_id]}.Sum(inputs); + return cub::ThreadReduce(outputs, ::cuda::std::plus<>{}); + } +}; + +template +struct block_reduce_manual_broadcast_t +{ + _CCCL_DEVICE _CCCL_FORCEINLINE T operator()(T thread_data) const + { + using block_reduce_t = cub::BlockReduce; + struct temp_storage_t + { + typename block_reduce_t::TempStorage reduce; + cub::Uninitialized aggregate; + }; + + __shared__ temp_storage_t temp_storage; + + T aggregate = block_reduce_t{temp_storage.reduce}.Sum(thread_data); + if (threadIdx.x == 0) + { + temp_storage.aggregate.Alias() = aggregate; + } + __syncthreads(); + + T result = temp_storage.aggregate.Alias(); + __syncthreads(); + return result; + } +}; + +template +struct block_reduce_coop_broadcast_t +{ + _CCCL_DEVICE _CCCL_FORCEINLINE T operator()(T thread_data) const + { + using block_reduce_t = cub::BlockReduceBroadcast; + __shared__ typename block_reduce_t::TempStorage temp_storage; + + return block_reduce_t{temp_storage}.Sum(thread_data); + } +}; + +template +struct block_row_reduce_t +{ + _CCCL_DEVICE _CCCL_FORCEINLINE T operator()(T thread_data) const + { + using row_reduce_t = cub::BlockRowReduce; + __shared__ typename row_reduce_t::TempStorage temp_storage; + + return row_reduce_t{temp_storage}.Sum(thread_data); + } +}; + +template +struct block_row_reduce_warp_broadcast_t +{ + _CCCL_DEVICE _CCCL_FORCEINLINE T operator()(T thread_data) const + { + using row_reduce_t = cub::BlockRowReduceWarpBroadcast; + __shared__ typename row_reduce_t::TempStorage temp_storage; + + return row_reduce_t{temp_storage}.Sum(thread_data); + } +}; + +template +struct warp_block_scan_t +{ + _CCCL_DEVICE _CCCL_FORCEINLINE T operator()(T thread_data) const + { + using warp_scan_t = cub::WarpScan; + using block_scan_t = cub::BlockScan; + + __shared__ typename warp_scan_t::TempStorage warp_storage[8]; + __shared__ typename block_scan_t::TempStorage block_storage; + + const int warp_id = static_cast(threadIdx.x) / cub::detail::warp_threads; + T broadcast = warp_scan_t{warp_storage[warp_id]}.Broadcast(thread_data, 0); + + T prefix{}; + block_scan_t{block_storage}.ExclusiveSum(static_cast(1), prefix); + return broadcast + prefix; + } +}; + +template class ActionT, typename T> +void bench_collective(nvbench::state& state, nvbench::type_list) +{ + using action_t = ActionT; + const auto& kernel = benchmark_kernel; + const int num_sms = state.get_device().value().get_number_of_sms(); + int max_blocks_per_sm = 0; + const std::size_t smem = 0; + const int block_threads = BlockSize; + const int unroll_factor = UnrollFactor; + + NVBENCH_CUDA_CALL_NOEXCEPT( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, kernel, block_threads, smem)); + + const int grid_size = max_blocks_per_sm * num_sms; + state.add_element_count(grid_size * block_threads * unroll_factor, "Thread ops"); + state.add_element_count(grid_size * unroll_factor, "CTA ops"); + state.exec(nvbench::exec_tag::gpu | nvbench::exec_tag::no_batch, [&](nvbench::launch&) { + kernel<<>>(action_t{}); + }); +} + +template +using warp_reduce_serial_batched_4_t = warp_reduce_serial_batched_t; + +template +using warp_reduce_batched_4_t = warp_reduce_batched_t; + +template +using warp_reduce_serial_batched_broadcast4_4_t = warp_reduce_serial_batched_broadcast4_t; + +template +using warp_reduce_batched_broadcast4_4_t = warp_reduce_batched_broadcast4_t; + +template +void warp_reduce_owner(nvbench::state& state, nvbench::type_list type) +{ + bench_collective<256, 128, warp_reduce_owner_t>(state, type); +} + +template +void warp_reduce_manual_broadcast(nvbench::state& state, nvbench::type_list type) +{ + bench_collective<256, 128, warp_reduce_manual_broadcast_t>(state, type); +} + +template +void warp_reduce_coop_broadcast(nvbench::state& state, nvbench::type_list type) +{ + bench_collective<256, 128, warp_reduce_coop_broadcast_t>(state, type); +} + +template +void warp_allreduce4_manual(nvbench::state& state, nvbench::type_list type) +{ + bench_collective<256, 128, warp_allreduce4_manual_t>(state, type); +} + +template +void warp_reduce_coop_broadcast4(nvbench::state& state, nvbench::type_list type) +{ + bench_collective<256, 128, warp_reduce_coop_broadcast4_t>(state, type); +} + +template +void warp_reduce_serial_batched_4(nvbench::state& state, nvbench::type_list type) +{ + bench_collective<256, 128, warp_reduce_serial_batched_4_t>(state, type); +} + +template +void warp_reduce_batched_4(nvbench::state& state, nvbench::type_list type) +{ + bench_collective<256, 128, warp_reduce_batched_4_t>(state, type); +} + +template +void warp_reduce_serial_batched_broadcast4_4(nvbench::state& state, nvbench::type_list type) +{ + bench_collective<256, 128, warp_reduce_serial_batched_broadcast4_4_t>(state, type); +} + +template +void warp_reduce_batched_broadcast4_4(nvbench::state& state, nvbench::type_list type) +{ + bench_collective<256, 128, warp_reduce_batched_broadcast4_4_t>(state, type); +} + +template +void block_reduce_manual_broadcast(nvbench::state& state, nvbench::type_list type) +{ + bench_collective<256, 64, block_reduce_manual_broadcast_t>(state, type); +} + +template +void block_reduce_coop_broadcast(nvbench::state& state, nvbench::type_list type) +{ + bench_collective<256, 64, block_reduce_coop_broadcast_t>(state, type); +} + +template +void block_row_reduce(nvbench::state& state, nvbench::type_list type) +{ + bench_collective<256, 64, block_row_reduce_t>(state, type); +} + +template +void block_row_reduce_warp_broadcast(nvbench::state& state, nvbench::type_list type) +{ + bench_collective<256, 64, block_row_reduce_warp_broadcast_t>(state, type); +} + +template +void warp_block_scan(nvbench::state& state, nvbench::type_list type) +{ + bench_collective<256, 64, warp_block_scan_t>(state, type); +} + +NVBENCH_BENCH_TYPES(warp_reduce_owner, NVBENCH_TYPE_AXES(value_types)) + .set_name("warp_reduce_owner") + .set_type_axes_names({"T{ct}"}); + +NVBENCH_BENCH_TYPES(warp_reduce_manual_broadcast, NVBENCH_TYPE_AXES(value_types)) + .set_name("warp_reduce_manual_broadcast") + .set_type_axes_names({"T{ct}"}); + +NVBENCH_BENCH_TYPES(warp_reduce_coop_broadcast, NVBENCH_TYPE_AXES(value_types)) + .set_name("warp_reduce_coop_broadcast") + .set_type_axes_names({"T{ct}"}); + +NVBENCH_BENCH_TYPES(warp_allreduce4_manual, NVBENCH_TYPE_AXES(value_types)) + .set_name("warp_allreduce4_manual") + .set_type_axes_names({"T{ct}"}); + +NVBENCH_BENCH_TYPES(warp_reduce_coop_broadcast4, NVBENCH_TYPE_AXES(value_types)) + .set_name("warp_reduce_coop_broadcast4") + .set_type_axes_names({"T{ct}"}); + +NVBENCH_BENCH_TYPES(warp_reduce_serial_batched_4, NVBENCH_TYPE_AXES(value_types)) + .set_name("warp_reduce_serial_batched_4") + .set_type_axes_names({"T{ct}"}); + +NVBENCH_BENCH_TYPES(warp_reduce_batched_4, NVBENCH_TYPE_AXES(value_types)) + .set_name("warp_reduce_batched_4") + .set_type_axes_names({"T{ct}"}); + +NVBENCH_BENCH_TYPES(warp_reduce_serial_batched_broadcast4_4, NVBENCH_TYPE_AXES(value_types)) + .set_name("warp_reduce_serial_batched_broadcast4_4") + .set_type_axes_names({"T{ct}"}); + +NVBENCH_BENCH_TYPES(warp_reduce_batched_broadcast4_4, NVBENCH_TYPE_AXES(value_types)) + .set_name("warp_reduce_batched_broadcast4_4") + .set_type_axes_names({"T{ct}"}); + +NVBENCH_BENCH_TYPES(block_reduce_manual_broadcast, NVBENCH_TYPE_AXES(value_types)) + .set_name("block_reduce_manual_broadcast") + .set_type_axes_names({"T{ct}"}); + +NVBENCH_BENCH_TYPES(block_reduce_coop_broadcast, NVBENCH_TYPE_AXES(value_types)) + .set_name("block_reduce_coop_broadcast") + .set_type_axes_names({"T{ct}"}); + +NVBENCH_BENCH_TYPES(block_row_reduce, NVBENCH_TYPE_AXES(value_types)) + .set_name("block_row_reduce") + .set_type_axes_names({"T{ct}"}); + +NVBENCH_BENCH_TYPES(block_row_reduce_warp_broadcast, NVBENCH_TYPE_AXES(value_types)) + .set_name("block_row_reduce_warp_broadcast") + .set_type_axes_names({"T{ct}"}); + +NVBENCH_BENCH_TYPES(warp_block_scan, NVBENCH_TYPE_AXES(value_types)) + .set_name("warp_block_scan") + .set_type_axes_names({"T{ct}"}); diff --git a/cub/cub/block/block_reduce_broadcast.cuh b/cub/cub/block/block_reduce_broadcast.cuh new file mode 100644 index 00000000000..52f8c5a47b7 --- /dev/null +++ b/cub/cub/block/block_reduce_broadcast.cuh @@ -0,0 +1,109 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +//! @file +//! @rst +//! The ``cub::BlockReduceBroadcast`` class provides :ref:`collective ` methods for +//! computing block-wide reductions whose aggregate is returned to every thread in the block. +//! @endrst + +#pragma once + +#include + +#include +#include +#include + +#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) +# pragma GCC system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) +# pragma clang system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) +# pragma system_header +#endif // no system header + +CUB_NAMESPACE_BEGIN + +//! @rst +//! Block-wide reduction adapter that broadcasts the aggregate to every thread in the block. +//! This keeps the usual CUB ``BlockReduce`` algorithm selection and stores the owner-lane result +//! in user-provided temporary storage before broadcasting it through shared memory. +//! @endrst +template +class BlockReduceBroadcast +{ + using BlockReduceT = cub::BlockReduce; + + struct _TempStorage + { + typename BlockReduceT::TempStorage reduce; + cub::Uninitialized aggregate; + }; + + _TempStorage& temp_storage; + unsigned int linear_tid; + + [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE T broadcast(T aggregate) + { + if (linear_tid == 0) + { + temp_storage.aggregate.Alias() = aggregate; + } + __syncthreads(); + + T result = temp_storage.aggregate.Alias(); + __syncthreads(); + return result; + } + +public: + /// @smemstorage{BlockReduceBroadcast} + struct TempStorage : cub::Uninitialized<_TempStorage> + {}; + + _CCCL_DEVICE_API _CCCL_FORCEINLINE explicit BlockReduceBroadcast(TempStorage& temp_storage) + : temp_storage(temp_storage.Alias()) + , linear_tid(RowMajorTid(BlockDimX, BlockDimY, BlockDimZ)) + {} + + template + [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE T Reduce(T input, ReductionOp reduction_op) + { + return broadcast(BlockReduceT(temp_storage.reduce).Reduce(input, reduction_op)); + } + + template + [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE T Reduce(T (&inputs)[ITEMS_PER_THREAD], ReductionOp reduction_op) + { + return broadcast(BlockReduceT(temp_storage.reduce).Reduce(inputs, reduction_op)); + } + + template + [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE T Reduce(T input, ReductionOp reduction_op, int num_valid) + { + return broadcast(BlockReduceT(temp_storage.reduce).Reduce(input, reduction_op, num_valid)); + } + + [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE T Sum(T input) + { + return broadcast(BlockReduceT(temp_storage.reduce).Sum(input)); + } + + template + [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE T Sum(T (&inputs)[ITEMS_PER_THREAD]) + { + return broadcast(BlockReduceT(temp_storage.reduce).Sum(inputs)); + } + + [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE T Sum(T input, int num_valid) + { + return broadcast(BlockReduceT(temp_storage.reduce).Sum(input, num_valid)); + } +}; + +CUB_NAMESPACE_END diff --git a/cub/cub/block/block_row_reduce.cuh b/cub/cub/block/block_row_reduce.cuh new file mode 100644 index 00000000000..b36340cbdd6 --- /dev/null +++ b/cub/cub/block/block_row_reduce.cuh @@ -0,0 +1,190 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +//! @file +//! @rst +//! The ``cub::BlockRowReduce`` and ``cub::BlockRowReduceWarpBroadcast`` classes provide +//! :ref:`collective ` methods for row-shaped block reductions. +//! @endrst + +#pragma once + +#include + +#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) +# pragma GCC system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) +# pragma clang system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) +# pragma system_header +#endif // no system header + +#include +#include +#include + +#include + +CUB_NAMESPACE_BEGIN + +//! @rst +//! Row-shaped block reduction for fixed layouts where each row spans one or more full warps. +//! The row aggregate is returned to every thread in the corresponding row. This matches common +//! norm kernels where a CTA owns one or more rows and every lane needs the row statistic. +//! @endrst +template +class BlockRowReduce +{ + static_assert(RowsPerBlock > 0, "RowsPerBlock must be greater than zero"); + static_assert(WarpsPerRow > 0, "WarpsPerRow must be greater than zero"); + static_assert(WarpsPerRow <= detail::warp_threads, "WarpsPerRow must fit in one final warp reduction"); + + static constexpr int WARP_THREADS = detail::warp_threads; + static constexpr int BLOCK_THREADS = RowsPerBlock * WarpsPerRow * WARP_THREADS; + static constexpr int WARPS = RowsPerBlock * WarpsPerRow; + + using WarpReduceT = cub::WarpReduce; + + struct _TempStorage + { + typename WarpReduceT::TempStorage warp_reduce[WARPS]; + typename WarpReduceT::TempStorage final_reduce[RowsPerBlock]; + cub::Uninitialized partials[RowsPerBlock][WarpsPerRow]; + cub::Uninitialized totals[RowsPerBlock]; + }; + + _TempStorage& temp_storage; + int linear_tid; + +public: + /// @smemstorage{BlockRowReduce} + struct TempStorage : cub::Uninitialized<_TempStorage> + {}; + + _CCCL_DEVICE_API _CCCL_FORCEINLINE explicit BlockRowReduce(TempStorage& temp_storage) + : temp_storage(temp_storage.Alias()) + , linear_tid(static_cast(threadIdx.x)) + {} + + template + [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE T Reduce(T input, ReductionOp reduction_op) + { + const int warp_id = linear_tid / WARP_THREADS; + const int lane_id = linear_tid % WARP_THREADS; + const int row_id = warp_id / WarpsPerRow; + const int row_warp_id = warp_id % WarpsPerRow; + + T warp_aggregate = WarpReduceT(temp_storage.warp_reduce[warp_id]).Reduce(input, reduction_op); + if (lane_id == 0) + { + temp_storage.partials[row_id][row_warp_id].Alias() = warp_aggregate; + } + __syncthreads(); + + if (row_warp_id == 0) + { + T partial = T{}; + if (lane_id < WarpsPerRow) + { + partial = temp_storage.partials[row_id][lane_id].Alias(); + } + + T row_aggregate = WarpReduceT(temp_storage.final_reduce[row_id]).Reduce(partial, reduction_op, WarpsPerRow); + if (lane_id == 0) + { + temp_storage.totals[row_id].Alias() = row_aggregate; + } + } + __syncthreads(); + + T result = temp_storage.totals[row_id].Alias(); + __syncthreads(); + return result; + } + + [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE T Sum(T input) + { + return Reduce(input, ::cuda::std::plus<>{}); + } +}; + +//! @rst +//! Row-shaped block reduction that broadcasts the row aggregate by repeating the final +//! row-wide warp reduction in every warp of the row. +//! +//! This is intended for norm-style kernels where a CTA owns one or more rows, every +//! thread needs the row statistic, and ``WarpsPerRow`` fits in one warp. Compared to +//! ``BlockRowReduce``, this avoids storing the final row total and avoids the extra +//! CTA synchronizations needed to broadcast that stored total. +//! @endrst +template +class BlockRowReduceWarpBroadcast +{ + static_assert(RowsPerBlock > 0, "RowsPerBlock must be greater than zero"); + static_assert(WarpsPerRow > 0, "WarpsPerRow must be greater than zero"); + static_assert(WarpsPerRow <= detail::warp_threads, "WarpsPerRow must fit in one final warp reduction"); + + static constexpr int WARP_THREADS = detail::warp_threads; + + struct _TempStorage + { + cub::Uninitialized partials[RowsPerBlock][WarpsPerRow]; + }; + + _TempStorage& temp_storage; + int linear_tid; + + template + [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE T warp_all_reduce(T input, ReductionOp reduction_op) const + { + const int lane_id = linear_tid % WARP_THREADS; + + _CCCL_PRAGMA_UNROLL_FULL() + for (int offset = WARP_THREADS / 2; offset > 0; offset >>= 1) + { + const T peer = cub::ShuffleIndex(input, lane_id ^ offset, 0xFFFFFFFFu); + input = reduction_op(input, peer); + } + return input; + } + +public: + /// @smemstorage{BlockRowReduceWarpBroadcast} + struct TempStorage : cub::Uninitialized<_TempStorage> + {}; + + _CCCL_DEVICE_API _CCCL_FORCEINLINE explicit BlockRowReduceWarpBroadcast(TempStorage& temp_storage) + : temp_storage(temp_storage.Alias()) + , linear_tid(static_cast(threadIdx.x)) + {} + + template + [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE T CommutativeReduce(T input, ReductionOp reduction_op, T identity) + { + const int warp_id = linear_tid / WARP_THREADS; + const int lane_id = linear_tid % WARP_THREADS; + const int row_id = warp_id / WarpsPerRow; + const int row_warp_id = warp_id % WarpsPerRow; + + T warp_aggregate = warp_all_reduce(input, reduction_op); + if (lane_id == 0) + { + temp_storage.partials[row_id][row_warp_id].Alias() = warp_aggregate; + } + __syncthreads(); + + T partial = identity; + if (lane_id < WarpsPerRow) + { + partial = temp_storage.partials[row_id][lane_id].Alias(); + } + return warp_all_reduce(partial, reduction_op); + } + + [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE T Sum(T input) + { + return CommutativeReduce(input, ::cuda::std::plus<>{}, T{}); + } +}; + +CUB_NAMESPACE_END diff --git a/cub/cub/cub.cuh b/cub/cub/cub.cuh index 66d7ca6434c..65a8a837615 100644 --- a/cub/cub/cub.cuh +++ b/cub/cub/cub.cuh @@ -37,6 +37,8 @@ #include #include #include +#include +#include #include #include // #include @@ -79,6 +81,9 @@ #include #include #include +#include +#include +#include #include #include diff --git a/cub/cub/warp/warp_reduce_batched_broadcast.cuh b/cub/cub/warp/warp_reduce_batched_broadcast.cuh new file mode 100644 index 00000000000..5be066fd6bf --- /dev/null +++ b/cub/cub/warp/warp_reduce_batched_broadcast.cuh @@ -0,0 +1,118 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +//! @file +//! @rst +//! The ``cub::WarpReduceBatchedBroadcast`` class provides :ref:`collective ` methods for +//! performing batched warp-wide reductions whose aggregates are returned to every participating logical lane. +//! @endrst + +#pragma once + +#include + +#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) +# pragma GCC system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) +# pragma clang system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) +# pragma system_header +#endif // no system header + +#include +#include + +#include +#include +#include + +CUB_NAMESPACE_BEGIN + +//! @rst +//! Batched warp-wide reduction adapter that broadcasts every batch aggregate to every participating logical lane. +//! ``Sum`` and ``CommutativeReduce`` use a shuffle all-reduce fast path and require commutative +//! reduction operators. Use ``WarpReduceBatched`` directly when only owner lanes need the batch outputs. +//! @endrst +template +class WarpReduceBatchedBroadcast +{ + static_assert(Batches > 0, "Batches must be greater than zero"); + static_assert(::cuda::is_power_of_two(LogicalWarpThreads), "LogicalWarpThreads must be a power of two"); + static_assert(LogicalWarpThreads > 0 && LogicalWarpThreads <= detail::warp_threads, + "LogicalWarpThreads must be in the range [1, 32]"); + + template + static _CCCL_DEVICE_API _CCCL_FORCEINLINE void check_constraints() + { + static_assert(detail::is_fixed_size_random_access_range_v, + "InputType must support operator[] and have a compile-time size"); + static_assert(detail::is_fixed_size_random_access_range_v, + "OutputType must support operator[] and have a compile-time size"); + static_assert(detail::static_size_v == Batches, "Input size must match Batches"); + static_assert(detail::static_size_v == Batches, "Output size must match Batches"); + static_assert(detail::has_binary_call_operator::value, + "ReductionOp must have the binary call operator: operator(T, T)"); + } + + template + _CCCL_DEVICE_API _CCCL_FORCEINLINE void commutative_all_reduce_batches(OutputType& outputs, ReductionOp reduction_op) + { + const auto lane_id = cub::detail::logical_lane_id(); + const auto logical_warp_id = cub::detail::logical_warp_id(); + const auto member_mask = + SyncPhysicalWarp ? 0xFFFFFFFFu : static_cast(cub::WarpMask(logical_warp_id)); + + _CCCL_PRAGMA_UNROLL_FULL() + for (int offset = LogicalWarpThreads / 2; offset > 0; offset >>= 1) + { + _CCCL_PRAGMA_UNROLL_FULL() + for (int batch = 0; batch < Batches; ++batch) + { + const T peer = cub::ShuffleIndex(outputs[batch], lane_id ^ offset, member_mask); + outputs[batch] = reduction_op(outputs[batch], peer); + } + } + } + +public: + /// @smemstorage{WarpReduceBatchedBroadcast} + using TempStorage = cub::NullType; + + _CCCL_DEVICE_API _CCCL_FORCEINLINE explicit WarpReduceBatchedBroadcast(TempStorage& /*temp_storage*/) {} + + template + _CCCL_DEVICE_API _CCCL_FORCEINLINE void + CommutativeReduce(const InputType& inputs, OutputType& outputs, ReductionOp reduction_op) + { + check_constraints(); + _CCCL_PRAGMA_UNROLL_FULL() + for (int batch = 0; batch < Batches; ++batch) + { + outputs[batch] = inputs[batch]; + } + commutative_all_reduce_batches(outputs, reduction_op); + } + + template + [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE ::cuda::std::array + CommutativeReduce(const InputType& inputs, ReductionOp reduction_op) + { + ::cuda::std::array outputs{}; + CommutativeReduce(inputs, outputs, reduction_op); + return outputs; + } + + template + _CCCL_DEVICE_API _CCCL_FORCEINLINE void Sum(const InputType& inputs, OutputType& outputs) + { + CommutativeReduce(inputs, outputs, ::cuda::std::plus<>{}); + } + + template + [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE ::cuda::std::array Sum(const InputType& inputs) + { + return CommutativeReduce(inputs, ::cuda::std::plus<>{}); + } +}; + +CUB_NAMESPACE_END diff --git a/cub/cub/warp/warp_reduce_broadcast.cuh b/cub/cub/warp/warp_reduce_broadcast.cuh new file mode 100644 index 00000000000..942097b411f --- /dev/null +++ b/cub/cub/warp/warp_reduce_broadcast.cuh @@ -0,0 +1,117 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +//! @file +//! @rst +//! The ``cub::WarpReduceBroadcast`` class provides :ref:`collective ` methods for +//! computing warp-wide reductions whose aggregate is returned to every participating logical lane. +//! @endrst + +#pragma once + +#include + +#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) +# pragma GCC system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) +# pragma clang system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) +# pragma system_header +#endif // no system header + +#include +#include +#include +#include +#include + +#include +#include + +CUB_NAMESPACE_BEGIN + +//! @rst +//! Warp-wide reduction adapter that broadcasts the aggregate to every participating logical lane. +//! ``Sum`` uses a shuffle all-reduce fast path. Generic ``Reduce`` preserves CUB's non-commutative +//! reduction semantics by using the owner-lane result and broadcasting it. +//! @endrst +template +class WarpReduceBroadcast +{ + static_assert(::cuda::is_power_of_two(LogicalWarpThreads), "LogicalWarpThreads must be a power of two"); + static_assert(LogicalWarpThreads > 0 && LogicalWarpThreads <= detail::warp_threads, + "LogicalWarpThreads must be in the range [1, 32]"); + + using WarpReduceT = cub::WarpReduce; + + typename WarpReduceT::TempStorage& temp_storage; + + template + [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE T commutative_all_reduce(T input, ReductionOp reduction_op) const + { + const auto lane_id = cub::detail::logical_lane_id(); + const auto logical_warp_id = cub::detail::logical_warp_id(); + const auto member_mask = cub::WarpMask(logical_warp_id); + + _CCCL_PRAGMA_UNROLL_FULL() + for (int offset = LogicalWarpThreads / 2; offset > 0; offset >>= 1) + { + const T peer = cub::ShuffleIndex(input, lane_id ^ offset, member_mask); + input = reduction_op(input, peer); + } + return input; + } + + [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE T broadcast_from_lane0(T aggregate) const + { + const auto logical_warp_id = cub::detail::logical_warp_id(); + const auto member_mask = cub::WarpMask(logical_warp_id); + return cub::ShuffleIndex(aggregate, 0, member_mask); + } + +public: + /// @smemstorage{WarpReduceBroadcast} + using TempStorage = typename WarpReduceT::TempStorage; + + _CCCL_DEVICE_API _CCCL_FORCEINLINE explicit WarpReduceBroadcast(TempStorage& temp_storage) + : temp_storage(temp_storage) + {} + + [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE T Sum(T input) + { + return commutative_all_reduce(input, ::cuda::std::plus<>{}); + } + + _CCCL_TEMPLATE(typename InputType) + _CCCL_REQUIRES(detail::is_fixed_size_random_access_range_v) + [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE T Sum(const InputType& input) + { + return commutative_all_reduce(cub::ThreadReduce(input, ::cuda::std::plus<>{}), ::cuda::std::plus<>{}); + } + + [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE T Sum(T input, int valid_items) + { + return broadcast_from_lane0(WarpReduceT(temp_storage).Sum(input, valid_items)); + } + + template + [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE T Reduce(T input, ReductionOp reduction_op) + { + return broadcast_from_lane0(WarpReduceT(temp_storage).Reduce(input, reduction_op)); + } + + _CCCL_TEMPLATE(typename InputType, typename ReductionOp) + _CCCL_REQUIRES(detail::is_fixed_size_random_access_range_v) + [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE T Reduce(const InputType& input, ReductionOp reduction_op) + { + return broadcast_from_lane0(WarpReduceT(temp_storage).Reduce(input, reduction_op)); + } + + template + [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE T Reduce(T input, ReductionOp reduction_op, int valid_items) + { + return broadcast_from_lane0(WarpReduceT(temp_storage).Reduce(input, reduction_op, valid_items)); + } +}; + +CUB_NAMESPACE_END diff --git a/cub/test/catch2_test_coop_collectives.cu b/cub/test/catch2_test_coop_collectives.cu new file mode 100644 index 00000000000..da9cdcd493a --- /dev/null +++ b/cub/test/catch2_test_coop_collectives.cu @@ -0,0 +1,359 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include + +struct max_int_op +{ + _CCCL_DEVICE_API _CCCL_FORCEINLINE int operator()(int lhs, int rhs) const + { + return lhs > rhs ? lhs : rhs; + } +}; + +static_assert(sizeof(typename cub::BlockRowReduceWarpBroadcast::TempStorage) == 16); + +__global__ void WarpReduceBroadcastKernel(int* out) +{ + // example-begin warp-reduce-broadcast + using warp_reduce_t = cub::WarpReduceBroadcast; + __shared__ typename warp_reduce_t::TempStorage temp_storage[2]; + + const int warp_id = static_cast(threadIdx.x) / 32; + const int result = warp_reduce_t(temp_storage[warp_id]).Sum(static_cast(threadIdx.x)); + // example-end warp-reduce-broadcast + + out[threadIdx.x] = result; +} + +C2H_TEST("warp reduce broadcast returns aggregate to every lane", "[coop][warp][reduce]") +{ + c2h::device_vector d_out(64); + + WarpReduceBroadcastKernel<<<1, 64>>>(thrust::raw_pointer_cast(d_out.data())); + REQUIRE(cudaSuccess == cudaPeekAtLastError()); + REQUIRE(cudaSuccess == cudaDeviceSynchronize()); + + c2h::host_vector expected(64); + for (int i = 0; i < 64; ++i) + { + expected[i] = i < 32 ? 496 : 1520; + } + REQUIRE(expected == d_out); +} + +__global__ void WarpReduceBroadcastLogicalKernel(int* out) +{ + // example-begin warp-reduce-broadcast-logical + using warp_reduce_t = cub::WarpReduceBroadcast; + __shared__ typename warp_reduce_t::TempStorage temp_storage[2]; + + const int physical_warp_id = static_cast(threadIdx.x) / 32; + const int result = warp_reduce_t(temp_storage[physical_warp_id]).Sum(static_cast(threadIdx.x)); + // example-end warp-reduce-broadcast-logical + + out[threadIdx.x] = result; +} + +C2H_TEST("warp reduce broadcast supports tiny logical warps", "[coop][warp][reduce]") +{ + c2h::device_vector d_out(64); + + WarpReduceBroadcastLogicalKernel<<<1, 64>>>(thrust::raw_pointer_cast(d_out.data())); + REQUIRE(cudaSuccess == cudaPeekAtLastError()); + REQUIRE(cudaSuccess == cudaDeviceSynchronize()); + + c2h::host_vector expected(64); + for (int i = 0; i < 64; ++i) + { + const int group_start = (i / 4) * 4; + expected[i] = group_start * 4 + 6; + } + REQUIRE(expected == d_out); +} + +__global__ void WarpReduceBatchedOwnerKernel(int* out) +{ + // example-begin warp-reduce-batched-owner + constexpr int batches = 3; + using warp_reduce_t = cub::WarpReduceBatched; + __shared__ typename warp_reduce_t::TempStorage temp_storage[2]; + + const int tid = static_cast(threadIdx.x); + const int warp_id = tid / 32; + const int lane_id = tid % 32; + + int inputs[batches] = {tid, tid + 1, tid + 2}; + int result = warp_reduce_t(temp_storage[warp_id]).Sum(inputs); + // example-end warp-reduce-batched-owner + + if (lane_id < batches) + { + out[warp_id * batches + lane_id] = result; + } +} + +C2H_TEST("warp reduce batched owner-lane layout returns one batch per lane", "[coop][warp][reduce]") +{ + c2h::device_vector d_out(6); + + WarpReduceBatchedOwnerKernel<<<1, 64>>>(thrust::raw_pointer_cast(d_out.data())); + REQUIRE(cudaSuccess == cudaPeekAtLastError()); + REQUIRE(cudaSuccess == cudaDeviceSynchronize()); + + c2h::host_vector expected{496, 528, 560, 1520, 1552, 1584}; + REQUIRE(expected == d_out); +} + +__global__ void WarpReduceBatchedBroadcastKernel(int* out) +{ + // example-begin warp-reduce-batched-broadcast + constexpr int batches = 5; + constexpr int logical_warp_threads = 4; + using warp_reduce_t = cub::WarpReduceBatchedBroadcast; + __shared__ typename warp_reduce_t::TempStorage temp_storage[2]; + + const int physical_warp_id = static_cast(threadIdx.x) / 32; + const int logical_lane = static_cast(threadIdx.x) % logical_warp_threads; + + ::cuda::std::array inputs{}; + for (int batch = 0; batch < batches; ++batch) + { + inputs[batch] = batch * 10 + logical_lane; + } + + const auto outputs = warp_reduce_t(temp_storage[physical_warp_id]).Sum(inputs); + // example-end warp-reduce-batched-broadcast + + for (int batch = 0; batch < batches; ++batch) + { + out[threadIdx.x * batches + batch] = outputs[batch]; + } +} + +C2H_TEST("warp reduce batched broadcast returns every batch to every lane", "[coop][warp][reduce]") +{ + constexpr int threads = 64; + constexpr int batches = 5; + c2h::device_vector d_out(threads * batches); + + WarpReduceBatchedBroadcastKernel<<<1, threads>>>(thrust::raw_pointer_cast(d_out.data())); + REQUIRE(cudaSuccess == cudaPeekAtLastError()); + REQUIRE(cudaSuccess == cudaDeviceSynchronize()); + + c2h::host_vector expected(threads * batches); + for (int i = 0; i < threads; ++i) + { + for (int batch = 0; batch < batches; ++batch) + { + expected[i * batches + batch] = batch * 40 + 6; + } + } + REQUIRE(expected == d_out); +} + +__global__ void BlockReduceBroadcastKernel(int* out) +{ + // example-begin block-reduce-broadcast + using block_reduce_t = cub::BlockReduceBroadcast; + __shared__ typename block_reduce_t::TempStorage temp_storage; + + const int result = block_reduce_t(temp_storage).Sum(static_cast(threadIdx.x)); + // example-end block-reduce-broadcast + + out[threadIdx.x] = result; +} + +C2H_TEST("block reduce broadcast returns aggregate to every thread", "[coop][block][reduce]") +{ + c2h::device_vector d_out(128); + + BlockReduceBroadcastKernel<<<1, 128>>>(thrust::raw_pointer_cast(d_out.data())); + REQUIRE(cudaSuccess == cudaPeekAtLastError()); + REQUIRE(cudaSuccess == cudaDeviceSynchronize()); + + c2h::host_vector expected(128, 8128); + REQUIRE(expected == d_out); +} + +__global__ void BlockRowReduceKernel(int* out) +{ + // example-begin block-row-reduce + using row_reduce_t = cub::BlockRowReduce; + __shared__ typename row_reduce_t::TempStorage temp_storage; + + const int result = row_reduce_t(temp_storage).Sum(static_cast(threadIdx.x)); + // example-end block-row-reduce + + out[threadIdx.x] = result; +} + +C2H_TEST("block row reduce returns one aggregate per row", "[coop][block][reduce]") +{ + c2h::device_vector d_out(128); + + BlockRowReduceKernel<<<1, 128>>>(thrust::raw_pointer_cast(d_out.data())); + REQUIRE(cudaSuccess == cudaPeekAtLastError()); + REQUIRE(cudaSuccess == cudaDeviceSynchronize()); + + c2h::host_vector expected(128); + for (int i = 0; i < 128; ++i) + { + expected[i] = i < 64 ? 2016 : 6112; + } + REQUIRE(expected == d_out); +} + +__global__ void BlockRowReduceWarpBroadcastKernel(int* out) +{ + // example-begin block-row-reduce-warp-broadcast + using row_reduce_t = cub::BlockRowReduceWarpBroadcast; + __shared__ typename row_reduce_t::TempStorage temp_storage; + + const int result = row_reduce_t(temp_storage).Sum(static_cast(threadIdx.x)); + // example-end block-row-reduce-warp-broadcast + + out[threadIdx.x] = result; +} + +C2H_TEST("block row reduce warp broadcast returns one aggregate per row", "[coop][block][reduce]") +{ + c2h::device_vector d_out(128); + + BlockRowReduceWarpBroadcastKernel<<<1, 128>>>(thrust::raw_pointer_cast(d_out.data())); + REQUIRE(cudaSuccess == cudaPeekAtLastError()); + REQUIRE(cudaSuccess == cudaDeviceSynchronize()); + + c2h::host_vector expected(128); + for (int i = 0; i < 128; ++i) + { + expected[i] = i < 64 ? 2016 : 6112; + } + REQUIRE(expected == d_out); +} + +__global__ void BlockRowReduceMaxKernel(int* out) +{ + using row_reduce_t = cub::BlockRowReduce; + __shared__ typename row_reduce_t::TempStorage temp_storage; + + const int value = static_cast(threadIdx.x) - 200; + const int result = row_reduce_t(temp_storage).Reduce(value, max_int_op{}); + + out[threadIdx.x] = result; +} + +__global__ void BlockRowReduceWarpBroadcastMaxKernel(int* out) +{ + using row_reduce_t = cub::BlockRowReduceWarpBroadcast; + __shared__ typename row_reduce_t::TempStorage temp_storage; + + const int value = static_cast(threadIdx.x) - 200; + const int result = row_reduce_t(temp_storage).CommutativeReduce(value, max_int_op{}, -10000); + + out[threadIdx.x] = result; +} + +C2H_TEST("block row reduce supports custom row reductions", "[coop][block][reduce]") +{ + c2h::device_vector d_row_out(128); + c2h::device_vector d_warp_broadcast_out(128); + + BlockRowReduceMaxKernel<<<1, 128>>>(thrust::raw_pointer_cast(d_row_out.data())); + REQUIRE(cudaSuccess == cudaPeekAtLastError()); + BlockRowReduceWarpBroadcastMaxKernel<<<1, 128>>>(thrust::raw_pointer_cast(d_warp_broadcast_out.data())); + REQUIRE(cudaSuccess == cudaPeekAtLastError()); + REQUIRE(cudaSuccess == cudaDeviceSynchronize()); + + c2h::host_vector expected(128); + for (int i = 0; i < 128; ++i) + { + expected[i] = i < 64 ? -137 : -73; + } + REQUIRE(expected == d_row_out); + REQUIRE(expected == d_warp_broadcast_out); +} + +__global__ void WarpSegmentedReduceKernel(int* out) +{ + // example-begin warp-segmented-row-reduce + using warp_reduce_t = cub::WarpReduce; + __shared__ typename warp_reduce_t::TempStorage temp_storage; + + const int lane_id = static_cast(threadIdx.x); + const int head_flag = (lane_id % 8) == 0; + const int result = warp_reduce_t(temp_storage).HeadSegmentedSum(lane_id, head_flag); + // example-end warp-segmented-row-reduce + + if (head_flag) + { + out[lane_id / 8] = result; + } +} + +C2H_TEST("warp segmented reduce maps fixed row segments to segment heads", "[coop][warp][reduce]") +{ + c2h::device_vector d_out(4); + + WarpSegmentedReduceKernel<<<1, 32>>>(thrust::raw_pointer_cast(d_out.data())); + REQUIRE(cudaSuccess == cudaPeekAtLastError()); + REQUIRE(cudaSuccess == cudaDeviceSynchronize()); + + c2h::host_vector expected{28, 92, 156, 220}; + REQUIRE(expected == d_out); +} + +__global__ void WarpAndBlockScanKernel(int* warp_out, int* block_out) +{ + // example-begin warp-block-scan + using warp_scan_t = cub::WarpScan; + using block_scan_t = cub::BlockScan; + + __shared__ typename warp_scan_t::TempStorage warp_storage[2]; + __shared__ typename block_scan_t::TempStorage block_storage; + + const int tid = static_cast(threadIdx.x); + const int warp_id = tid / 32; + + warp_out[tid] = warp_scan_t(warp_storage[warp_id]).Broadcast(tid, 0); + + int prefix{}; + block_scan_t(block_storage).ExclusiveSum(1, prefix); + block_out[tid] = prefix; + // example-end warp-block-scan +} + +C2H_TEST("warp broadcast and block scan cover scan-style cooperative primitives", "[coop][scan]") +{ + c2h::device_vector d_warp_out(64); + c2h::device_vector d_block_out(64); + + WarpAndBlockScanKernel<<<1, 64>>>( + thrust::raw_pointer_cast(d_warp_out.data()), thrust::raw_pointer_cast(d_block_out.data())); + REQUIRE(cudaSuccess == cudaPeekAtLastError()); + REQUIRE(cudaSuccess == cudaDeviceSynchronize()); + + c2h::host_vector expected_warp(64); + c2h::host_vector expected_block(64); + for (int i = 0; i < 64; ++i) + { + expected_warp[i] = i < 32 ? 0 : 32; + expected_block[i] = i; + } + + REQUIRE(expected_warp == d_warp_out); + REQUIRE(expected_block == d_block_out); +}