@@ -28,9 +28,14 @@ namespace metax {
2828C_Status MetaxModuleLoad (void * dev_ptr, const char * path, void ** mod_out) {
2929 CUmodule module ;
3030 CUresult err = cuModuleLoad (&module , path);
31- if (err != CUDA_SUCCESS) return C_Status::C_FAILED;
32-
31+ if (err != CUDA_SUCCESS) {
32+ std::cerr << " [MetaxModuleLoad] FAILED to load module from: " << path
33+ << " , error=" << err << std::endl;
34+ return C_Status::C_FAILED;
35+ }
3336 *mod_out = reinterpret_cast <void *>(module );
37+ std::cerr << " [MetaxModuleLoad] OK path=" << path << " module=" << module
38+ << std::endl;
3439 return C_Status::C_SUCCESS;
3540}
3641
@@ -47,9 +52,14 @@ C_Status MetaxGetKernelAddress(void* dev_ptr,
4752 void ** func_out) {
4853 CUfunction func;
4954 CUresult err = cuModuleGetFunction (&func, (CUmodule)module_handle, func_name);
50- if (err != CUDA_SUCCESS) return C_Status::C_FAILED;
51-
55+ if (err != CUDA_SUCCESS) {
56+ std::cerr << " [MetaxGetKernelAddress] FAILED func_name=" << func_name
57+ << " module=" << module_handle << " error=" << err << std::endl;
58+ return C_Status::C_FAILED;
59+ }
5260 *func_out = reinterpret_cast <void *>(func);
61+ std::cerr << " [MetaxGetKernelAddress] OK func_name=" << func_name
62+ << " func_ptr=" << func << std::endl;
5363 return C_Status::C_SUCCESS;
5464}
5565
@@ -82,7 +92,10 @@ C_Status MetaxLaunchKernel(void* dev_ptr,
8292 return C_Status::C_SUCCESS;
8393}
8494
85- // Launch cooperative kernel: equivalent to cuLaunchCooperativeKernel
95+ // Launch cooperative kernel: uses cuLaunchCooperativeKernel (mapped to
96+ // wcudaLaunchCooperativeKernel -> mcLaunchCooperativeKernel via cu-bridge)
97+ // to guarantee all thread blocks are co-resident on the GPU, which is
98+ // required by cross-block grid_reduce barriers (__cinn_grid_sync).
8699C_Status MetaxLaunchCooperativeKernel (void * dev_ptr,
87100 void * func_ptr,
88101 void ** args,
@@ -95,17 +108,25 @@ C_Status MetaxLaunchCooperativeKernel(void* dev_ptr,
95108 int bz,
96109 int shm,
97110 void * stream) {
98- CUresult err = cuLaunchCooperativeKernel ((CUfunction)func_ptr,
111+ std::cout << " YUHAN!!! [MetaxLaunchCooperativeKernel] func_ptr=" << func_ptr
112+ << " grid=(" << gx << " ," << gy << " ," << gz << " )"
113+ << " block=(" << bx << " ," << by << " ," << bz << " )"
114+ << " shm=" << shm << std::endl;
115+ CUresult err = cuLaunchCooperativeKernel (static_cast <CUfunction>(func_ptr),
99116 gx,
100117 gy,
101118 gz,
102119 bx,
103120 by,
104121 bz,
105122 shm,
106- ( CUstream) stream,
123+ static_cast < CUstream>( stream) ,
107124 args);
108- if (err != CUDA_SUCCESS) return C_Status::C_FAILED;
125+ if (err != CUDA_SUCCESS) {
126+ std::cerr << " [MetaxLaunchCooperativeKernel] FAILED error=" << err
127+ << std::endl;
128+ return C_Status::C_FAILED;
129+ }
109130 return C_Status::C_SUCCESS;
110131}
111132
0 commit comments