Skip to content

Commit ff12318

Browse files
committed
Delay cuda driver load until launch
Signed-off-by: Jay Gu <jagu@nvidia.com>
1 parent bb2ebac commit ff12318

11 files changed

Lines changed: 214 additions & 159 deletions

cext/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ target_link_options(_cext_shared PUBLIC ${cext_link_flags} ${nostdlib_flags} -Wl
7474
add_executable(test_stream_buffer
7575
test/test_stream_buffer.cpp
7676
cuda_loader.cpp
77+
cuda_helper.cpp
7778
memory.cpp
7879
)
7980
target_compile_options(test_stream_buffer PUBLIC ${cext_compile_flags} ${test_coverage_options})

cext/cuda_helper.cpp

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,40 +6,52 @@
66
#include "cuda_loader.h"
77

88

9-
const char* get_cuda_error(CUresult res) {
9+
const char* get_cuda_error(const DriverApi* driver, CUresult res) {
1010
const char* str = nullptr;
11-
g_cuGetErrorString(res, &str);
11+
driver->cuGetErrorString(res, &str);
1212
return str ? str : "Unknown error";
1313
}
1414

15-
void try_init_cuda() {
16-
ErrorGuard guard;
17-
CUresult res = g_cuInit(0);
15+
Status check_driver_version(const DriverApi* driver, int minimum_version) {
16+
int version;
17+
CUresult res = driver->cuDriverGetVersion(&version);
1818
if (res != CUDA_SUCCESS) {
19-
raise(PyExc_RuntimeError, "cuInit: %s", get_cuda_error(res));
20-
SavedException exc = save_raised_exception();
21-
LOG_PYTHON_ERROR("warning", exc, "Failed to initialized CUDA");
19+
PyErr_Format(PyExc_RuntimeError, "cuDriverGetVersion: %s", get_cuda_error(driver, res));
20+
return ErrorRaised;
21+
}
22+
if (version < minimum_version) {
23+
int major = version / 1000;
24+
int minor = (version % 1000) / 10;
25+
int required_major = minimum_version / 1000;
26+
PyErr_Format(PyExc_RuntimeError,
27+
"Minimum driver version required is %d.0, got %d.%d",
28+
required_major, major, minor);
29+
return ErrorRaised;
2230
}
31+
return OK;
2332
}
2433

2534
PyObject* get_max_grid_size(PyObject *self, PyObject *args) {
2635
int device_id;
2736
if (!PyArg_ParseTuple(args, "i", &device_id))
2837
return NULL;
2938

39+
Result<const DriverApi*> driver = get_driver_api();
40+
if (!driver.is_ok()) return NULL;
41+
3042
CUdevice dev;
31-
CUresult res = g_cuDeviceGet(&dev, device_id);
43+
CUresult res = (*driver)->cuDeviceGet(&dev, device_id);
3244
if (res != CUDA_SUCCESS)
33-
return PyErr_Format(PyExc_RuntimeError, "cuDeviceGet: %s", get_cuda_error(res));
45+
return PyErr_Format(PyExc_RuntimeError, "cuDeviceGet: %s", get_cuda_error(*driver, res));
3446

3547
int max_grid_size[3];
3648
for (int i = 0; i < 3; ++i) {
37-
res = g_cuDeviceGetAttribute(&max_grid_size[i],
49+
res = (*driver)->cuDeviceGetAttribute(&max_grid_size[i],
3850
static_cast<CUdevice_attribute>(CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X + i),
3951
dev);
4052
if (res != CUDA_SUCCESS) {
4153
return PyErr_Format(PyExc_RuntimeError,
42-
"cuDeviceGetAttribute: %s", get_cuda_error(res));
54+
"cuDeviceGetAttribute: %s", get_cuda_error(*driver, res));
4355
}
4456
}
4557
return Py_BuildValue("(iii)", max_grid_size[0], max_grid_size[1], max_grid_size[2]);
@@ -48,26 +60,36 @@ PyObject* get_max_grid_size(PyObject *self, PyObject *args) {
4860
PyObject* get_compute_capability(PyObject *self, PyObject *Py_UNUSED(ignored)) {
4961
int major, minor;
5062
CUdevice dev;
51-
CUresult res = g_cuDeviceGet(&dev, 0);
63+
64+
Result<const DriverApi*> driver_result = get_driver_api();
65+
if (!driver_result.is_ok()) return NULL;
66+
const DriverApi* d = *driver_result;
67+
68+
CUresult res = d->cuDeviceGet(&dev, 0);
5269
if (res != CUDA_SUCCESS) {
53-
return PyErr_Format(PyExc_RuntimeError, "cuDeviceGet: %s", get_cuda_error(res));
70+
return PyErr_Format(PyExc_RuntimeError, "cuDeviceGet: %s", get_cuda_error(d, res));
5471
}
55-
res = g_cuDeviceGetAttribute(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, dev);
72+
res = d->cuDeviceGetAttribute(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, dev);
5673
if (res != CUDA_SUCCESS) {
57-
return PyErr_Format(PyExc_RuntimeError, "cuDeviceGetAttribute: %s", get_cuda_error(res));
74+
return PyErr_Format(PyExc_RuntimeError, "cuDeviceGetAttribute: %s", get_cuda_error(d, res));
5875
}
59-
res = g_cuDeviceGetAttribute(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, dev);
76+
res = d->cuDeviceGetAttribute(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, dev);
6077
if (res != CUDA_SUCCESS) {
61-
return PyErr_Format(PyExc_RuntimeError, "cuDeviceGetAttribute: %s", get_cuda_error(res));
78+
return PyErr_Format(PyExc_RuntimeError, "cuDeviceGetAttribute: %s", get_cuda_error(d, res));
6279
}
6380
return Py_BuildValue("(ii)", major, minor);
6481
}
6582

6683
PyObject* get_driver_version(PyObject *self, PyObject *Py_UNUSED(ignored)) {
6784
int major, minor;
68-
CUresult res = g_cuDriverGetVersion(&major);
85+
86+
Result<const DriverApi*> driver_result = get_driver_api();
87+
if (!driver_result.is_ok()) return NULL;
88+
const DriverApi* d = *driver_result;
89+
90+
CUresult res = d->cuDriverGetVersion(&major);
6991
if (res != CUDA_SUCCESS) {
70-
return PyErr_Format(PyExc_RuntimeError, "cuDriverGetVersion: %s", get_cuda_error(res));
92+
return PyErr_Format(PyExc_RuntimeError, "cuDriverGetVersion: %s", get_cuda_error(d, res));
7193
}
7294
minor = (major % 1000) / 10;
7395
major = major / 1000;

cext/cuda_helper.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,12 @@
99
#include "py.h"
1010
#include <cuda.h>
1111

12+
struct DriverApi;
13+
1214
Status cuda_helper_init(PyObject* m);
1315

14-
const char* get_cuda_error(CUresult res);
16+
const char* get_cuda_error(const DriverApi*, CUresult res);
17+
18+
void try_cuInit(const DriverApi*);
1519

16-
void try_init_cuda();
20+
Status check_driver_version(const DriverApi*, int minimum_version);

cext/cuda_loader.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,12 @@ F get_proc_address(cuGetProcAddress_v2_t getter,
5050
FOREACH_CUDA_FUNCTION_TO_LOAD(DEFINE_CUDA_FUNCTION_GLOBAL)
5151

5252
#define GET_PROC_ADDRESS(name, cuda_ver) \
53-
if (!(g_##name = get_proc_address<decltype(name)*>(_cuGetProcAddress, #name, cuda_ver))) \
53+
if (!(driver_api.name = \
54+
get_proc_address<decltype(name)*>(_cuGetProcAddress, #name, cuda_ver))) \
5455
return ErrorRaised;
5556

5657

57-
Status cuda_loader_init() {
58+
static Status cuda_loader_init(DriverApi& driver_api) {
5859
PyPtr load_libcuda_mod = steal(PyImport_ImportModule("cuda.tile._load_libcuda"));
5960
if (!load_libcuda_mod) return ErrorRaised;
6061

@@ -72,3 +73,22 @@ Status cuda_loader_init() {
7273

7374
return OK;
7475
}
76+
77+
78+
static constexpr int MIN_DRIVER_VERSION = 13000;
79+
80+
Result<const DriverApi*> get_driver_api() {
81+
static bool initialized;
82+
static DriverApi instance;
83+
if (!initialized) {
84+
if (!cuda_loader_init(instance))
85+
return ErrorRaised;
86+
CUresult res = instance.cuInit(0);
87+
if (res != CUDA_SUCCESS)
88+
return raise(PyExc_RuntimeError, "cuInit: %s", get_cuda_error(&instance, res));
89+
if (!check_driver_version(&instance, MIN_DRIVER_VERSION))
90+
return ErrorRaised;
91+
initialized = true;
92+
}
93+
return &instance;
94+
}

cext/cuda_loader.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
#include "py.h"
1010
#include <cuda.h>
1111

12-
Status cuda_loader_init();
13-
1412
#define FOREACH_CUDA_FUNCTION_TO_LOAD(X) \
1513
X(cuInit, 2000) \
1614
X(cuLibraryLoadFromFile, 12000) \
@@ -47,8 +45,10 @@ Status cuda_loader_init();
4745

4846

4947
#define DECLARE_CUDA_FUNC_EXTERN(name, _cuda_version) \
50-
extern decltype(name)* g_##name;
51-
52-
FOREACH_CUDA_FUNCTION_TO_LOAD(DECLARE_CUDA_FUNC_EXTERN)
48+
decltype(::name)* name;
5349

50+
struct DriverApi {
51+
FOREACH_CUDA_FUNCTION_TO_LOAD(DECLARE_CUDA_FUNC_EXTERN)
52+
};
5453

54+
Result<const DriverApi*> get_driver_api();

cext/module.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
#include "py.h"
66

7-
#include "cuda_loader.h"
87
#include "tile_kernel.h"
98
#include "cuda_helper.h"
109

@@ -20,11 +19,6 @@ static PyModuleDef module_def = {
2019
};
2120

2221
PyMODINIT_FUNC PyInit__cext() {
23-
if (!cuda_loader_init())
24-
return nullptr;
25-
26-
try_init_cuda();
27-
2822
PyPtr m = steal(PyModule_Create(&module_def));
2923
if (!m) return nullptr;
3024

0 commit comments

Comments
 (0)