Skip to content

Commit 09c436c

Browse files
committed
Metax: Add __cinn_grid_sync() into CINN_GRID_REDUCE_IMPL in compiler.cc
Add stdcout in cinn_runtime.cc.
1 parent 3965483 commit 09c436c

2 files changed

Lines changed: 65 additions & 13 deletions

File tree

backends/metax_gpu/cinn/compiler/compiler.cc

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -780,12 +780,43 @@ __device__ inline argidx_fp32_i64 cinn_discrete_reduce_min_argidx_fp32_i64(
780780
CINN_DISCRETE_REDUCE_IMPL(min_argidx_fp32_i64, value);
781781
}
782782
783+
// ===============================================================
784+
// Grid-wide Barrier (emulates cooperative_groups::this_grid().sync())
785+
// Uses a sense-reversing barrier so it works correctly when called
786+
// multiple times within the same kernel.
787+
// REQUIREMENT: all thread blocks must be co-resident on the GPU.
788+
// ===============================================================
789+
__device__ unsigned int __cinn_grid_barrier_count[8192];
790+
__device__ unsigned int __cinn_grid_barrier_flag[8192];
791+
792+
__device__ inline void __cinn_grid_sync() {
793+
__threadfence();
794+
__syncthreads();
795+
if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {
796+
unsigned int expected =
797+
atomicAdd(&__cinn_grid_barrier_flag[blockIdx.x], 0u);
798+
unsigned int arrived =
799+
atomicAdd(&__cinn_grid_barrier_count[blockIdx.x], 1u) + 1u;
800+
if (arrived == (unsigned int)gridDim.y) {
801+
atomicExch(&__cinn_grid_barrier_count[blockIdx.x], 0u);
802+
__threadfence();
803+
atomicExch(&__cinn_grid_barrier_flag[blockIdx.x], 1u - expected);
804+
__threadfence();
805+
} else {
806+
while (atomicAdd(&__cinn_grid_barrier_flag[blockIdx.x], 0u) ==
807+
expected) {
808+
}
809+
}
810+
}
811+
__syncthreads();
812+
}
813+
783814
#define CINN_GRID_REDUCE_IMPL(REDUCE_TYPE, init_value, DTYPE) \
784-
DTYPE tmp_val = init_value; \
785-
for (int y = 0; y < gridDim.y; y++) { \
786-
tmp_val = \
787-
cinn_##REDUCE_TYPE(tmp_val, mem[y * spatial_size + spatial_index]); \
788-
} \
815+
__cinn_grid_sync(); \
816+
DTYPE tmp_val = init_value; \
817+
for (int y = 0; y < gridDim.y; y++) { \
818+
tmp_val = cinn_##REDUCE_TYPE(tmp_val, mem[y * spatial_size + spatial_index]); \
819+
} \
789820
return tmp_val;
790821
791822
#define CINN_GRID_REDUCE_MACRO(REDUCE_TYPE, INIT_VAL, DTYPE) \

backends/metax_gpu/cinn/runtime/cinn_runtime.cc

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,14 @@ namespace metax {
2828
C_Status MetaxModuleLoad(void* dev_ptr, const char* path, void** mod_out) {
2929
CUmodule module;
3030
CUresult err = cuModuleLoad(&module, path);
31-
if (err != CUDA_SUCCESS) return C_Status::C_FAILED;
32-
31+
if (err != CUDA_SUCCESS) {
32+
std::cerr << "[MetaxModuleLoad] FAILED to load module from: " << path
33+
<< ", error=" << err << std::endl;
34+
return C_Status::C_FAILED;
35+
}
3336
*mod_out = reinterpret_cast<void*>(module);
37+
std::cerr << "[MetaxModuleLoad] OK path=" << path << " module=" << module
38+
<< std::endl;
3439
return C_Status::C_SUCCESS;
3540
}
3641

@@ -47,9 +52,14 @@ C_Status MetaxGetKernelAddress(void* dev_ptr,
4752
void** func_out) {
4853
CUfunction func;
4954
CUresult err = cuModuleGetFunction(&func, (CUmodule)module_handle, func_name);
50-
if (err != CUDA_SUCCESS) return C_Status::C_FAILED;
51-
55+
if (err != CUDA_SUCCESS) {
56+
std::cerr << "[MetaxGetKernelAddress] FAILED func_name=" << func_name
57+
<< " module=" << module_handle << " error=" << err << std::endl;
58+
return C_Status::C_FAILED;
59+
}
5260
*func_out = reinterpret_cast<void*>(func);
61+
std::cerr << "[MetaxGetKernelAddress] OK func_name=" << func_name
62+
<< " func_ptr=" << func << std::endl;
5363
return C_Status::C_SUCCESS;
5464
}
5565

@@ -82,7 +92,10 @@ C_Status MetaxLaunchKernel(void* dev_ptr,
8292
return C_Status::C_SUCCESS;
8393
}
8494

85-
// Launch cooperative kernel: equivalent to cuLaunchCooperativeKernel
95+
// Launch cooperative kernel: uses cuLaunchCooperativeKernel (mapped to
96+
// wcudaLaunchCooperativeKernel -> mcLaunchCooperativeKernel via cu-bridge)
97+
// to guarantee all thread blocks are co-resident on the GPU, which is
98+
// required by cross-block grid_reduce barriers (__cinn_grid_sync).
8699
C_Status MetaxLaunchCooperativeKernel(void* dev_ptr,
87100
void* func_ptr,
88101
void** args,
@@ -95,17 +108,25 @@ C_Status MetaxLaunchCooperativeKernel(void* dev_ptr,
95108
int bz,
96109
int shm,
97110
void* stream) {
98-
CUresult err = cuLaunchCooperativeKernel((CUfunction)func_ptr,
111+
std::cout << "YUHAN!!! [MetaxLaunchCooperativeKernel] func_ptr=" << func_ptr
112+
<< " grid=(" << gx << "," << gy << "," << gz << ")"
113+
<< " block=(" << bx << "," << by << "," << bz << ")"
114+
<< " shm=" << shm << std::endl;
115+
CUresult err = cuLaunchCooperativeKernel(static_cast<CUfunction>(func_ptr),
99116
gx,
100117
gy,
101118
gz,
102119
bx,
103120
by,
104121
bz,
105122
shm,
106-
(CUstream)stream,
123+
static_cast<CUstream>(stream),
107124
args);
108-
if (err != CUDA_SUCCESS) return C_Status::C_FAILED;
125+
if (err != CUDA_SUCCESS) {
126+
std::cerr << "[MetaxLaunchCooperativeKernel] FAILED error=" << err
127+
<< std::endl;
128+
return C_Status::C_FAILED;
129+
}
109130
return C_Status::C_SUCCESS;
110131
}
111132

0 commit comments

Comments
 (0)