Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions backends/metax_gpu/cinn/cinn_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ extern C_Status MetaxLaunchKernel(void* dev_ptr,
int shm,
void* stream);

// Launches a cooperative kernel function (grid-level sync)
extern C_Status MetaxLaunchCooperativeKernel(void* dev_ptr,
void* func_ptr,
void** args,
int num_args,
int gx,
int gy,
int gz,
int bx,
int by,
int bz,
int shm,
void* stream);

// --- From passes/pass_manager.cc ---
// Applies custom graph optimization passes
extern C_Status MetaxApplyCustomPass(void* dev_ptr, void* ir_module);
Expand Down Expand Up @@ -99,6 +113,7 @@ void InitCinnInterface(C_DeviceInterface* device_interface) {
metax_cinn_impl.module_unload = MetaxModuleUnload;
metax_cinn_impl.get_kernel_address = MetaxGetKernelAddress;
metax_cinn_impl.launch_kernel = MetaxLaunchKernel;
metax_cinn_impl.launch_cooperative_kernel = MetaxLaunchCooperativeKernel;

// 6. Register Compilation Strategy interface
metax_cinn_impl.apply_custom_pass = MetaxApplyCustomPass;
Expand Down
41 changes: 36 additions & 5 deletions backends/metax_gpu/cinn/compiler/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -780,12 +780,43 @@ __device__ inline argidx_fp32_i64 cinn_discrete_reduce_min_argidx_fp32_i64(
CINN_DISCRETE_REDUCE_IMPL(min_argidx_fp32_i64, value);
}

// ===============================================================
// Grid-wide Barrier (emulates cooperative_groups::this_grid().sync())
// Uses a sense-reversing barrier so it works correctly when called
// multiple times within the same kernel.
// REQUIREMENT: all thread blocks must be co-resident on the GPU.
// ===============================================================
__device__ unsigned int __cinn_grid_barrier_count[8192];
__device__ unsigned int __cinn_grid_barrier_flag[8192];

__device__ inline void __cinn_grid_sync() {
__threadfence();
__syncthreads();
if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {
unsigned int expected =
atomicAdd(&__cinn_grid_barrier_flag[blockIdx.x], 0u);
unsigned int arrived =
atomicAdd(&__cinn_grid_barrier_count[blockIdx.x], 1u) + 1u;
if (arrived == (unsigned int)gridDim.y) {
atomicExch(&__cinn_grid_barrier_count[blockIdx.x], 0u);
__threadfence();
atomicExch(&__cinn_grid_barrier_flag[blockIdx.x], 1u - expected);
__threadfence();
} else {
while (atomicAdd(&__cinn_grid_barrier_flag[blockIdx.x], 0u) ==
expected) {
}
}
}
__syncthreads();
}

#define CINN_GRID_REDUCE_IMPL(REDUCE_TYPE, init_value, DTYPE) \
DTYPE tmp_val = init_value; \
for (int y = 0; y < gridDim.y; y++) { \
tmp_val = \
cinn_##REDUCE_TYPE(tmp_val, mem[y * spatial_size + spatial_index]); \
} \
__cinn_grid_sync(); \
DTYPE tmp_val = init_value; \
for (int y = 0; y < gridDim.y; y++) { \
tmp_val = cinn_##REDUCE_TYPE(tmp_val, mem[y * spatial_size + spatial_index]); \
} \
return tmp_val;

#define CINN_GRID_REDUCE_MACRO(REDUCE_TYPE, INIT_VAL, DTYPE) \
Expand Down
56 changes: 52 additions & 4 deletions backends/metax_gpu/cinn/runtime/cinn_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,14 @@ namespace metax {
C_Status MetaxModuleLoad(void* dev_ptr, const char* path, void** mod_out) {
CUmodule module;
CUresult err = cuModuleLoad(&module, path);
if (err != CUDA_SUCCESS) return C_Status::C_FAILED;

if (err != CUDA_SUCCESS) {
std::cerr << "[MetaxModuleLoad] FAILED to load module from: " << path
<< ", error=" << err << std::endl;
return C_Status::C_FAILED;
}
*mod_out = reinterpret_cast<void*>(module);
std::cerr << "[MetaxModuleLoad] OK path=" << path << " module=" << module
<< std::endl;
return C_Status::C_SUCCESS;
}

Expand All @@ -47,9 +52,14 @@ C_Status MetaxGetKernelAddress(void* dev_ptr,
void** func_out) {
CUfunction func;
CUresult err = cuModuleGetFunction(&func, (CUmodule)module_handle, func_name);
if (err != CUDA_SUCCESS) return C_Status::C_FAILED;

if (err != CUDA_SUCCESS) {
std::cerr << "[MetaxGetKernelAddress] FAILED func_name=" << func_name
<< " module=" << module_handle << " error=" << err << std::endl;
return C_Status::C_FAILED;
}
*func_out = reinterpret_cast<void*>(func);
std::cerr << "[MetaxGetKernelAddress] OK func_name=" << func_name
<< " func_ptr=" << func << std::endl;
return C_Status::C_SUCCESS;
}

Expand Down Expand Up @@ -82,6 +92,44 @@ C_Status MetaxLaunchKernel(void* dev_ptr,
return C_Status::C_SUCCESS;
}

// Launch cooperative kernel: uses cuLaunchCooperativeKernel (mapped to
// wcudaLaunchCooperativeKernel -> mcLaunchCooperativeKernel via cu-bridge)
// to guarantee all thread blocks are co-resident on the GPU, which is
// required by cross-block grid_reduce barriers (__cinn_grid_sync).
C_Status MetaxLaunchCooperativeKernel(void* dev_ptr,
void* func_ptr,
void** args,
int num_args,
int gx,
int gy,
int gz,
int bx,
int by,
int bz,
int shm,
void* stream) {
std::cout << "YUHAN!!! [MetaxLaunchCooperativeKernel] func_ptr=" << func_ptr
<< " grid=(" << gx << "," << gy << "," << gz << ")"
<< " block=(" << bx << "," << by << "," << bz << ")"
<< " shm=" << shm << std::endl;
CUresult err = cuLaunchCooperativeKernel(static_cast<CUfunction>(func_ptr),
gx,
gy,
gz,
bx,
by,
bz,
shm,
static_cast<CUstream>(stream),
args);
if (err != CUDA_SUCCESS) {
std::cerr << "[MetaxLaunchCooperativeKernel] FAILED error=" << err
<< std::endl;
return C_Status::C_FAILED;
}
return C_Status::C_SUCCESS;
}

} // namespace metax
} // namespace custom_device
} // namespace paddle
Loading