@@ -285,8 +285,44 @@ struct HashVector {
285285 }
286286};
287287
288+ // X(Name, #Attrs, MinStack, StackEffect)
289+ #define FOREACH_SIZE_OPCODE (X ) \
290+ X (Const, 1 , 0 , 1 ) \
291+ X(KernelArgI32, 1 , 0 , 1 ) \
292+ X(Add, 0 , 2 , -1 ) \
293+ X(Mul, 0 , 2 , -1 )
294+
295+ #define SIZE_OPCODE_ENUM_ENTRY (name, _nattr, _min_st, _stack_eff ) \
296+ name,
297+
298+ enum class SizeOpcode : uint8_t {
299+ FOREACH_SIZE_OPCODE (SIZE_OPCODE_ENUM_ENTRY )
300+ };
301+
302+ #define SIZE_OPCODE_PARSE (name, nattr, min_st, stack_eff ) \
303+ if (!PyUnicode_CompareWithASCIIString(opcode_str, #name)) { \
304+ *num_attrs = nattr; \
305+ *min_stack = min_st; \
306+ *stack_effect = stack_eff; \
307+ return SizeOpcode::name; \
308+ }
309+
310+ static Result<SizeOpcode> size_opcode_parse (PyObject* opcode_str,
311+ int * num_attrs, int * min_stack, int * stack_effect) {
312+ FOREACH_SIZE_OPCODE (SIZE_OPCODE_PARSE );
313+ return raise (PyExc_ValueError, " Invalid opcode string %R" , opcode_str);
314+ }
315+
316+ namespace { struct SizeProgram {
317+ enum { kMaxStackDepth = 32 };
318+
319+ Vec<SizeOpcode> opcodes;
320+ Vec<int64_t > op_attrs;
321+ }; }
322+
288323struct TileKernel {
289324 CudaKernel cukernel;
325+ SizeProgram dyn_smem_size_prog;
290326};
291327
292328using KernelMap = HashMap<Vec<int64_t >, TileKernel>;
@@ -1410,6 +1446,73 @@ struct TileContextDispatcher {
14101446};
14111447
14121448
1449+ static int64_t size_program_eval (const SizeProgram& prog, const Vec<Word>& cuargs) {
1450+ int64_t stack[SizeProgram::kMaxStackDepth ];
1451+ int64_t * top = stack;
1452+ const int64_t * op_attrs = prog.op_attrs .data ();
1453+ for (SizeOpcode opcode : prog.opcodes ) {
1454+ switch (opcode) {
1455+ case SizeOpcode::Const: *top++ = *op_attrs++; break ;
1456+ case SizeOpcode::KernelArgI32: *top++ = cuargs[*op_attrs++].i32 ; break ;
1457+ case SizeOpcode::Add: top[-2 ] += top[-1 ]; --top; break ; // TODO: overflow check?
1458+ case SizeOpcode::Mul: top[-2 ] *= top[-1 ]; --top; break ; // TODO: overflow check?
1459+ }
1460+ }
1461+ return stack[0 ];
1462+ }
1463+
1464+ static Result<SizeProgram> size_program_parse (PyObject* opcodes_pylist, PyObject* attrs_pylist) {
1465+ if (opcodes_pylist == Py_None)
1466+ return SizeProgram{{SizeOpcode::Const}, {0 }};
1467+
1468+ Py_ssize_t num_opcodes = PyList_Size (opcodes_pylist);
1469+ Py_ssize_t num_attrs = PyList_Size (attrs_pylist);
1470+ if (PyErr_Occurred ()) return ErrorRaised;
1471+
1472+ SizeProgram prog;
1473+ Py_ssize_t remaining_attrs = num_attrs;
1474+ int depth = 0 ;
1475+ for (Py_ssize_t i = 0 ; i < num_opcodes; ++i) {
1476+ PyObject* py_opcode = PyList_GetItem (opcodes_pylist, i);
1477+ if (!py_opcode) return ErrorRaised;
1478+
1479+ int opcode_attrs, min_stack, stack_eff;
1480+ Result<SizeOpcode> opcode_res = size_opcode_parse (
1481+ py_opcode, &opcode_attrs, &min_stack, &stack_eff);
1482+ if (!opcode_res.is_ok ()) return ErrorRaised;
1483+
1484+ if (remaining_attrs < opcode_attrs)
1485+ return raise (PyExc_ValueError,
1486+ " Invalid size program (at op #%zd): not enough attributes"
1487+ " for opcode %u (need %d, have %zd)" ,
1488+ i, static_cast <unsigned >(*opcode_res), opcode_attrs, remaining_attrs);
1489+ remaining_attrs -= opcode_attrs;
1490+
1491+ if (depth < min_stack)
1492+ return raise (PyExc_ValueError, " Invalid size program: not enough values on stack" );
1493+ depth += stack_eff;
1494+ if (depth > SizeProgram::kMaxStackDepth )
1495+ return raise (PyExc_ValueError, " Invalid size program: stack overflow" );
1496+
1497+ prog.opcodes .push_back (*opcode_res);
1498+ }
1499+
1500+ if (remaining_attrs != 0 )
1501+ return raise (PyExc_ValueError, " Invalid size program: too many attributes" );
1502+ if (depth != 1 )
1503+ return raise (PyExc_ValueError, " Invalid size program: expected exactly 1 value on stack" );
1504+
1505+ for (Py_ssize_t i = 0 ; i < num_attrs; ++i) {
1506+ PyObject* py_attr = PyList_GetItem (attrs_pylist, i);
1507+ if (!py_attr) return ErrorRaised;
1508+
1509+ prog.op_attrs .push_back (pylong_as<int64_t >(py_attr));
1510+ if (PyErr_Occurred ()) return ErrorRaised;
1511+ }
1512+
1513+ return prog;
1514+ }
1515+
14131516namespace { struct TileDispatcher {
14141517 Vec<bool > constant_arg_flags;
14151518 Vec<bool > int64_index_flags;
@@ -1440,20 +1543,25 @@ static Result<TileKernel> compile(const DriverApi* driver,
14401543 return raise (PyExc_TypeError, " Expected compile() to return a tuple, got %s" ,
14411544 Py_TYPE (compile_result.get ())->tp_name );
14421545
1443- if (PyTuple_GET_SIZE (compile_result.get ()) != 2 )
1444- return raise (PyExc_TypeError, " Expected compile() to return a 2 -tuple, got length %zd" ,
1546+ if (PyTuple_GET_SIZE (compile_result.get ()) != 4 )
1547+ return raise (PyExc_TypeError, " Expected compile() to return a 4 -tuple, got length %zd" ,
14451548 PyTuple_GET_SIZE (compile_result.get ()));
14461549
14471550 PyObject* py_cubin_bytes = PyTuple_GET_ITEM (compile_result.get (), 0 );
14481551 PyObject* py_cufunc_name = PyTuple_GET_ITEM (compile_result.get (), 1 );
1552+ PyObject* py_dyn_smem_size_opcodes = PyTuple_GET_ITEM (compile_result.get (), 2 );
1553+ PyObject* py_dyn_smem_size_opattrs = PyTuple_GET_ITEM (compile_result.get (), 3 );
14491554
1450- if (!PyBytes_Check (py_cubin_bytes) || !PyUnicode_Check (py_cufunc_name))
1555+ if (!PyBytes_Check (py_cubin_bytes)
1556+ || !PyUnicode_Check (py_cufunc_name)
1557+ || (py_dyn_smem_size_opcodes != Py_None && !PyList_Check (py_dyn_smem_size_opcodes))
1558+ || (py_dyn_smem_size_opattrs != Py_None && !PyList_Check (py_dyn_smem_size_opattrs))) {
14511559 return raise (PyExc_TypeError,
1452- " Expected compile() to return (bytes, str),"
1560+ " Expected compile() to return (bytes, str, list|None, list|None ),"
14531561 " got %s, %s" ,
14541562 Py_TYPE (py_cubin_bytes)->tp_name ,
14551563 Py_TYPE (py_cufunc_name)->tp_name );
1456-
1564+ }
14571565
14581566 char * cubin_data;
14591567 Py_ssize_t cubin_size;
@@ -1466,7 +1574,11 @@ static Result<TileKernel> compile(const DriverApi* driver,
14661574 Result<CudaKernel> cukernel = load_cuda_kernel (driver, cubin_data, cubin_size, cufunc_name);
14671575 if (!cukernel.is_ok ()) return ErrorRaised;
14681576
1469- return TileKernel{std::move (*cukernel)};
1577+ Result<SizeProgram> dyn_smem_size_prog = size_program_parse (
1578+ py_dyn_smem_size_opcodes, py_dyn_smem_size_opattrs);
1579+ if (!dyn_smem_size_prog.is_ok ()) return ErrorRaised;
1580+
1581+ return TileKernel{std::move (*cukernel), std::move (*dyn_smem_size_prog)};
14701582}
14711583
14721584static inline bool has_torch_tensor_input (const Vec<PyTypeObject*>& pyarg_types) {
@@ -1645,6 +1757,7 @@ static bool try_clarify_invalid_value_error(const DriverApi* driver, const Grid&
16451757struct PreparedLaunch {
16461758 LaunchHelperPtr helper;
16471759 CUkernel kernel;
1760+ unsigned dynamic_smem_bytes;
16481761};
16491762
16501763static Result<CUcontext> get_stream_context (const DriverApi* driver, CUstream stream) {
@@ -1791,14 +1904,19 @@ static Result<PreparedLaunch> prepare_launch(
17911904 for (Word& arg : helper->cuargs )
17921905 helper->cuarg_pointers .push_back (&arg);
17931906
1794- return PreparedLaunch{std::move (helper), kernel_item->value .cukernel .kernel };
1907+ int64_t dyn_smem_size = size_program_eval (
1908+ kernel_item->value .dyn_smem_size_prog , helper->cuargs );
1909+ if (dyn_smem_size < 0 || dyn_smem_size > UINT_MAX )
1910+ return raise (PyExc_RuntimeError, " Invalid dynamic shared memory size" );
1911+
1912+ return PreparedLaunch{std::move (helper), kernel_item->value .cukernel .kernel ,
1913+ static_cast <unsigned >(dyn_smem_size)};
17951914}
17961915
17971916static Status launch (const DriverApi* driver,
17981917 PyObject* dispatcher_pyobj,
17991918 Grid grid,
18001919 Grid block,
1801- unsigned dynamic_shared_memory_bytes,
18021920 CUstream launch_stream,
18031921 PyObject* const * pyargs,
18041922 Py_ssize_t num_pyargs) {
@@ -1815,7 +1933,7 @@ static Status launch(const DriverApi* driver,
18151933 reinterpret_cast <CUfunction>(prep->kernel ),
18161934 grid.dims [0 ], grid.dims [1 ], grid.dims [2 ],
18171935 block.dims [0 ], block.dims [1 ], block.dims [2 ],
1818- dynamic_shared_memory_bytes ,
1936+ prep-> dynamic_smem_bytes ,
18191937 launch_stream,
18201938 prep->helper ->cuarg_pointers .data (),
18211939 nullptr );
@@ -1926,7 +2044,7 @@ static Result<double> benchmark(const DriverApi* driver,
19262044 kparams.blockDimX = 1 ;
19272045 kparams.blockDimY = 1 ;
19282046 kparams.blockDimZ = 1 ;
1929- kparams.sharedMemBytes = 0 ;
2047+ kparams.sharedMemBytes = pl. dynamic_smem_bytes ;
19302048 kparams.kernelParams = pl.helper ->cuarg_pointers .data ();
19312049 kparams.extra = nullptr ;
19322050 kparams.kern = pl.kernel ;
@@ -2201,8 +2319,7 @@ static Status parse_launch_args(PyObject* const* args, Py_ssize_t nargs, const c
22012319}
22022320
22032321static PyObject* launch_impl (PyObject* const * args, Py_ssize_t nargs,
2204- const char * signature, unsigned dynamic_shared_memory_bytes,
2205- bool with_block) {
2322+ const char * signature, bool with_block) {
22062323 LaunchArgs launch_args;
22072324 if (!parse_launch_args (args, nargs, signature, with_block, &launch_args))
22082325 return nullptr ;
@@ -2211,8 +2328,7 @@ static PyObject* launch_impl(PyObject* const* args, Py_ssize_t nargs,
22112328 if (!driver.is_ok ()) return nullptr ;
22122329
22132330 if (!launch (*driver, launch_args.dispatcher , launch_args.grid , launch_args.block ,
2214- dynamic_shared_memory_bytes, launch_args.stream , launch_args.kernel_args ,
2215- launch_args.num_kernel_args ))
2331+ launch_args.stream , launch_args.kernel_args , launch_args.num_kernel_args ))
22162332 return nullptr ;
22172333
22182334 return Py_NewRef (Py_None);
@@ -2221,47 +2337,13 @@ static PyObject* launch_impl(PyObject* const* args, Py_ssize_t nargs,
22212337#define LAUNCH_SIGNATURE " launch(stream, grid, kernel, kernel_args, /)"
22222338
22232339static PyObject* cuda_tile_launch (PyObject*, PyObject* const * args, Py_ssize_t nargs) {
2224- return launch_impl (args, nargs, LAUNCH_SIGNATURE ,
2225- /* dynamic_shared_memory_bytes=*/ 0 ,
2226- /* with_block=*/ false );
2227- }
2228-
2229- #define LAUNCH_EXTENDED_SIGNATURE " launch(stream, grid, block, kernel, kernel_args, /," \
2230- " *, dynamic_shared_memory_bytes=None)"
2231-
2232- static PyObject* launch_extended (PyObject*, PyObject* const * args, Py_ssize_t nargs,
2233- PyObject* kwnames) {
2234- unsigned dynamic_shared_memory_bytes = 0 ;
2235- if (kwnames) {
2236- Py_ssize_t num_kwargs = PyTuple_GET_SIZE (kwnames);
2237- if (num_kwargs > 1 ) {
2238- return PyErr_Format (PyExc_TypeError,
2239- " Too many keyword arguments to " LAUNCH_EXTENDED_SIGNATURE );
2240- }
2241- if (num_kwargs) {
2242- PyObject* keyword = PyTuple_GET_ITEM (kwnames, 0 );
2243- if (PyUnicode_Compare (keyword, g_dynamic_shared_memory_bytes_pyunicode)) {
2244- return PyErr_Format (
2245- PyExc_TypeError,
2246- " Unexpected keyword argument '%U' to " LAUNCH_EXTENDED_SIGNATURE ,
2247- keyword);
2248- }
2249- PyObject* py_smem_size = args[nargs];
2250- if (py_smem_size != Py_None) {
2251- unsigned long smem_size_ul = PyLong_AsUnsignedLong (py_smem_size);
2252- if (PyErr_Occurred ()) return nullptr ;
2253-
2254- if (smem_size_ul > UINT_MAX ) {
2255- PyErr_SetString (PyExc_OverflowError,
2256- " dynamic_shared_memory_bytes is out of range" );
2257- return nullptr ;
2258- }
2259- dynamic_shared_memory_bytes = static_cast <unsigned >(smem_size_ul);
2260- }
2261- }
2262- }
2263- return launch_impl (args, nargs, LAUNCH_EXTENDED_SIGNATURE , dynamic_shared_memory_bytes,
2264- /* with_block=*/ true );
2340+ return launch_impl (args, nargs, LAUNCH_SIGNATURE , /* with_block=*/ false );
2341+ }
2342+
2343+ #define LAUNCH_EXTENDED_SIGNATURE " launch(stream, grid, block, kernel, kernel_args, /)"
2344+
2345+ static PyObject* launch_extended (PyObject*, PyObject* const * args, Py_ssize_t nargs) {
2346+ return launch_impl (args, nargs, LAUNCH_EXTENDED_SIGNATURE , /* with_block=*/ true );
22652347}
22662348
22672349#define BENCHMARK_SIGNATURE " _benchmark(stream, grid, kernel, pyargs_tuples, /)"
0 commit comments