Skip to content

Commit d46ac16

Browse files
committed
feat(cuda): add high-performance warp and block reduction primitives
- Implement warp-level parallel reductions using `__shfl_xor_sync` (Sum, Max, Min). - Implement a `warpBroadcast` helper utility targeting lane 0. - Introduce skeleton for a shared-memory backed `blockReduceSum`. - Guard all device-specific primitives behind `__CUDACC__` flags.
1 parent 3d6bc62 commit d46ac16

1 file changed

Lines changed: 66 additions & 0 deletions

File tree

cuda/includes/reduce.cuh

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#pragma once
2+
#include "common.h"
3+
4+
#ifdef __CUDACC__
5+
6+
static constexpr unsigned FULL_MASK = 0xffffffff;
7+
// Warp reductions
8+
__device__ QX_INLINE float warpReduceSum(float v)
9+
{
10+
v += __shfl_xor_sync(FULL_MASK, v, 16);
11+
v += __shfl_xor_sync(FULL_MASK, v, 8);
12+
v += __shfl_xor_sync(FULL_MASK, v, 4);
13+
v += __shfl_xor_sync(FULL_MASK, v, 2);
14+
v += __shfl_xor_sync(FULL_MASK, v, 1);
15+
return v;
16+
}
17+
__device__ QX_INLINE float warpReduceMax(float v)
18+
{
19+
v = fmaxf(v, __shfl_xor_sync(FULL_MASK, v, 16));
20+
v = fmaxf(v, __shfl_xor_sync(FULL_MASK, v, 8));
21+
v = fmaxf(v, __shfl_xor_sync(FULL_MASK, v, 4));
22+
v = fmaxf(v, __shfl_xor_sync(FULL_MASK, v, 2));
23+
v = fmaxf(v, __shfl_xor_sync(FULL_MASK, v, 1));
24+
return v;
25+
}
26+
__device__ QX_INLINE float warpReduceMin(float v)
27+
{
28+
v = fminf(v, __shfl_xor_sync(FULL_MASK, v, 16));
29+
v = fminf(v, __shfl_xor_sync(FULL_MASK, v, 8));
30+
v = fminf(v, __shfl_xor_sync(FULL_MASK, v, 4));
31+
v = fminf(v, __shfl_xor_sync(FULL_MASK, v, 2));
32+
v = fminf(v, __shfl_xor_sync(FULL_MASK, v, 1));
33+
return v;
34+
}
35+
__device__ QX_INLINE float warpBroadcast(float v)
36+
{
37+
return __shfl_sync(FULL_MASK, v, 0);
38+
}
39+
__device__ QX_INLINE float blockReduceSum(float v, float *smem)
40+
{
41+
int lane = threadIdx.x % QX_WARP_SIZE;
42+
int wid = threadIdx.x / QX_WARP_SIZE;
43+
v = warpReduceSum(v);
44+
if (lane == 0)
45+
smem[wid] = v;
46+
__syncthreads();
47+
v = (threadIdx.x < blockDim.x / QX_WARP_SIZE) ? smem[lane] : 0.f;
48+
if (wid == 0)
49+
v = warpReduceSum(v);
50+
return v;
51+
}
52+
__device__ QX_INLINE float blockReduceMax(float v, float *smem)
53+
{
54+
int lane = threadIdx.x % QX_WARP_SIZE;
55+
int wid = threadIdx.x / QX_WARP_SIZE;
56+
v = warpReduceMax(v);
57+
if (lane == 0)
58+
smem[wid] = v;
59+
__syncthreads();
60+
v = (threadIdx.x < blockDim.x / QX_WARP_SIZE) ? smem[lane] : QX_NEG_INF_F32;
61+
if (wid == 0)
62+
v = warpReduceMax(v);
63+
return v;
64+
}
65+
66+
#endif // __CUDACC__

0 commit comments

Comments
 (0)