@@ -168,90 +168,62 @@ __device__ inline float16 cinn_min_fp16(const float16 left, const float16 right)
168168
169169
170170// ===============================================================
171- // 4. Warp Shuffle Wrappers
171+ // 4. Warp Shuffle Wrappers (Using Legacy API & Full Down Strategy)
172172// ===============================================================
173173
174- #define CINN_WARP_SHUFFLE_INTERNAL_IMPL(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \
174+ // 【核心修复】Warp Reduce 逻辑重写
175+ // 1. 弃用 XOR 模式:因为在 64-thread warp 下,跨 32 边界的 XOR 可能存在未定义行为或硬件 bug。
176+ // 2. 统一使用 DOWN 模式:__shfl_down 是单向规约,Lane 0 总是能收集到数据的,更加稳健。
177+ // 3. 严格的边界检查:确保 fetch 的来源线程在 Block 范围内,否则使用 INIT_VAL 填充。
178+
179+ #define CINN_WARP_SHUFFLE_INTERNAL_IMPL(REDUCE_TYPE, INIT_VAL, DTYPE) \
175180 __device__ inline DTYPE cinn_warp_shuffle_##REDUCE_TYPE##_internal( \
176181 const DTYPE value) { \
177- DTYPE tmp_val = value, shfl_res; \
182+ DTYPE tmp_val = value; \
178183 unsigned int thread_id = threadIdx.x; \
179184 unsigned int block_dim = blockDim.x; \
180- unsigned int last_warp_size = block_dim - (thread_id - (threadIdx.x % WARP_SIZE)); \
181- if (last_warp_size < WARP_SIZE) { \
182- for (unsigned int offset = WARP_SIZE / 2; offset >= 1; offset /= 2) { \
183- /* 使用通用的 shuffle down 实现 */ \
184- shfl_res = cinn_warp_shuffle_down_##DTYPE##_wrapper(tmp_val, offset); \
185- tmp_val = cinn_##REDUCE_TYPE(thread_id + offset < block_dim \
186- ? shfl_res \
187- : (DTYPE)(INITIAL_VALUE), \
188- tmp_val); \
189- } \
190- /* 这里的 __shfl 广播可以用 shfl_sync(0) 替代 */ \
191- tmp_val = __shfl_sync(0xffffffff, tmp_val, 0); \
192- } else { \
193- for (unsigned int offset = WARP_SIZE / 2; offset >= 1; offset /= 2) { \
194- tmp_val = cinn_##REDUCE_TYPE(tmp_val, \
195- cinn_warp_shuffle_xor_##DTYPE##_wrapper(tmp_val, offset)); \
196- } \
185+ /* 始终使用 Down Shuffle 进行规约 (Log2 复杂度) */ \
186+ for (unsigned int offset = WARP_SIZE / 2; offset >= 1; offset /= 2) { \
187+ DTYPE shfl_res = cinn_warp_shuffle_down_##DTYPE##_wrapper(tmp_val, offset); \
188+ /* 检查数据来源是否有效:当前线程+offset 必须还在 Block 范围内 */ \
189+ /* 如果 Block 大小不是 WARP_SIZE 的倍数,这一步至关重要 */ \
190+ DTYPE neighbor = (thread_id + offset < block_dim) ? shfl_res : (DTYPE)(INIT_VAL); \
191+ tmp_val = cinn_##REDUCE_TYPE(tmp_val, neighbor); \
197192 } \
198- return tmp_val; \
193+ /* 广播:虽然 Down Shuffle 只有 Lane 0 结果正确,但这里为了兼容 XOR 语义 */ \
194+ /* 我们用 shfl 0 把 Lane 0 的结果广播给所有人 (CINN Block Reduce 需要) */ \
195+ return __shfl(tmp_val, 0); \
199196 }
200197
201- // --- Warp Shuffle Primitives (Internal Helpers) ---
202- // 为了适配宏展开,这里定义带后缀的 wrapper,统一 float16/double 处理
203-
204- __device__ inline float cinn_warp_shuffle_down_float_wrapper(float v, int factor) { return __shfl_down_sync(0xffffffff, v, factor); }
205- __device__ inline float cinn_warp_shuffle_xor_float_wrapper(float v, int factor) { return __shfl_xor_sync(0xffffffff, v, factor); }
206-
207- __device__ inline int cinn_warp_shuffle_down_int_wrapper(int v, int factor) { return __shfl_down_sync(0xffffffff, v, factor); }
208- __device__ inline int cinn_warp_shuffle_xor_int_wrapper(int v, int factor) { return __shfl_xor_sync(0xffffffff, v, factor); }
198+ // --- Warp Shuffle Primitives (Legacy API without mask) ---
209199
210- __device__ inline bool cinn_warp_shuffle_down_bool_wrapper(bool v, int factor) { return __shfl_down_sync(0xffffffff, v, factor); }
211- __device__ inline bool cinn_warp_shuffle_xor_bool_wrapper(bool v, int factor) { return __shfl_xor_sync(0xffffffff, v, factor); }
200+ __device__ inline float cinn_warp_shuffle_down_float_wrapper(float v, int factor) { return __shfl_down(v, factor); }
201+ __device__ inline int cinn_warp_shuffle_down_int_wrapper(int v, int factor) { return __shfl_down(v, factor); }
202+ __device__ inline bool cinn_warp_shuffle_down_bool_wrapper(bool v, int factor) { return __shfl_down(v, factor); }
212203
213204__device__ inline double cinn_warp_shuffle_down_double_wrapper(double v, int factor) {
214205 unsigned long long int val_u64 = *(unsigned long long int*)&v;
215206 int lo = (int)val_u64; int hi = (int)(val_u64 >> 32);
216- lo = __shfl_down_sync(0xffffffff, lo, factor);
217- hi = __shfl_down_sync(0xffffffff, hi, factor);
218- unsigned long long int res_u64 = ((unsigned long long int)hi << 32) | (unsigned int)lo;
219- return *(double*)&res_u64;
220- }
221- __device__ inline double cinn_warp_shuffle_xor_double_wrapper(double v, int factor) {
222- unsigned long long int val_u64 = *(unsigned long long int*)&v;
223- int lo = (int)val_u64; int hi = (int)(val_u64 >> 32);
224- lo = __shfl_xor_sync(0xffffffff, lo, factor);
225- hi = __shfl_xor_sync(0xffffffff, hi, factor);
207+ lo = __shfl_down(lo, factor);
208+ hi = __shfl_down(hi, factor);
226209 unsigned long long int res_u64 = ((unsigned long long int)hi << 32) | (unsigned int)lo;
227210 return *(double*)&res_u64;
228211}
229212
230213__device__ inline int64_t cinn_warp_shuffle_down_int64_t_wrapper(int64_t v, int factor) {
231214 int lo = (int)v; int hi = (int)(v >> 32);
232- lo = __shfl_down_sync(0xffffffff, lo, factor);
233- hi = __shfl_down_sync(0xffffffff, hi, factor);
234- return ((int64_t)hi << 32) | (unsigned int)lo;
235- }
236- __device__ inline int64_t cinn_warp_shuffle_xor_int64_t_wrapper(int64_t v, int factor) {
237- int lo = (int)v; int hi = (int)(v >> 32);
238- lo = __shfl_xor_sync(0xffffffff, lo, factor);
239- hi = __shfl_xor_sync(0xffffffff, hi, factor);
215+ lo = __shfl_down(lo, factor);
216+ hi = __shfl_down(hi, factor);
240217 return ((int64_t)hi << 32) | (unsigned int)lo;
241218}
242219
243220__device__ inline float16 cinn_warp_shuffle_down_float16_wrapper(float16 v, int factor) {
244221 unsigned short val = __half_as_ushort(v);
245- unsigned short res = (unsigned short)__shfl_down_sync(0xffffffff, (int)val, factor);
246- return __ushort_as_half(res);
247- }
248- __device__ inline float16 cinn_warp_shuffle_xor_float16_wrapper(float16 v, int factor) {
249- unsigned short val = __half_as_ushort(v);
250- unsigned short res = (unsigned short)__shfl_xor_sync(0xffffffff, (int)val, factor);
222+ unsigned short res = (unsigned short)__shfl_down((int)val, factor);
251223 return __ushort_as_half(res);
252224}
253225
254- // 展开 Internal Implementations
226+ // Expand Warp Shuffle
255227EXPAND_REDUCE_INT32_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL)
256228EXPAND_REDUCE_INT64_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL)
257229EXPAND_REDUCE_FP32_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL)
@@ -263,48 +235,44 @@ EXPAND_REDUCE_FP16_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL)
263235// 5. Block Reduce & Discrete Reduce & Grid Reduce
264236// ===============================================================
265237
266- #define CINN_BLOCK_REDUCE_IMPL(DTYPE, INITIAL_VALUE, cinn_warp_shuffle_internal) \
267- /* 1. Warp内规约 */ \
268- DTYPE tmp_val = cinn_warp_shuffle_internal(value); \
269- \
270- /* 如果只有一个 warp,直接返回 */ \
271- if (return_warp || blockDim.x <= WARP_SIZE) { \
272- return tmp_val; \
273- } \
274- __syncthreads(); \
275- \
276- /* 2. 每个 Warp 的结果写入共享内存 (仅 Lane 0 写入) */ \
277- if (threadIdx.x % WARP_SIZE == 0) { \
278- shm[threadIdx.x / WARP_SIZE] = tmp_val; \
279- } \
280- __syncthreads(); \
281- \
282- /* 3. Warp 0 负责汇总 */ \
283- if (threadIdx.x < WARP_SIZE) { \
284- /* 计算有多少个 Warp */ \
285- int num_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; \
286- \
287- /* 【核心修复】Lane >= num_warps 的线程必须加载 IDENTITY,否则后面 shuffle 会引入脏数据 */ \
288- DTYPE reduce_val = (DTYPE)(INITIAL_VALUE); \
289- if (threadIdx.x < num_warps) { \
290- reduce_val = shm[threadIdx.x]; \
291- } \
292- \
293- /* Warp 0 再次进行规约 (所有 64 个线程都参与) */ \
294- reduce_val = cinn_warp_shuffle_internal(reduce_val); \
295- \
296- /* 结果写入 shm[0] */ \
297- if (threadIdx.x == 0) { \
298- shm[0] = reduce_val; \
299- } \
300- } \
301- __syncthreads(); \
238+ // Block Reduce Implementation
239+ // 1. Warp Reduce -> SHM
240+ // 2. Warp 0 reads SHM and Pads with Identity
241+ // 3. Warp 0 Reduce
242+ // 4. Broadcast
243+ #define CINN_BLOCK_REDUCE_IMPL(DTYPE, INIT_VAL, cinn_warp_shuffle_internal) \
244+ /* 1. Warp Reduce */ \
245+ DTYPE tmp_val = cinn_warp_shuffle_internal(value); \
246+ if (return_warp || blockDim.x <= WARP_SIZE) { \
247+ return tmp_val; \
248+ } \
249+ __syncthreads(); \
250+ /* 2. Write Warp results to SHM (Lane 0 only) */ \
251+ if (threadIdx.x % WARP_SIZE == 0) { \
252+ shm[threadIdx.x / WARP_SIZE] = tmp_val; \
253+ } \
254+ __syncthreads(); \
255+ /* 3. Inter-Warp Reduce (Warp 0 only) */ \
256+ if (threadIdx.x < WARP_SIZE) { \
257+ int num_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; \
258+ /* Pad with Identity value for idle threads in Warp 0 */ \
259+ DTYPE reduce_val = (DTYPE)(INIT_VAL); \
260+ if (threadIdx.x < num_warps) { \
261+ reduce_val = shm[threadIdx.x]; \
262+ } \
263+ /* Reduce across all threads in Warp 0 */ \
264+ reduce_val = cinn_warp_shuffle_internal(reduce_val); \
265+ if (threadIdx.x == 0) { \
266+ shm[0] = reduce_val; \
267+ } \
268+ } \
269+ __syncthreads(); \
302270 return shm[0];
303271
304- #define CINN_BLOCK_REDUCE_MACRO(REDUCE_TYPE, INITIAL_VALUE , DTYPE) \
272+ #define CINN_BLOCK_REDUCE_MACRO(REDUCE_TYPE, INIT_VAL , DTYPE) \
305273 __device__ inline DTYPE cinn_block_reduce_##REDUCE_TYPE( \
306274 const DTYPE value, DTYPE *shm, bool return_warp = false) { \
307- CINN_BLOCK_REDUCE_IMPL(DTYPE, INITIAL_VALUE , cinn_warp_shuffle_##REDUCE_TYPE##_internal); \
275+ CINN_BLOCK_REDUCE_IMPL(DTYPE, INIT_VAL , cinn_warp_shuffle_##REDUCE_TYPE##_internal); \
308276 }
309277
310278EXPAND_REDUCE_INT32_MACRO(CINN_BLOCK_REDUCE_MACRO)
@@ -327,7 +295,7 @@ EXPAND_REDUCE_FP16_MACRO(CINN_BLOCK_REDUCE_MACRO)
327295 } \
328296 return shm[threadIdx.x];
329297
330- #define CINN_DISCRETE_REDUCE_MACRO(REDUCE_TYPE, INITIAL_VALUE , DTYPE) \
298+ #define CINN_DISCRETE_REDUCE_MACRO(REDUCE_TYPE, INIT_VAL , DTYPE) \
331299 __device__ inline DTYPE cinn_discrete_reduce_##REDUCE_TYPE( \
332300 const DTYPE value, DTYPE *shm) { \
333301 CINN_DISCRETE_REDUCE_IMPL(REDUCE_TYPE, value); \
@@ -348,10 +316,10 @@ EXPAND_REDUCE_FP16_MACRO(CINN_DISCRETE_REDUCE_MACRO)
348316 } \
349317 return tmp_val;
350318
351- #define CINN_GRID_REDUCE_MACRO(REDUCE_TYPE, INITIAL_VALUE , DTYPE) \
319+ #define CINN_GRID_REDUCE_MACRO(REDUCE_TYPE, INIT_VAL , DTYPE) \
352320 __device__ inline DTYPE cinn_grid_reduce_##REDUCE_TYPE( \
353321 const DTYPE *mem, int spatial_size, int spatial_index) { \
354- CINN_GRID_REDUCE_IMPL(REDUCE_TYPE, (DTYPE)(INITIAL_VALUE ), DTYPE); \
322+ CINN_GRID_REDUCE_IMPL(REDUCE_TYPE, (DTYPE)(INIT_VAL ), DTYPE); \
355323 }
356324
357325EXPAND_REDUCE_INT32_MACRO(CINN_GRID_REDUCE_MACRO)
@@ -372,7 +340,6 @@ __device__ inline bool cinn_grid_reduce_update_semaphore(int *semaphores) {
372340 __syncthreads();
373341 return done;
374342}
375-
376343// ===============================================================
377344// 6. Standard Math Functions
378345// ===============================================================
0 commit comments