Skip to content

Commit cd45744

Browse files
committed
make cuPythonInit reentrant + ensure GIL is released when calling underlying C APIs
1 parent 1be81a0 commit cd45744

7 files changed

Lines changed: 591 additions & 583 deletions

File tree

cuda_bindings/cuda/bindings/_bindings/cydriver.pyx.in

Lines changed: 570 additions & 564 deletions
Large diffs are not rendered by default.

cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,11 @@ cdef bint __cuPythonInit = False
4343

4444
cdef int cuPythonInit() except -1 nogil:
4545
global __cuPythonInit
46+
if __cuPythonInit:
47+
return 0
4648

4749
# Load library
4850
with gil, __symbol_lock:
49-
if __cuPythonInit:
50-
return 0
51-
5251
{{if 'Windows' == platform.system()}}
5352
handle = load_nvidia_dynamic_lib("nvrtc")._handle_uint
5453

@@ -237,7 +236,7 @@ cdef int cuPythonInit() except -1 nogil:
237236
{{endif}}
238237

239238
{{else}}
240-
handle = <void*><uintptr_t>load_nvidia_dynamic_lib("nvrtc")._handle_uint
239+
handle = <void*><uintptr_t>(load_nvidia_dynamic_lib("nvrtc")._handle_uint)
241240

242241
# Load function
243242
{{if 'nvrtcGetErrorString' in found_functions}}

cuda_bindings/cuda/bindings/_internal/cufile_linux.pyx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,13 @@ cdef void* load_library(const int driver_ver) except* with gil:
7979

8080
cdef int _check_or_init_cufile() except -1 nogil:
8181
global __py_cufile_init
82+
if __py_cufile_init:
83+
return 0
84+
85+
cdef void* handle = NULL
8286
cdef int err, driver_ver = 0
8387

8488
with gil, __symbol_lock:
85-
if __py_cufile_init:
86-
return 0
87-
8889
# Load driver to check version
8990
handle = dlopen('libcuda.so.1', RTLD_NOW | RTLD_GLOBAL)
9091
if handle == NULL:

cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,13 @@ cdef void* load_library(int driver_ver) except* with gil:
6161

6262
cdef int _check_or_init_nvjitlink() except -1 nogil:
6363
global __py_nvjitlink_init
64+
if __py_nvjitlink_init:
65+
return 0
66+
67+
cdef void* handle = NULL
6468
cdef int err, driver_ver = 0
6569

6670
with gil, __symbol_lock:
67-
if __py_nvjitlink_init:
68-
return 0
69-
7071
# Load driver to check version
7172
handle = dlopen('libcuda.so.1', RTLD_NOW | RTLD_GLOBAL)
7273
if handle == NULL:

cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ cdef void* __nvJitLinkVersion = NULL
4343

4444
cdef int _check_or_init_nvjitlink() except -1 nogil:
4545
global __py_nvjitlink_init
46+
if __py_nvjitlink_init:
47+
return 0
48+
4649
cdef int err, driver_ver = 0
4750

4851
with gil, __symbol_lock:
49-
if __py_nvjitlink_init:
50-
return 0
51-
5252
# Load driver to check version
5353
try:
5454
handle = win32api.LoadLibraryEx("nvcuda.dll", 0, LOAD_LIBRARY_SEARCH_SYSTEM32)

cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,13 @@ cdef void* load_library(const int driver_ver) except* with gil:
6060

6161
cdef int _check_or_init_nvvm() except -1 nogil:
6262
global __py_nvvm_init
63+
if __py_nvvm_init:
64+
return 0
65+
66+
cdef void* handle = NULL
6367
cdef int err, driver_ver = 0
6468

6569
with gil, __symbol_lock:
66-
if __py_nvvm_init:
67-
return 0
68-
6970
# Load driver to check version
7071
handle = dlopen('libcuda.so.1', RTLD_NOW | RTLD_GLOBAL)
7172
if handle == NULL:

cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@ cdef void* __nvvmGetProgramLog = NULL
4242

4343
cdef int _check_or_init_nvvm() except -1 nogil:
4444
global __py_nvvm_init
45+
if __py_nvvm_init:
46+
return 0
47+
4548
cdef int err, driver_ver = 0
4649

4750
with gil, __symbol_lock:
48-
if __py_nvvm_init:
49-
return 0
50-
5151
# Load driver to check version
5252
try:
5353
handle = win32api.LoadLibraryEx("nvcuda.dll", 0, LOAD_LIBRARY_SEARCH_SYSTEM32)

0 commit comments

Comments
 (0)