Skip to content

Commit ddecc7e

Browse files
committed
Use cooperative_groups::this_grid().sync() in CINN_GRID_REDUCE_IMPL. Add
CINN_GRID_REDUCE_FP16_MACRO.
1 parent 8c1abe6 commit ddecc7e

2 files changed

Lines changed: 30 additions & 8 deletions

File tree

backends/metax_gpu/cinn/compiler/compiler.cc

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ namespace metax {
3939
// ============================================================
4040
static const char* kMacaRuntimeSource = R"MACA_SOURCE(
4141
#pragma once
42+
#include <cooperative_groups.h>
4243
#include <cuda_fp16.h>
4344
#include <cuda_runtime.h>
4445
@@ -812,7 +813,7 @@ __device__ inline void __cinn_grid_sync() {
812813
}
813814
814815
#define CINN_GRID_REDUCE_IMPL(REDUCE_TYPE, init_value, DTYPE) \
815-
__cinn_grid_sync(); \
816+
cooperative_groups::this_grid().sync(); \
816817
DTYPE tmp_val = init_value; \
817818
for (int y = 0; y < gridDim.y; y++) { \
818819
tmp_val = cinn_##REDUCE_TYPE(tmp_val, mem[y * spatial_size + spatial_index]); \
@@ -830,7 +831,28 @@ EXPAND_REDUCE_INT64_MACRO(CINN_GRID_REDUCE_MACRO)
830831
EXPAND_REDUCE_FP32_MACRO(CINN_GRID_REDUCE_MACRO)
831832
EXPAND_REDUCE_FP64_MACRO(CINN_GRID_REDUCE_MACRO)
832833
EXPAND_REDUCE_BOOL_MACRO(CINN_GRID_REDUCE_MACRO)
833-
EXPAND_REDUCE_FP16_MACRO(CINN_GRID_REDUCE_MACRO)
834+
835+
// FP16 grid reduce: accumulate in FP32 to avoid precision loss when summing
836+
// multiple FP16 block-level partial sums. Each partial sum can have magnitude
837+
// O(block_size * input_scale), and accumulating N such values in FP16 incurs
838+
// error proportional to N * magnitude * eps_fp16. Using FP32 for the inter-
839+
// block accumulation step keeps the error at FP16 quantization level only.
840+
#define CINN_GRID_REDUCE_FP16_MACRO(FP16_TYPE, FP32_FUNC, INIT_VAL) \
841+
__device__ inline float16 cinn_grid_reduce_##FP16_TYPE( \
842+
const float16 *mem, int spatial_size, int spatial_index) { \
843+
cooperative_groups::this_grid().sync(); \
844+
float tmp_val = (float)(INIT_VAL); \
845+
for (int y = 0; y < gridDim.y; y++) { \
846+
tmp_val = FP32_FUNC( \
847+
tmp_val, __half2float(mem[y * spatial_size + spatial_index])); \
848+
} \
849+
return __float2half(tmp_val); \
850+
}
851+
852+
CINN_GRID_REDUCE_FP16_MACRO(sum_fp16, cinn_sum_fp32, 0.0f)
853+
CINN_GRID_REDUCE_FP16_MACRO(prod_fp16, cinn_prod_fp32, 1.0f)
854+
CINN_GRID_REDUCE_FP16_MACRO(max_fp16, cinn_max_fp32, -65504.0f)
855+
CINN_GRID_REDUCE_FP16_MACRO(min_fp16, cinn_min_fp32, 65504.0f)
834856
835857
__device__ inline bool cinn_grid_reduce_update_semaphore(int *semaphores) {
836858
__shared__ bool done;

backends/metax_gpu/cinn/runtime/cinn_runtime.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,15 @@ C_Status MetaxModuleLoad(void* dev_ptr, const char* path, void** mod_out) {
3434
return C_Status::C_FAILED;
3535
}
3636
*mod_out = reinterpret_cast<void*>(module);
37-
std::cerr << "[MetaxModuleLoad] OK path=" << path << " module=" << module
38-
<< std::endl;
37+
// std::cerr << "[MetaxModuleLoad] OK path=" << path << " module=" << module
38+
// << std::endl;
3939
return C_Status::C_SUCCESS;
4040
}
4141

4242
// Unload module
4343
C_Status MetaxModuleUnload(void* dev_ptr, void* module_handle) {
4444
cuModuleUnload((CUmodule)module_handle);
45-
std::cout << "YUHAN!!! [MetaxModuleUnload] module_handle=" << module_handle << std::endl;
45+
// std::cout << "YUHAN!!! [MetaxModuleUnload] module_handle=" << module_handle << std::endl;
4646
return C_Status::C_SUCCESS;
4747
}
4848

@@ -59,8 +59,8 @@ C_Status MetaxGetKernelAddress(void* dev_ptr,
5959
return C_Status::C_FAILED;
6060
}
6161
*func_out = reinterpret_cast<void*>(func);
62-
std::cout << "YUHAN!!! [MetaxGetKernelAddress] OK func_name=" << func_name
63-
<< " func_ptr=" << func << " module_handle=" << module_handle << std::endl;
62+
// std::cout << "YUHAN!!! [MetaxGetKernelAddress] OK func_name=" << func_name
63+
// << " func_ptr=" << func << " module_handle=" << module_handle << std::endl;
6464
return C_Status::C_SUCCESS;
6565
}
6666

@@ -109,7 +109,7 @@ C_Status MetaxLaunchCooperativeKernel(void* dev_ptr,
109109
int bz,
110110
int shm,
111111
void* stream) {
112-
std::cout << "YUHAN!!! [MetaxLaunchCooperativeKernel] func_ptr=" << func_ptr;
112+
// std::cout << "YUHAN!!! [MetaxLaunchCooperativeKernel] func_ptr=" << func_ptr;
113113
CUmodule module;
114114
CUresult errModule = cuFuncGetModule(&module ,static_cast<CUfunction>(func_ptr));
115115
if (errModule != CUDA_SUCCESS) {

0 commit comments

Comments
 (0)