@@ -39,6 +39,7 @@ namespace metax {
3939// ============================================================
4040static const char * kMacaRuntimeSource = R"MACA_SOURCE(
4141#pragma once
42+ #include <cooperative_groups.h>
4243#include <cuda_fp16.h>
4344#include <cuda_runtime.h>
4445
@@ -812,7 +813,7 @@ __device__ inline void __cinn_grid_sync() {
812813}
813814
814815#define CINN_GRID_REDUCE_IMPL(REDUCE_TYPE, init_value, DTYPE) \
815- __cinn_grid_sync(); \
816+ cooperative_groups::this_grid().sync(); \
816817 DTYPE tmp_val = init_value; \
817818 for (int y = 0; y < gridDim.y; y++) { \
818819 tmp_val = cinn_##REDUCE_TYPE(tmp_val, mem[y * spatial_size + spatial_index]); \
@@ -830,7 +831,28 @@ EXPAND_REDUCE_INT64_MACRO(CINN_GRID_REDUCE_MACRO)
830831EXPAND_REDUCE_FP32_MACRO(CINN_GRID_REDUCE_MACRO)
831832EXPAND_REDUCE_FP64_MACRO(CINN_GRID_REDUCE_MACRO)
832833EXPAND_REDUCE_BOOL_MACRO(CINN_GRID_REDUCE_MACRO)
833- EXPAND_REDUCE_FP16_MACRO(CINN_GRID_REDUCE_MACRO)
834+
835+ // FP16 grid reduce: accumulate in FP32 to avoid precision loss when summing
836+ // multiple FP16 block-level partial sums. Each partial sum can have magnitude
837+ // O(block_size * input_scale), and accumulating N such values in FP16 incurs
838+ // error proportional to N * magnitude * eps_fp16. Using FP32 for the inter-
839+ // block accumulation step keeps the error at FP16 quantization level only.
840+ #define CINN_GRID_REDUCE_FP16_MACRO(FP16_TYPE, FP32_FUNC, INIT_VAL) \
841+ __device__ inline float16 cinn_grid_reduce_##FP16_TYPE( \
842+ const float16 *mem, int spatial_size, int spatial_index) { \
843+ cooperative_groups::this_grid().sync(); \
844+ float tmp_val = (float)(INIT_VAL); \
845+ for (int y = 0; y < gridDim.y; y++) { \
846+ tmp_val = FP32_FUNC( \
847+ tmp_val, __half2float(mem[y * spatial_size + spatial_index])); \
848+ } \
849+ return __float2half(tmp_val); \
850+ }
851+
852+ CINN_GRID_REDUCE_FP16_MACRO(sum_fp16, cinn_sum_fp32, 0.0f)
853+ CINN_GRID_REDUCE_FP16_MACRO(prod_fp16, cinn_prod_fp32, 1.0f)
854+ CINN_GRID_REDUCE_FP16_MACRO(max_fp16, cinn_max_fp32, -65504.0f)
855+ CINN_GRID_REDUCE_FP16_MACRO(min_fp16, cinn_min_fp32, 65504.0f)
834856
835857__device__ inline bool cinn_grid_reduce_update_semaphore(int *semaphores) {
836858 __shared__ bool done;
0 commit comments