99#include " cuda_loader.h"
1010#include " cuda_helper.h"
1111#include " hash_map.h"
12+ #include " py.h"
1213#include " ref_ptr.h"
1314#include " stream_buffer.h"
1415#include " vec.h"
@@ -30,6 +31,9 @@ static PyObject* g_strides_pyunicode;
3031static PyObject* g___dlpack___pyunicode;
3132static PyObject* g_compile_pyunicode;
3233static PyObject* g_dynamic_shared_memory_bytes_pyunicode;
34+ static PyObject* g_cooperative_pyunicode;
35+ static PyObject* g_cluster_dim_pyunicode;
36+ static PyObject* g_preferred_cluster_dim_pyunicode;
3337
3438static PyTypeObject* g_torch_Tensor_type;
3539static PyTypeObject* g_torch_cuda_Stream_type;
@@ -78,7 +82,6 @@ static PyObject* get_signature_module() {
7882FOREACH_TORCH_DTYPE (DECLARE_TORCH_DTYPE_GLOBAL )
7983
8084
81- static PyTypeObject* g_cupy_ndarray_type;
8285static PyTypeObject* g_cupy_cuda_Stream_type;
8386
8487static PyTypeObject* g_numba_cuda_Stream_type;
@@ -2187,13 +2190,19 @@ static Result<PreparedLaunch> prepare_launch(
21872190 static_cast <unsigned >(dyn_smem_size)};
21882191}
21892192
2193+
2194+ static constexpr unsigned kMaxCUlaunchAttrs = /* CU_LAUNCH_ATTRIBUTE_MAX=*/ 17 ;
2195+
21902196static Status launch (const DriverApi* driver,
21912197 PyObject* dispatcher_pyobj,
21922198 Grid grid,
21932199 Grid block,
21942200 CUstream launch_stream,
2201+ CUlaunchAttribute launch_attrs[kMaxCUlaunchAttrs ],
2202+ unsigned num_attrs,
21952203 PyObject* const * pyargs,
2196- Py_ssize_t num_pyargs) {
2204+ Py_ssize_t num_pyargs
2205+ ) {
21972206 StreamBufferTransaction tx;
21982207 Result<PreparedLaunch> prep = prepare_launch (
21992208 driver, dispatcher_pyobj, launch_stream, pyargs, num_pyargs, tx);
@@ -2203,12 +2212,22 @@ static Status launch(const DriverApi* driver,
22032212 if (!maybe_switch_context (driver, prep->helper ->cuda_context , ctx_guard))
22042213 return ErrorRaised;
22052214
2206- CUresult res = driver->cuLaunchKernel (
2215+ CUlaunchConfig config = {
2216+ .gridDimX = grid.dims [0 ],
2217+ .gridDimY = grid.dims [1 ],
2218+ .gridDimZ = grid.dims [2 ],
2219+ .blockDimX = block.dims [0 ],
2220+ .blockDimY = block.dims [1 ],
2221+ .blockDimZ = block.dims [2 ],
2222+ .sharedMemBytes = prep->dynamic_smem_bytes ,
2223+ .hStream = launch_stream,
2224+ .attrs = launch_attrs,
2225+ .numAttrs = num_attrs,
2226+ };
2227+
2228+ CUresult res = driver->cuLaunchKernelEx (
2229+ &config,
22072230 reinterpret_cast <CUfunction>(prep->kernel ),
2208- grid.dims [0 ], grid.dims [1 ], grid.dims [2 ],
2209- block.dims [0 ], block.dims [1 ], block.dims [2 ],
2210- prep->dynamic_smem_bytes ,
2211- launch_stream,
22122231 reinterpret_cast <void **>(prep->helper ->cuarg_pointers .data ()),
22132232 nullptr );
22142233
@@ -2546,6 +2565,75 @@ struct LaunchArgs {
25462565 Py_ssize_t num_kernel_args;
25472566};
25482567
2568+ // Parse extra keyword arguments accepted by the extended launch api into
2569+ // launch attributes.
2570+ static Result<unsigned > parse_launch_kwargs (PyObject *const *args,
2571+ Py_ssize_t nargs, PyObject *kwargs,
2572+ CUlaunchAttribute launch_attrs[kMaxCUlaunchAttrs ]) {
2573+ if (kwargs == nullptr )
2574+ return 0 ;
2575+
2576+ CHECK (PyTuple_Check (kwargs) &&
2577+ " Keyword argument tuple is nonnull and not a tuple" );
2578+
2579+ const auto nkwargs = PyTuple_GET_SIZE (kwargs);
2580+ bool has_cluster_dim = false , has_preferred_cluster_dim = false ;
2581+ size_t num_attrs = 0 ;
2582+
2583+ for (Py_ssize_t i = 0 ; i < nkwargs; i++) {
2584+ PyObject *keyword = PyTuple_GET_ITEM (kwargs, i);
2585+ PyObject *kwarg = args[nargs + i];
2586+ CHECK (keyword && kwarg);
2587+ if (PyUnicode_Compare (keyword, g_cooperative_pyunicode) == 0 ) {
2588+ if (!PyBool_Check (kwarg))
2589+ return raise (PyExc_TypeError,
2590+ " expected argument %U to have type bool" , keyword);
2591+ CUlaunchAttribute *attr = &launch_attrs[num_attrs++];
2592+ attr->id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE ;
2593+ attr->value .cooperative = Py_IsTrue (kwarg);
2594+ } else if (PyUnicode_Compare (keyword, g_cluster_dim_pyunicode) == 0 ) {
2595+ if (Py_IsNone (kwarg))
2596+ continue ;
2597+ const auto grid = parse_grid (kwarg);
2598+ if (!grid.is_ok ())
2599+ return ErrorRaised;
2600+ const auto &dims = grid->dims ;
2601+ CUlaunchAttribute *attr = &launch_attrs[num_attrs++];
2602+ attr->id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION ;
2603+ attr->value .clusterDim = {.x = dims[0 ], .y = dims[1 ], .z = dims[2 ]};
2604+ has_cluster_dim = true ;
2605+ } else if (PyUnicode_Compare (keyword, g_preferred_cluster_dim_pyunicode) ==
2606+ 0 ) {
2607+ if (Py_IsNone (kwarg))
2608+ continue ;
2609+ const auto grid = parse_grid (kwarg);
2610+ if (!grid.is_ok ())
2611+ return ErrorRaised;
2612+ const auto &dims = grid->dims ;
2613+ CUlaunchAttribute *attr = &launch_attrs[num_attrs++];
2614+ attr->id = CU_LAUNCH_ATTRIBUTE_PREFERRED_CLUSTER_DIMENSION ;
2615+ attr->value .preferredClusterDim = {
2616+ .x = dims[0 ], .y = dims[1 ], .z = dims[2 ]};
2617+ has_preferred_cluster_dim = true ;
2618+ } else {
2619+ return raise (PyExc_RuntimeError, " Unexpected keyword argument %U" ,
2620+ keyword);
2621+ }
2622+ }
2623+
2624+ // ctk docs say: "This attribute will only take effect when a regular
2625+ // cluster dimension has been specified." We could technically allow it, but
2626+ // the user likely made a mistake if preferred dims were passed and
2627+ // "regular" dims were not.
2628+ if (has_preferred_cluster_dim && !has_cluster_dim)
2629+ return raise (PyExc_ValueError,
2630+ " Keyword argument %U requires that %U is also passed" ,
2631+ g_preferred_cluster_dim_pyunicode,
2632+ g_cluster_dim_pyunicode);
2633+
2634+ return num_attrs;
2635+ }
2636+
25492637static Status parse_launch_args (PyObject* const * args, Py_ssize_t nargs, const char * signature,
25502638 bool with_block, LaunchArgs* out) {
25512639 if (nargs != 4 + with_block)
@@ -2593,16 +2681,23 @@ static Status parse_launch_args(PyObject* const* args, Py_ssize_t nargs, const c
25932681}
25942682
25952683static PyObject* launch_impl (PyObject* const * args, Py_ssize_t nargs,
2596- const char * signature, bool with_block) {
2684+ PyObject* kwargs, const char * signature, bool with_block
2685+ ) {
25972686 LaunchArgs launch_args;
25982687 if (!parse_launch_args (args, nargs, signature, with_block, &launch_args))
25992688 return nullptr ;
26002689
2690+ CUlaunchAttribute launch_attrs[kMaxCUlaunchAttrs ];
2691+ const auto num_attrs = parse_launch_kwargs (args, nargs, kwargs, launch_attrs);
2692+ if (!num_attrs.is_ok ())
2693+ return nullptr ;
2694+
26012695 Result<const DriverApi*> driver = get_driver_api ();
26022696 if (!driver.is_ok ()) return nullptr ;
26032697
2604- if (!launch (*driver, launch_args.dispatcher , launch_args.grid , launch_args.block ,
2605- launch_args.stream , launch_args.kernel_args , launch_args.num_kernel_args ))
2698+ if (!launch (*driver, launch_args.dispatcher , launch_args.grid ,
2699+ launch_args.block , launch_args.stream , launch_attrs, *num_attrs,
2700+ launch_args.kernel_args , launch_args.num_kernel_args ))
26062701 return nullptr ;
26072702
26082703 return Py_NewRef (Py_None);
@@ -2611,13 +2706,18 @@ static PyObject* launch_impl(PyObject* const* args, Py_ssize_t nargs,
26112706#define LAUNCH_SIGNATURE " launch(stream, grid, kernel, kernel_args, /)"
26122707
26132708static PyObject* cuda_tile_launch (PyObject*, PyObject* const * args, Py_ssize_t nargs) {
2614- return launch_impl (args, nargs, LAUNCH_SIGNATURE , /* with_block=*/ false );
2709+ return launch_impl (args, nargs, nullptr , LAUNCH_SIGNATURE ,
2710+ /* with_block=*/ false );
26152711}
26162712
2617- #define LAUNCH_EXTENDED_SIGNATURE " launch(stream, grid, block, kernel, kernel_args, /)"
2713+ #define LAUNCH_EXTENDED_SIGNATURE \
2714+ " launch(stream, grid, block, kernel, kernel_args, /, *, " \
2715+ " cooperative=False, cluster_dim=None, preferred_cluster_dim=None)"
26182716
2619- static PyObject* launch_extended (PyObject*, PyObject* const * args, Py_ssize_t nargs) {
2620- return launch_impl (args, nargs, LAUNCH_EXTENDED_SIGNATURE , /* with_block=*/ true );
2717+ static PyObject *launch_extended (PyObject *, PyObject *const *args,
2718+ Py_ssize_t nargs, PyObject *kwargs) {
2719+ return launch_impl (args, nargs, kwargs, LAUNCH_EXTENDED_SIGNATURE ,
2720+ /* with_block=*/ true );
26212721}
26222722
26232723#define BENCHMARK_SIGNATURE " _benchmark(stream, grid, kernel, pyargs_tuples, /)"
@@ -2696,12 +2796,6 @@ static void try_get_cupy_globals() {
26962796 PyPtr cupy = try_import (" cupy" );
26972797 if (!cupy) return ;
26982798
2699- // Save a reference to cupy.ndarray
2700- if (PyPtr cupy_ndarray = try_getattr (cupy, " ndarray" )) {
2701- if (PyType_Check (cupy_ndarray.get ()))
2702- g_cupy_ndarray_type = reinterpret_cast <PyTypeObject*>(cupy_ndarray.release ());
2703- }
2704-
27052799 // Save references to cupy.cuda.Stream
27062800 if (PyPtr cupy_cuda = try_getattr (cupy, " cuda" )) {
27072801 if (PyPtr cupy_cuda_Stream = try_getattr (cupy_cuda, " Stream" )) {
@@ -2789,6 +2883,9 @@ Status tile_kernel_init(PyObject* m) {
27892883 INIT_STRING_CONSTANT (__dlpack__);
27902884 INIT_STRING_CONSTANT (compile);
27912885 INIT_STRING_CONSTANT (dynamic_shared_memory_bytes);
2886+ INIT_STRING_CONSTANT (cooperative);
2887+ INIT_STRING_CONSTANT (cluster_dim);
2888+ INIT_STRING_CONSTANT (preferred_cluster_dim);
27922889
27932890 g_stream_buffer_pool_by_ctx_id = new StreamBufferPoolMap ();
27942891
0 commit comments