Skip to content

Commit 58b7074

Browse files
committed
move init check inside of lock
1 parent 2459e09 commit 58b7074

7 files changed

Lines changed: 28 additions & 21 deletions

File tree

cuda_bindings/cuda/bindings/_bindings/cydriver.pyx.in

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -492,13 +492,14 @@ cdef bint __cuPythonInit = False
492492

493493
cdef int cuPythonInit() except -1 nogil:
494494
global __cuPythonInit
495-
if __cuPythonInit:
496-
return 0
497-
__cuPythonInit = True
498-
499495
cdef bint usePTDS
500496
cdef char libPath[260]
497+
501498
with gil, __symbol_lock:
499+
if __cuPythonInit:
500+
return 0
501+
__cuPythonInit = True
502+
502503
usePTDS = os.getenv('CUDA_PYTHON_CUDA_PER_THREAD_DEFAULT_STREAM', default=0)
503504

504505
# Load library

cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,12 @@ cdef bint __cuPythonInit = False
4343

4444
cdef int cuPythonInit() except -1 nogil:
4545
global __cuPythonInit
46-
if __cuPythonInit:
47-
return 0
4846

4947
# Load library
5048
with gil, __symbol_lock:
49+
if __cuPythonInit:
50+
return 0
51+
5152
{{if 'Windows' == platform.system()}}
5253
handle = load_nvidia_dynamic_lib("nvrtc")._handle_uint
5354

cuda_bindings/cuda/bindings/_internal/cufile_linux.pyx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,12 @@ cdef void* load_library(const int driver_ver) except* with gil:
7979

8080
cdef int _check_or_init_cufile() except -1 nogil:
8181
global __py_cufile_init
82-
if __py_cufile_init:
83-
return 0
84-
8582
cdef int err, driver_ver = 0
83+
8684
with gil, __symbol_lock:
85+
if __py_cufile_init:
86+
return 0
87+
8788
# Load driver to check version
8889
handle = dlopen('libcuda.so.1', RTLD_NOW | RTLD_GLOBAL)
8990
if handle == NULL:

cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,12 @@ cdef void* load_library(int driver_ver) except* with gil:
6161

6262
cdef int _check_or_init_nvjitlink() except -1 nogil:
6363
global __py_nvjitlink_init
64-
if __py_nvjitlink_init:
65-
return 0
66-
6764
cdef int err, driver_ver = 0
65+
6866
with gil, __symbol_lock:
67+
if __py_nvjitlink_init:
68+
return 0
69+
6970
# Load driver to check version
7071
handle = dlopen('libcuda.so.1', RTLD_NOW | RTLD_GLOBAL)
7172
if handle == NULL:

cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,12 @@ cdef void* __nvJitLinkVersion = NULL
4343

4444
cdef int _check_or_init_nvjitlink() except -1 nogil:
4545
global __py_nvjitlink_init
46-
if __py_nvjitlink_init:
47-
return 0
48-
4946
cdef int err, driver_ver = 0
47+
5048
with gil, __symbol_lock:
49+
if __py_nvjitlink_init:
50+
return 0
51+
5152
# Load driver to check version
5253
try:
5354
handle = win32api.LoadLibraryEx("nvcuda.dll", 0, LOAD_LIBRARY_SEARCH_SYSTEM32)

cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,12 @@ cdef void* load_library(const int driver_ver) except* with gil:
6060

6161
cdef int _check_or_init_nvvm() except -1 nogil:
6262
global __py_nvvm_init
63-
if __py_nvvm_init:
64-
return 0
65-
6663
cdef int err, driver_ver = 0
64+
6765
with gil, __symbol_lock:
66+
if __py_nvvm_init:
67+
return 0
68+
6869
# Load driver to check version
6970
handle = dlopen('libcuda.so.1', RTLD_NOW | RTLD_GLOBAL)
7071
if handle == NULL:

cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,12 @@ cdef void* __nvvmGetProgramLog = NULL
4242

4343
cdef int _check_or_init_nvvm() except -1 nogil:
4444
global __py_nvvm_init
45-
if __py_nvvm_init:
46-
return 0
47-
4845
cdef int err, driver_ver = 0
46+
4947
with gil, __symbol_lock:
48+
if __py_nvvm_init:
49+
return 0
50+
5051
# Load driver to check version
5152
try:
5253
handle = win32api.LoadLibraryEx("nvcuda.dll", 0, LOAD_LIBRARY_SEARCH_SYSTEM32)

0 commit comments

Comments
 (0)