Skip to content

Commit 5f4125e

Browse files
committed
move init check inside lock
1 parent 47c1c52 commit 5f4125e

7 files changed

Lines changed: 29 additions & 22 deletions

File tree

cuda_bindings/cuda/bindings/_bindings/cydriver.pyx.in

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import os
1414
import sys
1515
import threading
1616
cimport cuda.bindings._bindings.loader as loader
17-
cdef object __symbol_lock = threading.Lock()
17+
cdef object __symbol_lock = threading.RLock()
1818
cdef bint __cuPythonInit = False
1919
{{if 'cuGetErrorString' in found_functions}}cdef void *__cuGetErrorString = NULL{{endif}}
2020
{{if 'cuGetErrorName' in found_functions}}cdef void *__cuGetErrorName = NULL{{endif}}
@@ -488,13 +488,14 @@ cdef bint __cuPythonInit = False
488488

489489
cdef int cuPythonInit() except -1 nogil:
490490
global __cuPythonInit
491-
if __cuPythonInit:
492-
return 0
493-
__cuPythonInit = True
494-
495491
cdef bint usePTDS
496492
cdef char libPath[260]
493+
497494
with gil, __symbol_lock:
495+
if __cuPythonInit:
496+
return 0
497+
__cuPythonInit = True
498+
498499
usePTDS = os.getenv('CUDA_PYTHON_CUDA_PER_THREAD_DEFAULT_STREAM', default=0)
499500

500501
# Load library

cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,11 @@ cdef bint __cuPythonInit = False
4242

4343
cdef int cuPythonInit() except -1 nogil:
4444
global __cuPythonInit
45-
if __cuPythonInit:
46-
return 0
4745

4846
with gil, __symbol_lock:
47+
if __cuPythonInit:
48+
return 0
49+
4950
{{if 'Windows' == platform.system()}}
5051
handle = load_nvidia_dynamic_lib("nvrtc")._handle_uint
5152

cuda_bindings/cuda/bindings/_internal/cufile_linux.pyx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,13 @@ cdef void* load_library(const int driver_ver) except* with gil:
9292

9393
cdef int _check_or_init_cufile() except -1 nogil:
9494
global __py_cufile_init
95-
if __py_cufile_init:
96-
return 0
97-
9895
cdef void* handle = NULL
9996
cdef int err, driver_ver = 0
97+
10098
with gil, __symbol_lock:
99+
if __py_cufile_init:
100+
return 0
101+
101102
# Load driver to check version
102103
handle = dlopen('libcuda.so.1', RTLD_NOW | RTLD_GLOBAL)
103104
if handle == NULL:

cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,13 @@ cdef void* load_library(int driver_ver) except* with gil:
6262

6363
cdef int _check_or_init_nvjitlink() except -1 nogil:
6464
global __py_nvjitlink_init
65-
if __py_nvjitlink_init:
66-
return 0
67-
6865
cdef void* handle = NULL
6966
cdef int err, driver_ver = 0
67+
7068
with gil, __symbol_lock:
69+
if __py_nvjitlink_init:
70+
return 0
71+
7172
# Load driver to check version
7273
handle = dlopen('libcuda.so.1', RTLD_NOW | RTLD_GLOBAL)
7374
if handle == NULL:

cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,12 @@ cdef void* __nvJitLinkVersion = NULL
4343

4444
cdef int _check_or_init_nvjitlink() except -1 nogil:
4545
global __py_nvjitlink_init
46-
if __py_nvjitlink_init:
47-
return 0
48-
4946
cdef int err, driver_ver = 0
47+
5048
with gil, __symbol_lock:
49+
if __py_nvjitlink_init:
50+
return 0
51+
5152
# Load driver to check version
5253
try:
5354
handle = win32api.LoadLibraryEx("nvcuda.dll", 0, LOAD_LIBRARY_SEARCH_SYSTEM32)

cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,13 @@ cdef void* load_library(const int driver_ver) except* with gil:
6161

6262
cdef int _check_or_init_nvvm() except -1 nogil:
6363
global __py_nvvm_init
64-
if __py_nvvm_init:
65-
return 0
66-
6764
cdef void* handle = NULL
6865
cdef int err, driver_ver = 0
66+
6967
with gil, __symbol_lock:
68+
if __py_nvvm_init:
69+
return 0
70+
7071
# Load driver to check version
7172
handle = dlopen('libcuda.so.1', RTLD_NOW | RTLD_GLOBAL)
7273
if handle == NULL:

cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,12 @@ cdef void* __nvvmGetProgramLog = NULL
4242

4343
cdef int _check_or_init_nvvm() except -1 nogil:
4444
global __py_nvvm_init
45-
if __py_nvvm_init:
46-
return 0
47-
4845
cdef int err, driver_ver = 0
46+
4947
with gil, __symbol_lock:
48+
if __py_nvvm_init:
49+
return 0
50+
5051
# Load driver to check version
5152
try:
5253
handle = win32api.LoadLibraryEx("nvcuda.dll", 0, LOAD_LIBRARY_SEARCH_SYSTEM32)

0 commit comments

Comments
 (0)