diff --git a/backends/metax_gpu/cinn/cinn_interface.cc b/backends/metax_gpu/cinn/cinn_interface.cc index a01bd0e67e..332cc99674 100644 --- a/backends/metax_gpu/cinn/cinn_interface.cc +++ b/backends/metax_gpu/cinn/cinn_interface.cc @@ -67,6 +67,20 @@ extern C_Status MetaxLaunchKernel(void* dev_ptr, int shm, void* stream); +// Launches a cooperative kernel function (grid-level sync) +extern C_Status MetaxLaunchCooperativeKernel(void* dev_ptr, + void* func_ptr, + void** args, + int num_args, + int gx, + int gy, + int gz, + int bx, + int by, + int bz, + int shm, + void* stream); + // --- From passes/pass_manager.cc --- // Applies custom graph optimization passes extern C_Status MetaxApplyCustomPass(void* dev_ptr, void* ir_module); @@ -99,6 +113,7 @@ void InitCinnInterface(C_DeviceInterface* device_interface) { metax_cinn_impl.module_unload = MetaxModuleUnload; metax_cinn_impl.get_kernel_address = MetaxGetKernelAddress; metax_cinn_impl.launch_kernel = MetaxLaunchKernel; + metax_cinn_impl.launch_cooperative_kernel = MetaxLaunchCooperativeKernel; // 6. Register Compilation Strategy interface metax_cinn_impl.apply_custom_pass = MetaxApplyCustomPass; diff --git a/backends/metax_gpu/cinn/compiler/compiler.cc b/backends/metax_gpu/cinn/compiler/compiler.cc index b65f73e6e4..e9ee346c32 100644 --- a/backends/metax_gpu/cinn/compiler/compiler.cc +++ b/backends/metax_gpu/cinn/compiler/compiler.cc @@ -780,12 +780,43 @@ __device__ inline argidx_fp32_i64 cinn_discrete_reduce_min_argidx_fp32_i64( CINN_DISCRETE_REDUCE_IMPL(min_argidx_fp32_i64, value); } +// =============================================================== +// Grid-wide Barrier (emulates cooperative_groups::this_grid().sync()) +// Uses a sense-reversing barrier so it works correctly when called +// multiple times within the same kernel. +// REQUIREMENT: all thread blocks must be co-resident on the GPU. +// =============================================================== +__device__ unsigned int __cinn_grid_barrier_count[8192]; +__device__ unsigned int __cinn_grid_barrier_flag[8192]; + +__device__ inline void __cinn_grid_sync() { + __threadfence(); + __syncthreads(); + if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { + unsigned int expected = + atomicAdd(&__cinn_grid_barrier_flag[blockIdx.x], 0u); + unsigned int arrived = + atomicAdd(&__cinn_grid_barrier_count[blockIdx.x], 1u) + 1u; + if (arrived == (unsigned int)gridDim.y) { + atomicExch(&__cinn_grid_barrier_count[blockIdx.x], 0u); + __threadfence(); + atomicExch(&__cinn_grid_barrier_flag[blockIdx.x], 1u - expected); + __threadfence(); + } else { + while (atomicAdd(&__cinn_grid_barrier_flag[blockIdx.x], 0u) == + expected) { + } + } + } + __syncthreads(); +} + #define CINN_GRID_REDUCE_IMPL(REDUCE_TYPE, init_value, DTYPE) \ - DTYPE tmp_val = init_value; \ - for (int y = 0; y < gridDim.y; y++) { \ - tmp_val = \ - cinn_##REDUCE_TYPE(tmp_val, mem[y * spatial_size + spatial_index]); \ - } \ + __cinn_grid_sync(); \ + DTYPE tmp_val = init_value; \ + for (int y = 0; y < gridDim.y; y++) { \ + tmp_val = cinn_##REDUCE_TYPE(tmp_val, mem[y * spatial_size + spatial_index]); \ + } \ return tmp_val; #define CINN_GRID_REDUCE_MACRO(REDUCE_TYPE, INIT_VAL, DTYPE) \ diff --git a/backends/metax_gpu/cinn/runtime/cinn_runtime.cc b/backends/metax_gpu/cinn/runtime/cinn_runtime.cc index 7f19db35e4..41d493e3df 100644 --- a/backends/metax_gpu/cinn/runtime/cinn_runtime.cc +++ b/backends/metax_gpu/cinn/runtime/cinn_runtime.cc @@ -28,9 +28,14 @@ namespace metax { C_Status MetaxModuleLoad(void* dev_ptr, const char* path, void** mod_out) { CUmodule module; CUresult err = cuModuleLoad(&module, path); - if (err != CUDA_SUCCESS) return C_Status::C_FAILED; - + if (err != CUDA_SUCCESS) { + std::cerr << "[MetaxModuleLoad] FAILED to load module from: " << path + << ", error=" << err << std::endl; + return C_Status::C_FAILED; + } *mod_out = reinterpret_cast(module); + std::cerr << "[MetaxModuleLoad] OK path=" << path << " module=" << module + << std::endl; return C_Status::C_SUCCESS; } @@ -47,9 +52,14 @@ C_Status MetaxGetKernelAddress(void* dev_ptr, void** func_out) { CUfunction func; CUresult err = cuModuleGetFunction(&func, (CUmodule)module_handle, func_name); - if (err != CUDA_SUCCESS) return C_Status::C_FAILED; - + if (err != CUDA_SUCCESS) { + std::cerr << "[MetaxGetKernelAddress] FAILED func_name=" << func_name + << " module=" << module_handle << " error=" << err << std::endl; + return C_Status::C_FAILED; + } *func_out = reinterpret_cast(func); + std::cerr << "[MetaxGetKernelAddress] OK func_name=" << func_name + << " func_ptr=" << func << std::endl; return C_Status::C_SUCCESS; } @@ -82,6 +92,44 @@ C_Status MetaxLaunchKernel(void* dev_ptr, return C_Status::C_SUCCESS; } +// Launch cooperative kernel: uses cuLaunchCooperativeKernel (mapped to +// wcudaLaunchCooperativeKernel -> mcLaunchCooperativeKernel via cu-bridge) +// to guarantee all thread blocks are co-resident on the GPU, which is +// required by cross-block grid_reduce barriers (__cinn_grid_sync). +C_Status MetaxLaunchCooperativeKernel(void* dev_ptr, + void* func_ptr, + void** args, + int num_args, + int gx, + int gy, + int gz, + int bx, + int by, + int bz, + int shm, + void* stream) { + std::cout << "YUHAN!!! [MetaxLaunchCooperativeKernel] func_ptr=" << func_ptr + << " grid=(" << gx << "," << gy << "," << gz << ")" + << " block=(" << bx << "," << by << "," << bz << ")" + << " shm=" << shm << std::endl; + CUresult err = cuLaunchCooperativeKernel(static_cast(func_ptr), + gx, + gy, + gz, + bx, + by, + bz, + shm, + static_cast(stream), + args); + if (err != CUDA_SUCCESS) { + std::cerr << "[MetaxLaunchCooperativeKernel] FAILED error=" << err + << std::endl; + return C_Status::C_FAILED; + } + return C_Status::C_SUCCESS; +} + } // namespace metax } // namespace custom_device } // namespace paddle