Skip to content

Commit 3965483

Browse files
committed
[MetaX] Implement cooperative kernel launch for CINN plugin
1 parent 879edf4 commit 3965483

2 files changed

Lines changed: 42 additions & 0 deletions

File tree

backends/metax_gpu/cinn/cinn_interface.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,20 @@ extern C_Status MetaxLaunchKernel(void* dev_ptr,
6767
int shm,
6868
void* stream);
6969

70+
// Launches a cooperative kernel function (grid-level sync)
71+
extern C_Status MetaxLaunchCooperativeKernel(void* dev_ptr,
72+
void* func_ptr,
73+
void** args,
74+
int num_args,
75+
int gx,
76+
int gy,
77+
int gz,
78+
int bx,
79+
int by,
80+
int bz,
81+
int shm,
82+
void* stream);
83+
7084
// --- From passes/pass_manager.cc ---
7185
// Applies custom graph optimization passes
7286
extern C_Status MetaxApplyCustomPass(void* dev_ptr, void* ir_module);
@@ -99,6 +113,7 @@ void InitCinnInterface(C_DeviceInterface* device_interface) {
99113
metax_cinn_impl.module_unload = MetaxModuleUnload;
100114
metax_cinn_impl.get_kernel_address = MetaxGetKernelAddress;
101115
metax_cinn_impl.launch_kernel = MetaxLaunchKernel;
116+
metax_cinn_impl.launch_cooperative_kernel = MetaxLaunchCooperativeKernel;
102117

103118
// 6. Register Compilation Strategy interface
104119
metax_cinn_impl.apply_custom_pass = MetaxApplyCustomPass;

backends/metax_gpu/cinn/runtime/cinn_runtime.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,33 @@ C_Status MetaxLaunchKernel(void* dev_ptr,
8282
return C_Status::C_SUCCESS;
8383
}
8484

85+
// Launch cooperative kernel: equivalent to cuLaunchCooperativeKernel
86+
C_Status MetaxLaunchCooperativeKernel(void* dev_ptr,
87+
void* func_ptr,
88+
void** args,
89+
int num_args,
90+
int gx,
91+
int gy,
92+
int gz,
93+
int bx,
94+
int by,
95+
int bz,
96+
int shm,
97+
void* stream) {
98+
CUresult err = cuLaunchCooperativeKernel((CUfunction)func_ptr,
99+
gx,
100+
gy,
101+
gz,
102+
bx,
103+
by,
104+
bz,
105+
shm,
106+
(CUstream)stream,
107+
args);
108+
if (err != CUDA_SUCCESS) return C_Status::C_FAILED;
109+
return C_Status::C_SUCCESS;
110+
}
111+
85112
} // namespace metax
86113
} // namespace custom_device
87114
} // namespace paddle

0 commit comments

Comments
 (0)