Skip to content

Commit 9784de2

Browse files
committed
Compute dynamic shared memory size automatically
Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent 761053c commit 9784de2

6 files changed

Lines changed: 208 additions & 62 deletions

File tree

cext/cuda_helper.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,65 @@ PyObject* destroy_stream(PyObject* self, PyObject* arg) {
143143
Py_RETURN_NONE;
144144
}
145145

146+
static decltype(cuLaunchKernel)* g_real_cuLaunchKernel;
147+
static PyObject* g_cuLaunchKernel_spy_callback;
148+
149+
static CUresult shim_cuLaunchKernel(
150+
CUfunction f,
151+
unsigned int gridDimX,
152+
unsigned int gridDimY,
153+
unsigned int gridDimZ,
154+
unsigned int blockDimX,
155+
unsigned int blockDimY,
156+
unsigned int blockDimZ,
157+
unsigned int sharedMemBytes,
158+
CUstream hStream,
159+
void** kernelParams,
160+
void** extra) {
161+
162+
PyPtr res = steal(PyObject_CallFunction(
163+
g_cuLaunchKernel_spy_callback,
164+
"(K III III I K)",
165+
reinterpret_cast<unsigned long long>(f),
166+
gridDimX, gridDimY, gridDimZ,
167+
blockDimX, blockDimY, blockDimZ,
168+
sharedMemBytes,
169+
reinterpret_cast<unsigned long long>(hStream)
170+
));
171+
if (!res) return CUDA_ERROR_LAUNCH_FAILED;
172+
173+
return g_real_cuLaunchKernel(f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ,
174+
sharedMemBytes, hStream, kernelParams, extra);
175+
}
176+
177+
static PyObject* spy_on_cuLaunchKernel_begin(PyObject* self, PyObject* arg) {
178+
if (g_real_cuLaunchKernel)
179+
return PyErr_Format(PyExc_RuntimeError, "Already spying");
180+
181+
Result<const DriverApi*> driver_result = get_driver_api();
182+
if (!driver_result.is_ok()) return nullptr;
183+
184+
DriverApi* api = const_cast<DriverApi*>(*driver_result);
185+
g_real_cuLaunchKernel = api->cuLaunchKernel;
186+
g_cuLaunchKernel_spy_callback = Py_NewRef(arg);
187+
api->cuLaunchKernel = shim_cuLaunchKernel;
188+
return Py_NewRef(Py_None);
189+
}
190+
191+
static PyObject* spy_on_cuLaunchKernel_end(PyObject* self, PyObject* arg) {
192+
if (!g_real_cuLaunchKernel)
193+
return PyErr_Format(PyExc_RuntimeError, "Not spying");
194+
195+
Result<const DriverApi*> driver_result = get_driver_api();
196+
if (!driver_result.is_ok()) return nullptr;
197+
198+
DriverApi* api = const_cast<DriverApi*>(*driver_result);
199+
api->cuLaunchKernel = g_real_cuLaunchKernel;
200+
g_real_cuLaunchKernel = nullptr;
201+
Py_CLEAR(g_cuLaunchKernel_spy_callback);
202+
return Py_NewRef(Py_None);
203+
}
204+
146205
static PyMethodDef functions[] = {
147206
{"get_compute_capability", get_compute_capability, METH_NOARGS,
148207
"Get compute capability of the default CUDA device"},
@@ -156,6 +215,8 @@ static PyMethodDef functions[] = {
156215
"Create a non-blocking CUDA stream. Returns int handle."},
157216
{"_destroy_stream", destroy_stream, METH_O,
158217
"Destroy a CUDA stream given its int handle."},
218+
{"_spy_on_cuLaunchKernel_begin", spy_on_cuLaunchKernel_begin, METH_O, nullptr},
219+
{"_spy_on_cuLaunchKernel_end", spy_on_cuLaunchKernel_end, METH_NOARGS, nullptr},
159220
NULL
160221
};
161222

cext/tile_kernel.cpp

Lines changed: 137 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,44 @@ struct HashVector {
285285
}
286286
};
287287

288+
// X(Name, #Attrs, MinStack, StackEffect)
289+
#define FOREACH_SIZE_OPCODE(X) \
290+
X(Const, 1, 0, 1) \
291+
X(KernelArgI32, 1, 0, 1) \
292+
X(Add, 0, 2, -1) \
293+
X(Mul, 0, 2, -1)
294+
295+
#define SIZE_OPCODE_ENUM_ENTRY(name, _nattr, _min_st, _stack_eff) \
296+
name,
297+
298+
enum class SizeOpcode : uint8_t {
299+
FOREACH_SIZE_OPCODE(SIZE_OPCODE_ENUM_ENTRY)
300+
};
301+
302+
#define SIZE_OPCODE_PARSE(name, nattr, min_st, stack_eff) \
303+
if (!PyUnicode_CompareWithASCIIString(opcode_str, #name)) { \
304+
*num_attrs = nattr; \
305+
*min_stack = min_st; \
306+
*stack_effect = stack_eff; \
307+
return SizeOpcode::name; \
308+
}
309+
310+
static Result<SizeOpcode> size_opcode_parse(PyObject* opcode_str,
311+
int* num_attrs, int* min_stack, int* stack_effect) {
312+
FOREACH_SIZE_OPCODE(SIZE_OPCODE_PARSE);
313+
return raise(PyExc_ValueError, "Invalid opcode string %R", opcode_str);
314+
}
315+
316+
namespace { struct SizeProgram {
317+
enum { kMaxStackDepth = 32 };
318+
319+
Vec<SizeOpcode> opcodes;
320+
Vec<int64_t> op_attrs;
321+
}; }
322+
288323
struct TileKernel {
289324
CudaKernel cukernel;
325+
SizeProgram dyn_smem_size_prog;
290326
};
291327

292328
using KernelMap = HashMap<Vec<int64_t>, TileKernel>;
@@ -1410,6 +1446,73 @@ struct TileContextDispatcher {
14101446
};
14111447

14121448

1449+
static int64_t size_program_eval(const SizeProgram& prog, const Vec<Word>& cuargs) {
1450+
int64_t stack[SizeProgram::kMaxStackDepth];
1451+
int64_t* top = stack;
1452+
const int64_t* op_attrs = prog.op_attrs.data();
1453+
for (SizeOpcode opcode : prog.opcodes) {
1454+
switch (opcode) {
1455+
case SizeOpcode::Const: *top++ = *op_attrs++; break;
1456+
case SizeOpcode::KernelArgI32: *top++ = cuargs[*op_attrs++].i32; break;
1457+
case SizeOpcode::Add: top[-2] += top[-1]; --top; break; // TODO: overflow check?
1458+
case SizeOpcode::Mul: top[-2] *= top[-1]; --top; break; // TODO: overflow check?
1459+
}
1460+
}
1461+
return stack[0];
1462+
}
1463+
1464+
static Result<SizeProgram> size_program_parse(PyObject* opcodes_pylist, PyObject* attrs_pylist) {
1465+
if (opcodes_pylist == Py_None)
1466+
return SizeProgram{{SizeOpcode::Const}, {0}};
1467+
1468+
Py_ssize_t num_opcodes = PyList_Size(opcodes_pylist);
1469+
Py_ssize_t num_attrs = PyList_Size(attrs_pylist);
1470+
if (PyErr_Occurred()) return ErrorRaised;
1471+
1472+
SizeProgram prog;
1473+
Py_ssize_t remaining_attrs = num_attrs;
1474+
int depth = 0;
1475+
for (Py_ssize_t i = 0; i < num_opcodes; ++i) {
1476+
PyObject* py_opcode = PyList_GetItem(opcodes_pylist, i);
1477+
if (!py_opcode) return ErrorRaised;
1478+
1479+
int opcode_attrs, min_stack, stack_eff;
1480+
Result<SizeOpcode> opcode_res = size_opcode_parse(
1481+
py_opcode, &opcode_attrs, &min_stack, &stack_eff);
1482+
if (!opcode_res.is_ok()) return ErrorRaised;
1483+
1484+
if (remaining_attrs < opcode_attrs)
1485+
return raise(PyExc_ValueError,
1486+
"Invalid size program (at op #%zd): not enough attributes"
1487+
" for opcode %u (need %d, have %zd)",
1488+
i, static_cast<unsigned>(*opcode_res), opcode_attrs, remaining_attrs);
1489+
remaining_attrs -= opcode_attrs;
1490+
1491+
if (depth < min_stack)
1492+
return raise(PyExc_ValueError, "Invalid size program: not enough values on stack");
1493+
depth += stack_eff;
1494+
if (depth > SizeProgram::kMaxStackDepth)
1495+
return raise(PyExc_ValueError, "Invalid size program: stack overflow");
1496+
1497+
prog.opcodes.push_back(*opcode_res);
1498+
}
1499+
1500+
if (remaining_attrs != 0)
1501+
return raise(PyExc_ValueError, "Invalid size program: too many attributes");
1502+
if (depth != 1)
1503+
return raise(PyExc_ValueError, "Invalid size program: expected exactly 1 value on stack");
1504+
1505+
for (Py_ssize_t i = 0; i < num_attrs; ++i) {
1506+
PyObject* py_attr = PyList_GetItem(attrs_pylist, i);
1507+
if (!py_attr) return ErrorRaised;
1508+
1509+
prog.op_attrs.push_back(pylong_as<int64_t>(py_attr));
1510+
if (PyErr_Occurred()) return ErrorRaised;
1511+
}
1512+
1513+
return prog;
1514+
}
1515+
14131516
namespace { struct TileDispatcher {
14141517
Vec<bool> constant_arg_flags;
14151518
Vec<bool> int64_index_flags;
@@ -1440,20 +1543,25 @@ static Result<TileKernel> compile(const DriverApi* driver,
14401543
return raise(PyExc_TypeError, "Expected compile() to return a tuple, got %s",
14411544
Py_TYPE(compile_result.get())->tp_name);
14421545

1443-
if (PyTuple_GET_SIZE(compile_result.get()) != 2)
1444-
return raise(PyExc_TypeError, "Expected compile() to return a 2-tuple, got length %zd",
1546+
if (PyTuple_GET_SIZE(compile_result.get()) != 4)
1547+
return raise(PyExc_TypeError, "Expected compile() to return a 4-tuple, got length %zd",
14451548
PyTuple_GET_SIZE(compile_result.get()));
14461549

14471550
PyObject* py_cubin_bytes = PyTuple_GET_ITEM(compile_result.get(), 0);
14481551
PyObject* py_cufunc_name = PyTuple_GET_ITEM(compile_result.get(), 1);
1552+
PyObject* py_dyn_smem_size_opcodes = PyTuple_GET_ITEM(compile_result.get(), 2);
1553+
PyObject* py_dyn_smem_size_opattrs = PyTuple_GET_ITEM(compile_result.get(), 3);
14491554

1450-
if (!PyBytes_Check(py_cubin_bytes) || !PyUnicode_Check(py_cufunc_name))
1555+
if (!PyBytes_Check(py_cubin_bytes)
1556+
|| !PyUnicode_Check(py_cufunc_name)
1557+
|| (py_dyn_smem_size_opcodes != Py_None && !PyList_Check(py_dyn_smem_size_opcodes))
1558+
|| (py_dyn_smem_size_opattrs != Py_None && !PyList_Check(py_dyn_smem_size_opattrs))) {
14511559
return raise(PyExc_TypeError,
1452-
"Expected compile() to return (bytes, str),"
1560+
"Expected compile() to return (bytes, str, list|None, list|None),"
14531561
" got %s, %s",
14541562
Py_TYPE(py_cubin_bytes)->tp_name,
14551563
Py_TYPE(py_cufunc_name)->tp_name);
1456-
1564+
}
14571565

14581566
char* cubin_data;
14591567
Py_ssize_t cubin_size;
@@ -1466,7 +1574,11 @@ static Result<TileKernel> compile(const DriverApi* driver,
14661574
Result<CudaKernel> cukernel = load_cuda_kernel(driver, cubin_data, cubin_size, cufunc_name);
14671575
if (!cukernel.is_ok()) return ErrorRaised;
14681576

1469-
return TileKernel{std::move(*cukernel)};
1577+
Result<SizeProgram> dyn_smem_size_prog = size_program_parse(
1578+
py_dyn_smem_size_opcodes, py_dyn_smem_size_opattrs);
1579+
if (!dyn_smem_size_prog.is_ok()) return ErrorRaised;
1580+
1581+
return TileKernel{std::move(*cukernel), std::move(*dyn_smem_size_prog)};
14701582
}
14711583

14721584
static inline bool has_torch_tensor_input(const Vec<PyTypeObject*>& pyarg_types) {
@@ -1645,6 +1757,7 @@ static bool try_clarify_invalid_value_error(const DriverApi* driver, const Grid&
16451757
struct PreparedLaunch {
16461758
LaunchHelperPtr helper;
16471759
CUkernel kernel;
1760+
unsigned dynamic_smem_bytes;
16481761
};
16491762

16501763
static Result<CUcontext> get_stream_context(const DriverApi* driver, CUstream stream) {
@@ -1791,14 +1904,19 @@ static Result<PreparedLaunch> prepare_launch(
17911904
for (Word& arg : helper->cuargs)
17921905
helper->cuarg_pointers.push_back(&arg);
17931906

1794-
return PreparedLaunch{std::move(helper), kernel_item->value.cukernel.kernel};
1907+
int64_t dyn_smem_size = size_program_eval(
1908+
kernel_item->value.dyn_smem_size_prog, helper->cuargs);
1909+
if (dyn_smem_size < 0 || dyn_smem_size > UINT_MAX)
1910+
return raise(PyExc_RuntimeError, "Invalid dynamic shared memory size");
1911+
1912+
return PreparedLaunch{std::move(helper), kernel_item->value.cukernel.kernel,
1913+
static_cast<unsigned>(dyn_smem_size)};
17951914
}
17961915

17971916
static Status launch(const DriverApi* driver,
17981917
PyObject* dispatcher_pyobj,
17991918
Grid grid,
18001919
Grid block,
1801-
unsigned dynamic_shared_memory_bytes,
18021920
CUstream launch_stream,
18031921
PyObject* const* pyargs,
18041922
Py_ssize_t num_pyargs) {
@@ -1815,7 +1933,7 @@ static Status launch(const DriverApi* driver,
18151933
reinterpret_cast<CUfunction>(prep->kernel),
18161934
grid.dims[0], grid.dims[1], grid.dims[2],
18171935
block.dims[0], block.dims[1], block.dims[2],
1818-
dynamic_shared_memory_bytes,
1936+
prep->dynamic_smem_bytes,
18191937
launch_stream,
18201938
prep->helper->cuarg_pointers.data(),
18211939
nullptr);
@@ -1926,7 +2044,7 @@ static Result<double> benchmark(const DriverApi* driver,
19262044
kparams.blockDimX = 1;
19272045
kparams.blockDimY = 1;
19282046
kparams.blockDimZ = 1;
1929-
kparams.sharedMemBytes = 0;
2047+
kparams.sharedMemBytes = pl.dynamic_smem_bytes;
19302048
kparams.kernelParams = pl.helper->cuarg_pointers.data();
19312049
kparams.extra = nullptr;
19322050
kparams.kern = pl.kernel;
@@ -2201,8 +2319,7 @@ static Status parse_launch_args(PyObject* const* args, Py_ssize_t nargs, const c
22012319
}
22022320

22032321
static PyObject* launch_impl(PyObject* const* args, Py_ssize_t nargs,
2204-
const char* signature, unsigned dynamic_shared_memory_bytes,
2205-
bool with_block) {
2322+
const char* signature, bool with_block) {
22062323
LaunchArgs launch_args;
22072324
if (!parse_launch_args(args, nargs, signature, with_block, &launch_args))
22082325
return nullptr;
@@ -2211,8 +2328,7 @@ static PyObject* launch_impl(PyObject* const* args, Py_ssize_t nargs,
22112328
if (!driver.is_ok()) return nullptr;
22122329

22132330
if (!launch(*driver, launch_args.dispatcher, launch_args.grid, launch_args.block,
2214-
dynamic_shared_memory_bytes, launch_args.stream, launch_args.kernel_args,
2215-
launch_args.num_kernel_args))
2331+
launch_args.stream, launch_args.kernel_args, launch_args.num_kernel_args))
22162332
return nullptr;
22172333

22182334
return Py_NewRef(Py_None);
@@ -2221,47 +2337,13 @@ static PyObject* launch_impl(PyObject* const* args, Py_ssize_t nargs,
22212337
#define LAUNCH_SIGNATURE "launch(stream, grid, kernel, kernel_args, /)"
22222338

22232339
static PyObject* cuda_tile_launch(PyObject*, PyObject* const* args, Py_ssize_t nargs) {
2224-
return launch_impl(args, nargs, LAUNCH_SIGNATURE,
2225-
/*dynamic_shared_memory_bytes=*/ 0,
2226-
/*with_block=*/ false);
2227-
}
2228-
2229-
#define LAUNCH_EXTENDED_SIGNATURE "launch(stream, grid, block, kernel, kernel_args, /,"\
2230-
" *, dynamic_shared_memory_bytes=None)"
2231-
2232-
static PyObject* launch_extended(PyObject*, PyObject* const* args, Py_ssize_t nargs,
2233-
PyObject* kwnames) {
2234-
unsigned dynamic_shared_memory_bytes = 0;
2235-
if (kwnames) {
2236-
Py_ssize_t num_kwargs = PyTuple_GET_SIZE(kwnames);
2237-
if (num_kwargs > 1) {
2238-
return PyErr_Format(PyExc_TypeError,
2239-
"Too many keyword arguments to " LAUNCH_EXTENDED_SIGNATURE);
2240-
}
2241-
if (num_kwargs) {
2242-
PyObject* keyword = PyTuple_GET_ITEM(kwnames, 0);
2243-
if (PyUnicode_Compare(keyword, g_dynamic_shared_memory_bytes_pyunicode)) {
2244-
return PyErr_Format(
2245-
PyExc_TypeError,
2246-
"Unexpected keyword argument '%U' to " LAUNCH_EXTENDED_SIGNATURE,
2247-
keyword);
2248-
}
2249-
PyObject* py_smem_size = args[nargs];
2250-
if (py_smem_size != Py_None) {
2251-
unsigned long smem_size_ul = PyLong_AsUnsignedLong(py_smem_size);
2252-
if (PyErr_Occurred()) return nullptr;
2253-
2254-
if (smem_size_ul > UINT_MAX) {
2255-
PyErr_SetString(PyExc_OverflowError,
2256-
"dynamic_shared_memory_bytes is out of range");
2257-
return nullptr;
2258-
}
2259-
dynamic_shared_memory_bytes = static_cast<unsigned>(smem_size_ul);
2260-
}
2261-
}
2262-
}
2263-
return launch_impl(args, nargs, LAUNCH_EXTENDED_SIGNATURE, dynamic_shared_memory_bytes,
2264-
/*with_block=*/ true);
2340+
return launch_impl(args, nargs, LAUNCH_SIGNATURE, /*with_block=*/ false);
2341+
}
2342+
2343+
#define LAUNCH_EXTENDED_SIGNATURE "launch(stream, grid, block, kernel, kernel_args, /)"
2344+
2345+
static PyObject* launch_extended(PyObject*, PyObject* const* args, Py_ssize_t nargs) {
2346+
return launch_impl(args, nargs, LAUNCH_EXTENDED_SIGNATURE, /*with_block=*/ true);
22652347
}
22662348

22672349
#define BENCHMARK_SIGNATURE "_benchmark(stream, grid, kernel, pyargs_tuples, /)"

0 commit comments

Comments
 (0)