Skip to content

Commit 8e79272

Browse files
committed
make cuPythonInit reentrant + ensure GIL is released when calling underlying C APIs
1 parent 5f4125e commit 8e79272

7 files changed

Lines changed: 585 additions & 580 deletions

File tree

cuda_bindings/cuda/bindings/_bindings/cydriver.pyx.in

Lines changed: 567 additions & 561 deletions
Large diffs are not rendered by default.

cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,10 @@ cdef bint __cuPythonInit = False
4242

4343
cdef int cuPythonInit() except -1 nogil:
4444
global __cuPythonInit
45+
if __cuPythonInit:
46+
return 0
4547

4648
with gil, __symbol_lock:
47-
if __cuPythonInit:
48-
return 0
49-
5049
{{if 'Windows' == platform.system()}}
5150
handle = load_nvidia_dynamic_lib("nvrtc")._handle_uint
5251

@@ -221,7 +220,7 @@ cdef int cuPythonInit() except -1 nogil:
221220
{{endif}}
222221

223222
{{else}}
224-
handle = <void*><uintptr_t>load_nvidia_dynamic_lib("nvrtc")._handle_uint
223+
handle = <void*><uintptr_t>(load_nvidia_dynamic_lib("nvrtc")._handle_uint)
225224

226225
# Load function
227226
{{if 'nvrtcGetErrorString' in found_functions}}

cuda_bindings/cuda/bindings/_internal/cufile_linux.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,13 @@ cdef void* load_library(const int driver_ver) except* with gil:
9292

9393
cdef int _check_or_init_cufile() except -1 nogil:
9494
global __py_cufile_init
95+
if __py_cufile_init:
96+
return 0
97+
9598
cdef void* handle = NULL
9699
cdef int err, driver_ver = 0
97100

98101
with gil, __symbol_lock:
99-
if __py_cufile_init:
100-
return 0
101-
102102
# Load driver to check version
103103
handle = dlopen('libcuda.so.1', RTLD_NOW | RTLD_GLOBAL)
104104
if handle == NULL:

cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,13 @@ cdef void* load_library(int driver_ver) except* with gil:
6262

6363
cdef int _check_or_init_nvjitlink() except -1 nogil:
6464
global __py_nvjitlink_init
65+
if __py_nvjitlink_init:
66+
return 0
67+
6568
cdef void* handle = NULL
6669
cdef int err, driver_ver = 0
6770

6871
with gil, __symbol_lock:
69-
if __py_nvjitlink_init:
70-
return 0
71-
7272
# Load driver to check version
7373
handle = dlopen('libcuda.so.1', RTLD_NOW | RTLD_GLOBAL)
7474
if handle == NULL:

cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ cdef void* __nvJitLinkVersion = NULL
4343

4444
cdef int _check_or_init_nvjitlink() except -1 nogil:
4545
global __py_nvjitlink_init
46+
if __py_nvjitlink_init:
47+
return 0
48+
4649
cdef int err, driver_ver = 0
4750

4851
with gil, __symbol_lock:
49-
if __py_nvjitlink_init:
50-
return 0
51-
5252
# Load driver to check version
5353
try:
5454
handle = win32api.LoadLibraryEx("nvcuda.dll", 0, LOAD_LIBRARY_SEARCH_SYSTEM32)

cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,13 @@ cdef void* load_library(const int driver_ver) except* with gil:
6161

6262
cdef int _check_or_init_nvvm() except -1 nogil:
6363
global __py_nvvm_init
64+
if __py_nvvm_init:
65+
return 0
66+
6467
cdef void* handle = NULL
6568
cdef int err, driver_ver = 0
6669

6770
with gil, __symbol_lock:
68-
if __py_nvvm_init:
69-
return 0
70-
7171
# Load driver to check version
7272
handle = dlopen('libcuda.so.1', RTLD_NOW | RTLD_GLOBAL)
7373
if handle == NULL:

cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@ cdef void* __nvvmGetProgramLog = NULL
4242

4343
cdef int _check_or_init_nvvm() except -1 nogil:
4444
global __py_nvvm_init
45+
if __py_nvvm_init:
46+
return 0
47+
4548
cdef int err, driver_ver = 0
4649

4750
with gil, __symbol_lock:
48-
if __py_nvvm_init:
49-
return 0
50-
5151
# Load driver to check version
5252
try:
5353
handle = win32api.LoadLibraryEx("nvcuda.dll", 0, LOAD_LIBRARY_SEARCH_SYSTEM32)

0 commit comments

Comments
 (0)