Skip to content

Commit ee6bda7

Browse files
[lang] Expose cluster dims and coop launch to extended launch api
This is where I got the api version, please let me know if this is not right. rg PFN_cuLaunchKernelEx /proj/cuda/13.1/Linux_x86_64/include/cudaTypedefs.h *PFN_cuLaunchKernelEx_v11060 Signed-off-by: Asher Mancinelli <amancinelli@nvidia.com>
1 parent cfd9f5f commit ee6bda7

10 files changed

Lines changed: 348 additions & 85 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@ docs/source/stubs
2525
/internal
2626
.cursor/rules/
2727
*.pyd
28+
/compile_commands.json

cext/cuda_helper.cpp

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -143,62 +143,54 @@ PyObject* destroy_stream(PyObject* self, PyObject* arg) {
143143
Py_RETURN_NONE;
144144
}
145145

146-
static decltype(cuLaunchKernel)* g_real_cuLaunchKernel;
147-
static PyObject* g_cuLaunchKernel_spy_callback;
146+
static decltype(cuLaunchKernelEx)* g_real_cuLaunchKernelEx;
147+
static PyObject* g_cuLaunchKernelEx_spy_callback;
148148

149-
static CUresult shim_cuLaunchKernel(
149+
static CUresult shim_cuLaunchKernelEx(
150+
const CUlaunchConfig *config,
150151
CUfunction f,
151-
unsigned int gridDimX,
152-
unsigned int gridDimY,
153-
unsigned int gridDimZ,
154-
unsigned int blockDimX,
155-
unsigned int blockDimY,
156-
unsigned int blockDimZ,
157-
unsigned int sharedMemBytes,
158-
CUstream hStream,
159152
void** kernelParams,
160153
void** extra) {
161154

162155
PyPtr res = steal(PyObject_CallFunction(
163-
g_cuLaunchKernel_spy_callback,
156+
g_cuLaunchKernelEx_spy_callback,
164157
"(K III III I K)",
165158
reinterpret_cast<unsigned long long>(f),
166-
gridDimX, gridDimY, gridDimZ,
167-
blockDimX, blockDimY, blockDimZ,
168-
sharedMemBytes,
169-
reinterpret_cast<unsigned long long>(hStream)
159+
config->gridDimX, config->gridDimY, config->gridDimZ,
160+
config->blockDimX, config->blockDimY, config->blockDimZ,
161+
config->sharedMemBytes,
162+
reinterpret_cast<unsigned long long>(config->hStream)
170163
));
171164
if (!res) return CUDA_ERROR_LAUNCH_FAILED;
172165

173-
return g_real_cuLaunchKernel(f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ,
174-
sharedMemBytes, hStream, kernelParams, extra);
166+
return g_real_cuLaunchKernelEx(config, f, kernelParams, extra);
175167
}
176168

177169
static PyObject* spy_on_cuLaunchKernel_begin(PyObject* self, PyObject* arg) {
178-
if (g_real_cuLaunchKernel)
170+
if (g_real_cuLaunchKernelEx)
179171
return PyErr_Format(PyExc_RuntimeError, "Already spying");
180172

181173
Result<const DriverApi*> driver_result = get_driver_api();
182174
if (!driver_result.is_ok()) return nullptr;
183175

184176
DriverApi* api = const_cast<DriverApi*>(*driver_result);
185-
g_real_cuLaunchKernel = api->cuLaunchKernel;
186-
g_cuLaunchKernel_spy_callback = Py_NewRef(arg);
187-
api->cuLaunchKernel = shim_cuLaunchKernel;
177+
g_real_cuLaunchKernelEx = api->cuLaunchKernelEx;
178+
g_cuLaunchKernelEx_spy_callback = Py_NewRef(arg);
179+
api->cuLaunchKernelEx = shim_cuLaunchKernelEx;
188180
return Py_NewRef(Py_None);
189181
}
190182

191183
static PyObject* spy_on_cuLaunchKernel_end(PyObject* self, PyObject* arg) {
192-
if (!g_real_cuLaunchKernel)
184+
if (!g_real_cuLaunchKernelEx)
193185
return PyErr_Format(PyExc_RuntimeError, "Not spying");
194186

195187
Result<const DriverApi*> driver_result = get_driver_api();
196188
if (!driver_result.is_ok()) return nullptr;
197189

198190
DriverApi* api = const_cast<DriverApi*>(*driver_result);
199-
api->cuLaunchKernel = g_real_cuLaunchKernel;
200-
g_real_cuLaunchKernel = nullptr;
201-
Py_CLEAR(g_cuLaunchKernel_spy_callback);
191+
api->cuLaunchKernelEx = g_real_cuLaunchKernelEx;
192+
g_real_cuLaunchKernelEx = nullptr;
193+
Py_CLEAR(g_cuLaunchKernelEx_spy_callback);
202194
return Py_NewRef(Py_None);
203195
}
204196

cext/cuda_loader.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
X(cuLibraryGetKernel, 12000) \
1717
X(cuGetErrorString, 6000) \
1818
X(cuLaunchKernel, 7000) \
19+
X(cuLaunchKernelEx, 11060) \
1920
X(cuPointerGetAttribute, 4000) \
2021
X(cuCtxSynchronize, 2000) \
2122
X(cuCtxPushCurrent, 4000) \

cext/tile_kernel.cpp

Lines changed: 117 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "cuda_loader.h"
1010
#include "cuda_helper.h"
1111
#include "hash_map.h"
12+
#include "py.h"
1213
#include "ref_ptr.h"
1314
#include "stream_buffer.h"
1415
#include "vec.h"
@@ -30,6 +31,9 @@ static PyObject* g_strides_pyunicode;
3031
static PyObject* g___dlpack___pyunicode;
3132
static PyObject* g_compile_pyunicode;
3233
static PyObject* g_dynamic_shared_memory_bytes_pyunicode;
34+
static PyObject* g_cooperative_pyunicode;
35+
static PyObject* g_cluster_dim_pyunicode;
36+
static PyObject* g_preferred_cluster_dim_pyunicode;
3337

3438
static PyTypeObject* g_torch_Tensor_type;
3539
static PyTypeObject* g_torch_cuda_Stream_type;
@@ -78,7 +82,6 @@ static PyObject* get_signature_module() {
7882
FOREACH_TORCH_DTYPE(DECLARE_TORCH_DTYPE_GLOBAL)
7983

8084

81-
static PyTypeObject* g_cupy_ndarray_type;
8285
static PyTypeObject* g_cupy_cuda_Stream_type;
8386

8487
static PyTypeObject* g_numba_cuda_Stream_type;
@@ -2187,13 +2190,19 @@ static Result<PreparedLaunch> prepare_launch(
21872190
static_cast<unsigned>(dyn_smem_size)};
21882191
}
21892192

2193+
2194+
static constexpr unsigned kMaxCUlaunchAttrs = /*CU_LAUNCH_ATTRIBUTE_MAX=*/17;
2195+
21902196
static Status launch(const DriverApi* driver,
21912197
PyObject* dispatcher_pyobj,
21922198
Grid grid,
21932199
Grid block,
21942200
CUstream launch_stream,
2201+
CUlaunchAttribute launch_attrs[kMaxCUlaunchAttrs],
2202+
unsigned num_attrs,
21952203
PyObject* const* pyargs,
2196-
Py_ssize_t num_pyargs) {
2204+
Py_ssize_t num_pyargs
2205+
) {
21972206
StreamBufferTransaction tx;
21982207
Result<PreparedLaunch> prep = prepare_launch(
21992208
driver, dispatcher_pyobj, launch_stream, pyargs, num_pyargs, tx);
@@ -2203,12 +2212,22 @@ static Status launch(const DriverApi* driver,
22032212
if (!maybe_switch_context(driver, prep->helper->cuda_context, ctx_guard))
22042213
return ErrorRaised;
22052214

2206-
CUresult res = driver->cuLaunchKernel(
2215+
CUlaunchConfig config = {
2216+
.gridDimX = grid.dims[0],
2217+
.gridDimY = grid.dims[1],
2218+
.gridDimZ = grid.dims[2],
2219+
.blockDimX = block.dims[0],
2220+
.blockDimY = block.dims[1],
2221+
.blockDimZ = block.dims[2],
2222+
.sharedMemBytes = prep->dynamic_smem_bytes,
2223+
.hStream = launch_stream,
2224+
.attrs = launch_attrs,
2225+
.numAttrs = num_attrs,
2226+
};
2227+
2228+
CUresult res = driver->cuLaunchKernelEx(
2229+
&config,
22072230
reinterpret_cast<CUfunction>(prep->kernel),
2208-
grid.dims[0], grid.dims[1], grid.dims[2],
2209-
block.dims[0], block.dims[1], block.dims[2],
2210-
prep->dynamic_smem_bytes,
2211-
launch_stream,
22122231
reinterpret_cast<void**>(prep->helper->cuarg_pointers.data()),
22132232
nullptr);
22142233

@@ -2546,6 +2565,75 @@ struct LaunchArgs {
25462565
Py_ssize_t num_kernel_args;
25472566
};
25482567

2568+
// Parse extra keyword arguments accepted by the extended launch api into
2569+
// launch attributes.
2570+
static Result<unsigned> parse_launch_kwargs(PyObject *const *args,
2571+
Py_ssize_t nargs, PyObject *kwargs,
2572+
CUlaunchAttribute launch_attrs[kMaxCUlaunchAttrs]) {
2573+
if (kwargs == nullptr)
2574+
return 0;
2575+
2576+
CHECK(PyTuple_Check(kwargs) &&
2577+
"Keyword argument tuple is nonnull and not a tuple");
2578+
2579+
const auto nkwargs = PyTuple_GET_SIZE(kwargs);
2580+
bool has_cluster_dim = false, has_preferred_cluster_dim = false;
2581+
size_t num_attrs = 0;
2582+
2583+
for (Py_ssize_t i = 0; i < nkwargs; i++) {
2584+
PyObject *keyword = PyTuple_GET_ITEM(kwargs, i);
2585+
PyObject *kwarg = args[nargs + i];
2586+
CHECK(keyword && kwarg);
2587+
if (PyUnicode_Compare(keyword, g_cooperative_pyunicode) == 0) {
2588+
if (!PyBool_Check(kwarg))
2589+
return raise(PyExc_TypeError,
2590+
"expected argument %U to have type bool", keyword);
2591+
CUlaunchAttribute *attr = &launch_attrs[num_attrs++];
2592+
attr->id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE;
2593+
attr->value.cooperative = Py_IsTrue(kwarg);
2594+
} else if (PyUnicode_Compare(keyword, g_cluster_dim_pyunicode) == 0) {
2595+
if (Py_IsNone(kwarg))
2596+
continue;
2597+
const auto grid = parse_grid(kwarg);
2598+
if (!grid.is_ok())
2599+
return ErrorRaised;
2600+
const auto &dims = grid->dims;
2601+
CUlaunchAttribute *attr = &launch_attrs[num_attrs++];
2602+
attr->id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
2603+
attr->value.clusterDim = {.x = dims[0], .y = dims[1], .z = dims[2]};
2604+
has_cluster_dim = true;
2605+
} else if (PyUnicode_Compare(keyword, g_preferred_cluster_dim_pyunicode) ==
2606+
0) {
2607+
if (Py_IsNone(kwarg))
2608+
continue;
2609+
const auto grid = parse_grid(kwarg);
2610+
if (!grid.is_ok())
2611+
return ErrorRaised;
2612+
const auto &dims = grid->dims;
2613+
CUlaunchAttribute *attr = &launch_attrs[num_attrs++];
2614+
attr->id = CU_LAUNCH_ATTRIBUTE_PREFERRED_CLUSTER_DIMENSION;
2615+
attr->value.preferredClusterDim = {
2616+
.x = dims[0], .y = dims[1], .z = dims[2]};
2617+
has_preferred_cluster_dim = true;
2618+
} else {
2619+
return raise(PyExc_RuntimeError, "Unexpected keyword argument %U",
2620+
keyword);
2621+
}
2622+
}
2623+
2624+
// ctk docs say: "This attribute will only take effect when a regular
2625+
// cluster dimension has been specified." We could technically allow it, but
2626+
// the user likely made a mistake if preferred dims were passed and
2627+
// "regular" dims were not.
2628+
if (has_preferred_cluster_dim && !has_cluster_dim)
2629+
return raise(PyExc_ValueError,
2630+
"Keyword argument %U requires that %U is also passed",
2631+
g_preferred_cluster_dim_pyunicode,
2632+
g_cluster_dim_pyunicode);
2633+
2634+
return num_attrs;
2635+
}
2636+
25492637
static Status parse_launch_args(PyObject* const* args, Py_ssize_t nargs, const char* signature,
25502638
bool with_block, LaunchArgs* out) {
25512639
if (nargs != 4 + with_block)
@@ -2593,16 +2681,23 @@ static Status parse_launch_args(PyObject* const* args, Py_ssize_t nargs, const c
25932681
}
25942682

25952683
static PyObject* launch_impl(PyObject* const* args, Py_ssize_t nargs,
2596-
const char* signature, bool with_block) {
2684+
PyObject* kwargs, const char* signature, bool with_block
2685+
) {
25972686
LaunchArgs launch_args;
25982687
if (!parse_launch_args(args, nargs, signature, with_block, &launch_args))
25992688
return nullptr;
26002689

2690+
CUlaunchAttribute launch_attrs[kMaxCUlaunchAttrs];
2691+
const auto num_attrs = parse_launch_kwargs(args, nargs, kwargs, launch_attrs);
2692+
if (!num_attrs.is_ok())
2693+
return nullptr;
2694+
26012695
Result<const DriverApi*> driver = get_driver_api();
26022696
if (!driver.is_ok()) return nullptr;
26032697

2604-
if (!launch(*driver, launch_args.dispatcher, launch_args.grid, launch_args.block,
2605-
launch_args.stream, launch_args.kernel_args, launch_args.num_kernel_args))
2698+
if (!launch(*driver, launch_args.dispatcher, launch_args.grid,
2699+
launch_args.block, launch_args.stream, launch_attrs, *num_attrs,
2700+
launch_args.kernel_args, launch_args.num_kernel_args))
26062701
return nullptr;
26072702

26082703
return Py_NewRef(Py_None);
@@ -2611,13 +2706,18 @@ static PyObject* launch_impl(PyObject* const* args, Py_ssize_t nargs,
26112706
#define LAUNCH_SIGNATURE "launch(stream, grid, kernel, kernel_args, /)"
26122707

26132708
static PyObject* cuda_tile_launch(PyObject*, PyObject* const* args, Py_ssize_t nargs) {
2614-
return launch_impl(args, nargs, LAUNCH_SIGNATURE, /*with_block=*/ false);
2709+
return launch_impl(args, nargs, nullptr, LAUNCH_SIGNATURE,
2710+
/*with_block=*/false);
26152711
}
26162712

2617-
#define LAUNCH_EXTENDED_SIGNATURE "launch(stream, grid, block, kernel, kernel_args, /)"
2713+
#define LAUNCH_EXTENDED_SIGNATURE \
2714+
"launch(stream, grid, block, kernel, kernel_args, /, *, " \
2715+
"cooperative=False, cluster_dim=None, preferred_cluster_dim=None)"
26182716

2619-
static PyObject* launch_extended(PyObject*, PyObject* const* args, Py_ssize_t nargs) {
2620-
return launch_impl(args, nargs, LAUNCH_EXTENDED_SIGNATURE, /*with_block=*/ true);
2717+
static PyObject *launch_extended(PyObject *, PyObject *const *args,
2718+
Py_ssize_t nargs, PyObject *kwargs) {
2719+
return launch_impl(args, nargs, kwargs, LAUNCH_EXTENDED_SIGNATURE,
2720+
/*with_block=*/true);
26212721
}
26222722

26232723
#define BENCHMARK_SIGNATURE "_benchmark(stream, grid, kernel, pyargs_tuples, /)"
@@ -2696,12 +2796,6 @@ static void try_get_cupy_globals() {
26962796
PyPtr cupy = try_import("cupy");
26972797
if (!cupy) return;
26982798

2699-
// Save a reference to cupy.ndarray
2700-
if (PyPtr cupy_ndarray = try_getattr(cupy, "ndarray")) {
2701-
if (PyType_Check(cupy_ndarray.get()))
2702-
g_cupy_ndarray_type = reinterpret_cast<PyTypeObject*>(cupy_ndarray.release());
2703-
}
2704-
27052799
// Save references to cupy.cuda.Stream
27062800
if (PyPtr cupy_cuda = try_getattr(cupy, "cuda")) {
27072801
if (PyPtr cupy_cuda_Stream = try_getattr(cupy_cuda, "Stream")) {
@@ -2789,6 +2883,9 @@ Status tile_kernel_init(PyObject* m) {
27892883
INIT_STRING_CONSTANT(__dlpack__);
27902884
INIT_STRING_CONSTANT(compile);
27912885
INIT_STRING_CONSTANT(dynamic_shared_memory_bytes);
2886+
INIT_STRING_CONSTANT(cooperative);
2887+
INIT_STRING_CONSTANT(cluster_dim);
2888+
INIT_STRING_CONSTANT(preferred_cluster_dim);
27922889

27932890
g_stream_buffer_pool_by_ctx_id = new StreamBufferPoolMap();
27942891

experimental/cuda-lang/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def run(self):
3838
build_dir = os.getenv("CUDA_TILE_CEXT_BUILD_DIR")
3939
if build_dir is None:
4040
build_dir = guess_cuda_tile_build_dir()
41-
self.spawn(["make", "-C", build_dir])
41+
self.spawn(["cmake", "--build", build_dir])
4242

4343
binary_name = "mlir2cubin"
4444
src_path = os.path.join(build_dir, "internal", "mlir2cubin", binary_name)

0 commit comments

Comments
 (0)