Skip to content

Commit 1a118b7

Browse files
authored
Update external library bindings (NVIDIA#1243)
* Update cybind-generated bindings * Remove safe_decode_string
1 parent 6ed6f6c commit 1a118b7

7 files changed

Lines changed: 75 additions & 28 deletions

File tree

cuda_bindings/cuda/bindings/_internal/cufile_linux.pyx

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,15 @@ cdef void* load_library() except* with gil:
111111
return <void*>handle
112112

113113

114-
cdef int __check_or_init_cufile() except -1 nogil:
114+
cdef int _init_cufile() except -1 nogil:
115115
global __py_cufile_init
116116

117117
cdef void* handle = NULL
118118

119119
with gil, __symbol_lock:
120+
# Recheck the flag after obtaining the locks
121+
if __py_cufile_init:
122+
return 0
120123
# Load function
121124
global __cuFileHandleRegister
122125
__cuFileHandleRegister = dlsym(RTLD_DEFAULT, 'cuFileHandleRegister')
@@ -427,7 +430,7 @@ cdef inline int _check_or_init_cufile() except -1 nogil:
427430
if __py_cufile_init:
428431
return 0
429432

430-
return __check_or_init_cufile()
433+
return _init_cufile()
431434

432435

433436
cdef dict func_ptrs = None

cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,16 @@ cdef void* load_library() except* with gil:
8080
return <void*>handle
8181

8282

83-
cdef int __check_or_init_nvjitlink() except -1 nogil:
83+
cdef int _init_nvjitlink() except -1 nogil:
8484
global __py_nvjitlink_init
8585

8686
cdef void* handle = NULL
8787

8888
with gil, __symbol_lock:
89+
# Recheck the flag after obtaining the locks
90+
if __py_nvjitlink_init:
91+
return 0
92+
8993
# Load function
9094
global __nvJitLinkCreate
9195
__nvJitLinkCreate = dlsym(RTLD_DEFAULT, 'nvJitLinkCreate')
@@ -193,7 +197,7 @@ cdef inline int _check_or_init_nvjitlink() except -1 nogil:
193197
if __py_nvjitlink_init:
194198
return 0
195199

196-
return __check_or_init_nvjitlink()
200+
return _init_nvjitlink()
197201

198202
cdef dict func_ptrs = None
199203

cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,14 @@ cdef void* __nvJitLinkGetInfoLog = NULL
9393
cdef void* __nvJitLinkVersion = NULL
9494

9595

96-
cdef int __check_or_init_nvjitlink() except -1 nogil:
96+
cdef int _init_nvjitlink() except -1 nogil:
9797
global __py_nvjitlink_init
9898

9999
with gil, __symbol_lock:
100+
# Recheck the flag after obtaining the locks
101+
if __py_nvjitlink_init:
102+
return 0
103+
100104
# Load library
101105
handle = load_nvidia_dynamic_lib("nvJitLink")._handle_uint
102106

@@ -151,7 +155,7 @@ cdef inline int _check_or_init_nvjitlink() except -1 nogil:
151155
if __py_nvjitlink_init:
152156
return 0
153157

154-
return __check_or_init_nvjitlink()
158+
return _init_nvjitlink()
155159

156160

157161
cdef dict func_ptrs = None

cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,16 @@ cdef void* load_library() except* with gil:
7979
return <void*>handle
8080

8181

82-
cdef int __check_or_init_nvvm() except -1 nogil:
82+
cdef int _init_nvvm() except -1 nogil:
8383
global __py_nvvm_init
8484

8585
cdef void* handle = NULL
8686

8787
with gil, __symbol_lock:
88+
# Recheck the flag after obtaining the locks
89+
if __py_nvvm_init:
90+
return 0
91+
8892
# Load function
8993
global __nvvmGetErrorString
9094
__nvvmGetErrorString = dlsym(RTLD_DEFAULT, 'nvvmGetErrorString')
@@ -185,7 +189,7 @@ cdef inline int _check_or_init_nvvm() except -1 nogil:
185189
if __py_nvvm_init:
186190
return 0
187191

188-
return __check_or_init_nvvm()
192+
return _init_nvvm()
189193

190194

191195
cdef dict func_ptrs = None

cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,14 @@ cdef void* __nvvmGetProgramLogSize = NULL
9292
cdef void* __nvvmGetProgramLog = NULL
9393

9494

95-
cdef int __check_or_init_nvvm() except -1 nogil:
95+
cdef int _init_nvvm() except -1 nogil:
9696
global __py_nvvm_init
9797

9898
with gil, __symbol_lock:
99+
# Recheck the flag after obtaining the locks
100+
if __py_nvvm_init:
101+
return 0
102+
99103
# Load library
100104
handle = load_nvidia_dynamic_lib("nvvm")._handle_uint
101105

@@ -147,7 +151,7 @@ cdef inline int _check_or_init_nvvm() except -1 nogil:
147151
if __py_nvvm_init:
148152
return 0
149153

150-
return __check_or_init_nvvm()
154+
return _init_nvvm()
151155

152156

153157
cdef dict func_ptrs = None

cuda_bindings/cuda/bindings/cufile.pyx

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import numpy as _numpy
1212
from cpython cimport buffer as _buffer
1313
from cpython.memoryview cimport PyMemoryView_FromMemory
1414
from enum import IntEnum as _IntEnum
15+
cimport cpython
1516

1617
import cython
1718

@@ -54,6 +55,10 @@ cdef class _py_anon_pod1:
5455
"""Get the pointer address to the data as Python :class:`int`."""
5556
return self._data.ctypes.data
5657

58+
cdef intptr_t _get_ptr(self):
59+
"""Get the pointer address to the data as Python :class:`int`."""
60+
return self._data.ctypes.data
61+
5762
def __int__(self):
5863
return self._data.ctypes.data
5964

@@ -157,6 +162,10 @@ cdef class _py_anon_pod3:
157162
"""Get the pointer address to the data as Python :class:`int`."""
158163
return self._data.ctypes.data
159164

165+
cdef intptr_t _get_ptr(self):
166+
"""Get the pointer address to the data as Python :class:`int`."""
167+
return self._data.ctypes.data
168+
160169
def __int__(self):
161170
return self._data.ctypes.data
162171

@@ -286,6 +295,10 @@ cdef class IOEvents:
286295
"""Get the pointer address to the data as Python :class:`int`."""
287296
return self._data.ctypes.data
288297

298+
cdef intptr_t _get_ptr(self):
299+
"""Get the pointer address to the data as Python :class:`int`."""
300+
return self._data.ctypes.data
301+
289302
def __int__(self):
290303
if self._data.size > 1:
291304
raise TypeError("int() argument must be a bytes-like object of size 1. "
@@ -422,6 +435,10 @@ cdef class OpCounter:
422435
"""Get the pointer address to the data as Python :class:`int`."""
423436
return self._data.ctypes.data
424437

438+
cdef intptr_t _get_ptr(self):
439+
"""Get the pointer address to the data as Python :class:`int`."""
440+
return self._data.ctypes.data
441+
425442
def __int__(self):
426443
return self._data.ctypes.data
427444

@@ -551,6 +568,10 @@ cdef class PerGpuStats:
551568
"""Get the pointer address to the data as Python :class:`int`."""
552569
return self._data.ctypes.data
553570

571+
cdef intptr_t _get_ptr(self):
572+
"""Get the pointer address to the data as Python :class:`int`."""
573+
return self._data.ctypes.data
574+
554575
def __int__(self):
555576
return self._data.ctypes.data
556577

@@ -914,6 +935,10 @@ cdef class Descr:
914935
"""Get the pointer address to the data as Python :class:`int`."""
915936
return self._data.ctypes.data
916937

938+
cdef intptr_t _get_ptr(self):
939+
"""Get the pointer address to the data as Python :class:`int`."""
940+
return self._data.ctypes.data
941+
917942
def __int__(self):
918943
if self._data.size > 1:
919944
raise TypeError("int() argument must be a bytes-like object of size 1. "
@@ -1052,6 +1077,10 @@ cdef class _py_anon_pod2:
10521077
"""Get the pointer address to the data as Python :class:`int`."""
10531078
return self._data.ctypes.data
10541079

1080+
cdef intptr_t _get_ptr(self):
1081+
"""Get the pointer address to the data as Python :class:`int`."""
1082+
return self._data.ctypes.data
1083+
10551084
def __int__(self):
10561085
return self._data.ctypes.data
10571086

@@ -1185,6 +1214,10 @@ cdef class StatsLevel1:
11851214
"""Get the pointer address to the data as Python :class:`int`."""
11861215
return self._data.ctypes.data
11871216

1217+
cdef intptr_t _get_ptr(self):
1218+
"""Get the pointer address to the data as Python :class:`int`."""
1219+
return self._data.ctypes.data
1220+
11881221
def __int__(self):
11891222
return self._data.ctypes.data
11901223

@@ -1667,6 +1700,10 @@ cdef class IOParams:
16671700
"""Get the pointer address to the data as Python :class:`int`."""
16681701
return self._data.ctypes.data
16691702

1703+
cdef intptr_t _get_ptr(self):
1704+
"""Get the pointer address to the data as Python :class:`int`."""
1705+
return self._data.ctypes.data
1706+
16701707
def __int__(self):
16711708
if self._data.size > 1:
16721709
raise TypeError("int() argument must be a bytes-like object of size 1. "
@@ -1824,6 +1861,10 @@ cdef class StatsLevel2:
18241861
"""Get the pointer address to the data as Python :class:`int`."""
18251862
return self._data.ctypes.data
18261863

1864+
cdef intptr_t _get_ptr(self):
1865+
"""Get the pointer address to the data as Python :class:`int`."""
1866+
return self._data.ctypes.data
1867+
18271868
def __int__(self):
18281869
return self._data.ctypes.data
18291870

@@ -1935,6 +1976,10 @@ cdef class StatsLevel3:
19351976
"""Get the pointer address to the data as Python :class:`int`."""
19361977
return self._data.ctypes.data
19371978

1979+
cdef intptr_t _get_ptr(self):
1980+
"""Get the pointer address to the data as Python :class:`int`."""
1981+
return self._data.ctypes.data
1982+
19381983
def __int__(self):
19391984
return self._data.ctypes.data
19401985

@@ -2458,7 +2503,7 @@ cpdef str get_parameter_string(int param, int len):
24582503
with nogil:
24592504
__status__ = cuFileGetParameterString(<_StringConfigParameter>param, desc_str, len)
24602505
check_status(__status__)
2461-
return _desc_str_.decode()
2506+
return cpython.PyUnicode_FromString(desc_str)
24622507

24632508

24642509
cpdef set_parameter_size_t(int param, size_t value):

cuda_bindings/tests/test_cufile.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -121,20 +121,6 @@ def isSupportedFilesystem():
121121
pytestmark = pytest.mark.skipif(not cufileLibraryAvailable(), reason="cuFile library not available on this system")
122122

123123

124-
def safe_decode_string(raw_value):
125-
"""Safely decode a string value from ctypes buffer."""
126-
# Find null terminator if present
127-
null_pos = raw_value.find(b"\x00")
128-
if null_pos != -1:
129-
raw_value = raw_value[:null_pos]
130-
# Decode with error handling
131-
try:
132-
return raw_value.decode("utf-8", errors="ignore")
133-
except UnicodeDecodeError:
134-
# If UTF-8 fails, try to decode as bytes
135-
return str(raw_value)
136-
137-
138124
def test_cufile_success_defined():
139125
"""Check if CUFILE_SUCCESS is defined in OpError enum."""
140126
assert hasattr(cufile.OpError, "SUCCESS")
@@ -1774,8 +1760,6 @@ def test_set_get_parameter_string(tmp_path):
17741760

17751761
def test_param(param, val, default_val):
17761762
orig_val = cufile.get_parameter_string(param, 256)
1777-
# Use safe_decode_string to handle null terminators and padding
1778-
orig_val = safe_decode_string(orig_val.encode("utf-8"))
17791763

17801764
val_b = val.encode("utf-8")
17811765
val_buf = ctypes.create_string_buffer(val_b)
@@ -1787,7 +1771,6 @@ def test_param(param, val, default_val):
17871771
# Round-trip test
17881772
cufile.set_parameter_string(param, int(ctypes.addressof(val_buf)))
17891773
retrieved_val = cufile.get_parameter_string(param, 256)
1790-
retrieved_val = safe_decode_string(retrieved_val.encode("utf-8"))
17911774
assert retrieved_val == val
17921775

17931776
# Restore

0 commit comments

Comments
 (0)