Skip to content

Commit 6211f97

Browse files
committed
Fix CINN compilation errors and incorrect reduction results on MetaX backend. Run test_elementwise_pow_op_metax.py success.
1 parent 76491c8 commit 6211f97

2 files changed

Lines changed: 65 additions & 100 deletions

File tree

backends/metax_gpu/cinn/compiler/compiler.cc

Lines changed: 65 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -168,90 +168,62 @@ __device__ inline float16 cinn_min_fp16(const float16 left, const float16 right)
168168
169169
170170
// ===============================================================
171-
// 4. Warp Shuffle Wrappers
171+
// 4. Warp Shuffle Wrappers (Using Legacy API & Full Down Strategy)
172172
// ===============================================================
173173
174-
#define CINN_WARP_SHUFFLE_INTERNAL_IMPL(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \
174+
// 【核心修复】Warp Reduce 逻辑重写
175+
// 1. 弃用 XOR 模式:因为在 64-thread warp 下,跨 32 边界的 XOR 可能存在未定义行为或硬件 bug。
176+
// 2. 统一使用 DOWN 模式:__shfl_down 是单向规约,Lane 0 总是能收集到数据的,更加稳健。
177+
// 3. 严格的边界检查:确保 fetch 的来源线程在 Block 范围内,否则使用 INIT_VAL 填充。
178+
179+
#define CINN_WARP_SHUFFLE_INTERNAL_IMPL(REDUCE_TYPE, INIT_VAL, DTYPE) \
175180
__device__ inline DTYPE cinn_warp_shuffle_##REDUCE_TYPE##_internal( \
176181
const DTYPE value) { \
177-
DTYPE tmp_val = value, shfl_res; \
182+
DTYPE tmp_val = value; \
178183
unsigned int thread_id = threadIdx.x; \
179184
unsigned int block_dim = blockDim.x; \
180-
unsigned int last_warp_size = block_dim - (thread_id - (threadIdx.x % WARP_SIZE)); \
181-
if (last_warp_size < WARP_SIZE) { \
182-
for (unsigned int offset = WARP_SIZE / 2; offset >= 1; offset /= 2) { \
183-
/* 使用通用的 shuffle down 实现 */ \
184-
shfl_res = cinn_warp_shuffle_down_##DTYPE##_wrapper(tmp_val, offset); \
185-
tmp_val = cinn_##REDUCE_TYPE(thread_id + offset < block_dim \
186-
? shfl_res \
187-
: (DTYPE)(INITIAL_VALUE), \
188-
tmp_val); \
189-
} \
190-
/* 这里的 __shfl 广播可以用 shfl_sync(0) 替代 */ \
191-
tmp_val = __shfl_sync(0xffffffff, tmp_val, 0); \
192-
} else { \
193-
for (unsigned int offset = WARP_SIZE / 2; offset >= 1; offset /= 2) { \
194-
tmp_val = cinn_##REDUCE_TYPE(tmp_val, \
195-
cinn_warp_shuffle_xor_##DTYPE##_wrapper(tmp_val, offset)); \
196-
} \
185+
/* 始终使用 Down Shuffle 进行规约 (Log2 复杂度) */ \
186+
for (unsigned int offset = WARP_SIZE / 2; offset >= 1; offset /= 2) { \
187+
DTYPE shfl_res = cinn_warp_shuffle_down_##DTYPE##_wrapper(tmp_val, offset); \
188+
/* 检查数据来源是否有效:当前线程+offset 必须还在 Block 范围内 */ \
189+
/* 如果 Block 大小不是 WARP_SIZE 的倍数,这一步至关重要 */ \
190+
DTYPE neighbor = (thread_id + offset < block_dim) ? shfl_res : (DTYPE)(INIT_VAL); \
191+
tmp_val = cinn_##REDUCE_TYPE(tmp_val, neighbor); \
197192
} \
198-
return tmp_val; \
193+
/* 广播:虽然 Down Shuffle 只有 Lane 0 结果正确,但这里为了兼容 XOR 语义 */ \
194+
/* 我们用 shfl 0 把 Lane 0 的结果广播给所有人 (CINN Block Reduce 需要) */ \
195+
return __shfl(tmp_val, 0); \
199196
}
200197
201-
// --- Warp Shuffle Primitives (Internal Helpers) ---
202-
// 为了适配宏展开,这里定义带后缀的 wrapper,统一 float16/double 处理
203-
204-
__device__ inline float cinn_warp_shuffle_down_float_wrapper(float v, int factor) { return __shfl_down_sync(0xffffffff, v, factor); }
205-
__device__ inline float cinn_warp_shuffle_xor_float_wrapper(float v, int factor) { return __shfl_xor_sync(0xffffffff, v, factor); }
206-
207-
__device__ inline int cinn_warp_shuffle_down_int_wrapper(int v, int factor) { return __shfl_down_sync(0xffffffff, v, factor); }
208-
__device__ inline int cinn_warp_shuffle_xor_int_wrapper(int v, int factor) { return __shfl_xor_sync(0xffffffff, v, factor); }
198+
// --- Warp Shuffle Primitives (Legacy API without mask) ---
209199
210-
__device__ inline bool cinn_warp_shuffle_down_bool_wrapper(bool v, int factor) { return __shfl_down_sync(0xffffffff, v, factor); }
211-
__device__ inline bool cinn_warp_shuffle_xor_bool_wrapper(bool v, int factor) { return __shfl_xor_sync(0xffffffff, v, factor); }
200+
__device__ inline float cinn_warp_shuffle_down_float_wrapper(float v, int factor) { return __shfl_down(v, factor); }
201+
__device__ inline int cinn_warp_shuffle_down_int_wrapper(int v, int factor) { return __shfl_down(v, factor); }
202+
__device__ inline bool cinn_warp_shuffle_down_bool_wrapper(bool v, int factor) { return __shfl_down(v, factor); }
212203
213204
__device__ inline double cinn_warp_shuffle_down_double_wrapper(double v, int factor) {
214205
unsigned long long int val_u64 = *(unsigned long long int*)&v;
215206
int lo = (int)val_u64; int hi = (int)(val_u64 >> 32);
216-
lo = __shfl_down_sync(0xffffffff, lo, factor);
217-
hi = __shfl_down_sync(0xffffffff, hi, factor);
218-
unsigned long long int res_u64 = ((unsigned long long int)hi << 32) | (unsigned int)lo;
219-
return *(double*)&res_u64;
220-
}
221-
__device__ inline double cinn_warp_shuffle_xor_double_wrapper(double v, int factor) {
222-
unsigned long long int val_u64 = *(unsigned long long int*)&v;
223-
int lo = (int)val_u64; int hi = (int)(val_u64 >> 32);
224-
lo = __shfl_xor_sync(0xffffffff, lo, factor);
225-
hi = __shfl_xor_sync(0xffffffff, hi, factor);
207+
lo = __shfl_down(lo, factor);
208+
hi = __shfl_down(hi, factor);
226209
unsigned long long int res_u64 = ((unsigned long long int)hi << 32) | (unsigned int)lo;
227210
return *(double*)&res_u64;
228211
}
229212
230213
__device__ inline int64_t cinn_warp_shuffle_down_int64_t_wrapper(int64_t v, int factor) {
231214
int lo = (int)v; int hi = (int)(v >> 32);
232-
lo = __shfl_down_sync(0xffffffff, lo, factor);
233-
hi = __shfl_down_sync(0xffffffff, hi, factor);
234-
return ((int64_t)hi << 32) | (unsigned int)lo;
235-
}
236-
__device__ inline int64_t cinn_warp_shuffle_xor_int64_t_wrapper(int64_t v, int factor) {
237-
int lo = (int)v; int hi = (int)(v >> 32);
238-
lo = __shfl_xor_sync(0xffffffff, lo, factor);
239-
hi = __shfl_xor_sync(0xffffffff, hi, factor);
215+
lo = __shfl_down(lo, factor);
216+
hi = __shfl_down(hi, factor);
240217
return ((int64_t)hi << 32) | (unsigned int)lo;
241218
}
242219
243220
__device__ inline float16 cinn_warp_shuffle_down_float16_wrapper(float16 v, int factor) {
244221
unsigned short val = __half_as_ushort(v);
245-
unsigned short res = (unsigned short)__shfl_down_sync(0xffffffff, (int)val, factor);
246-
return __ushort_as_half(res);
247-
}
248-
__device__ inline float16 cinn_warp_shuffle_xor_float16_wrapper(float16 v, int factor) {
249-
unsigned short val = __half_as_ushort(v);
250-
unsigned short res = (unsigned short)__shfl_xor_sync(0xffffffff, (int)val, factor);
222+
unsigned short res = (unsigned short)__shfl_down((int)val, factor);
251223
return __ushort_as_half(res);
252224
}
253225
254-
// 展开 Internal Implementations
226+
// Expand Warp Shuffle
255227
EXPAND_REDUCE_INT32_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL)
256228
EXPAND_REDUCE_INT64_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL)
257229
EXPAND_REDUCE_FP32_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL)
@@ -263,48 +235,44 @@ EXPAND_REDUCE_FP16_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL)
263235
// 5. Block Reduce & Discrete Reduce & Grid Reduce
264236
// ===============================================================
265237
266-
#define CINN_BLOCK_REDUCE_IMPL(DTYPE, INITIAL_VALUE, cinn_warp_shuffle_internal) \
267-
/* 1. Warp内规约 */ \
268-
DTYPE tmp_val = cinn_warp_shuffle_internal(value); \
269-
\
270-
/* 如果只有一个 warp,直接返回 */ \
271-
if (return_warp || blockDim.x <= WARP_SIZE) { \
272-
return tmp_val; \
273-
} \
274-
__syncthreads(); \
275-
\
276-
/* 2. 每个 Warp 的结果写入共享内存 (仅 Lane 0 写入) */ \
277-
if (threadIdx.x % WARP_SIZE == 0) { \
278-
shm[threadIdx.x / WARP_SIZE] = tmp_val; \
279-
} \
280-
__syncthreads(); \
281-
\
282-
/* 3. Warp 0 负责汇总 */ \
283-
if (threadIdx.x < WARP_SIZE) { \
284-
/* 计算有多少个 Warp */ \
285-
int num_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; \
286-
\
287-
/* 【核心修复】Lane >= num_warps 的线程必须加载 IDENTITY,否则后面 shuffle 会引入脏数据 */ \
288-
DTYPE reduce_val = (DTYPE)(INITIAL_VALUE); \
289-
if (threadIdx.x < num_warps) { \
290-
reduce_val = shm[threadIdx.x]; \
291-
} \
292-
\
293-
/* Warp 0 再次进行规约 (所有 64 个线程都参与) */ \
294-
reduce_val = cinn_warp_shuffle_internal(reduce_val); \
295-
\
296-
/* 结果写入 shm[0] */ \
297-
if (threadIdx.x == 0) { \
298-
shm[0] = reduce_val; \
299-
} \
300-
} \
301-
__syncthreads(); \
238+
// Block Reduce Implementation
239+
// 1. Warp Reduce -> SHM
240+
// 2. Warp 0 reads SHM and Pads with Identity
241+
// 3. Warp 0 Reduce
242+
// 4. Broadcast
243+
#define CINN_BLOCK_REDUCE_IMPL(DTYPE, INIT_VAL, cinn_warp_shuffle_internal) \
244+
/* 1. Warp Reduce */ \
245+
DTYPE tmp_val = cinn_warp_shuffle_internal(value); \
246+
if (return_warp || blockDim.x <= WARP_SIZE) { \
247+
return tmp_val; \
248+
} \
249+
__syncthreads(); \
250+
/* 2. Write Warp results to SHM (Lane 0 only) */ \
251+
if (threadIdx.x % WARP_SIZE == 0) { \
252+
shm[threadIdx.x / WARP_SIZE] = tmp_val; \
253+
} \
254+
__syncthreads(); \
255+
/* 3. Inter-Warp Reduce (Warp 0 only) */ \
256+
if (threadIdx.x < WARP_SIZE) { \
257+
int num_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; \
258+
/* Pad with Identity value for idle threads in Warp 0 */ \
259+
DTYPE reduce_val = (DTYPE)(INIT_VAL); \
260+
if (threadIdx.x < num_warps) { \
261+
reduce_val = shm[threadIdx.x]; \
262+
} \
263+
/* Reduce across all threads in Warp 0 */ \
264+
reduce_val = cinn_warp_shuffle_internal(reduce_val); \
265+
if (threadIdx.x == 0) { \
266+
shm[0] = reduce_val; \
267+
} \
268+
} \
269+
__syncthreads(); \
302270
return shm[0];
303271
304-
#define CINN_BLOCK_REDUCE_MACRO(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \
272+
#define CINN_BLOCK_REDUCE_MACRO(REDUCE_TYPE, INIT_VAL, DTYPE) \
305273
__device__ inline DTYPE cinn_block_reduce_##REDUCE_TYPE( \
306274
const DTYPE value, DTYPE *shm, bool return_warp = false) { \
307-
CINN_BLOCK_REDUCE_IMPL(DTYPE, INITIAL_VALUE, cinn_warp_shuffle_##REDUCE_TYPE##_internal); \
275+
CINN_BLOCK_REDUCE_IMPL(DTYPE, INIT_VAL, cinn_warp_shuffle_##REDUCE_TYPE##_internal); \
308276
}
309277
310278
EXPAND_REDUCE_INT32_MACRO(CINN_BLOCK_REDUCE_MACRO)
@@ -327,7 +295,7 @@ EXPAND_REDUCE_FP16_MACRO(CINN_BLOCK_REDUCE_MACRO)
327295
} \
328296
return shm[threadIdx.x];
329297
330-
#define CINN_DISCRETE_REDUCE_MACRO(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \
298+
#define CINN_DISCRETE_REDUCE_MACRO(REDUCE_TYPE, INIT_VAL, DTYPE) \
331299
__device__ inline DTYPE cinn_discrete_reduce_##REDUCE_TYPE( \
332300
const DTYPE value, DTYPE *shm) { \
333301
CINN_DISCRETE_REDUCE_IMPL(REDUCE_TYPE, value); \
@@ -348,10 +316,10 @@ EXPAND_REDUCE_FP16_MACRO(CINN_DISCRETE_REDUCE_MACRO)
348316
} \
349317
return tmp_val;
350318
351-
#define CINN_GRID_REDUCE_MACRO(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \
319+
#define CINN_GRID_REDUCE_MACRO(REDUCE_TYPE, INIT_VAL, DTYPE) \
352320
__device__ inline DTYPE cinn_grid_reduce_##REDUCE_TYPE( \
353321
const DTYPE *mem, int spatial_size, int spatial_index) { \
354-
CINN_GRID_REDUCE_IMPL(REDUCE_TYPE, (DTYPE)(INITIAL_VALUE), DTYPE); \
322+
CINN_GRID_REDUCE_IMPL(REDUCE_TYPE, (DTYPE)(INIT_VAL), DTYPE); \
355323
}
356324
357325
EXPAND_REDUCE_INT32_MACRO(CINN_GRID_REDUCE_MACRO)
@@ -372,7 +340,6 @@ __device__ inline bool cinn_grid_reduce_update_semaphore(int *semaphores) {
372340
__syncthreads();
373341
return done;
374342
}
375-
376343
// ===============================================================
377344
// 6. Standard Math Functions
378345
// ===============================================================

backends/metax_gpu/tests/unittest/test_elementwise_pow_op_metax.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ def setUp(self):
8888
}
8989
self.outputs = {"Out": np.power(self.inputs["X"], self.inputs["Y"])}
9090

91-
'''
9291
class TestElementwisePowOp_ZeroDim3(TestElementwisePowOp):
9392
def setUp(self):
9493
self.op_type = "elementwise_pow"
@@ -455,7 +454,6 @@ def test_check_grad(self):
455454
only_check_prim=True,
456455
check_prim_pir=True,
457456
)
458-
'''
459457

460458
if __name__ == "__main__":
461459
unittest.main()

0 commit comments

Comments
 (0)