Skip to content

Commit 9ff46d8

Browse files
committed
Introduce dataclass LoadedDL as return type for load_nvidia_dynamic_library()
1 parent c1a4983 commit 9ff46d8

7 files changed

Lines changed: 38 additions & 41 deletions

File tree

cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ cdef int cuPythonInit() except -1 nogil:
5656

5757
{{if 'Windows' == platform.system()}}
5858
with gil:
59-
handle = path_finder.load_nvidia_dynamic_library("nvrtc")
59+
handle = path_finder.load_nvidia_dynamic_library("nvrtc").handle
6060
{{if 'nvrtcGetErrorString' in found_functions}}
6161
try:
6262
global __nvrtcGetErrorString
@@ -242,7 +242,7 @@ cdef int cuPythonInit() except -1 nogil:
242242

243243
{{else}}
244244
with gil:
245-
handle = <void*><uintptr_t>path_finder.load_nvidia_dynamic_library("nvrtc")
245+
handle = <void*><uintptr_t>path_finder.load_nvidia_dynamic_library("nvrtc").handle
246246
{{if 'nvrtcGetErrorString' in found_functions}}
247247
global __nvrtcGetErrorString
248248
__nvrtcGetErrorString = dlfcn.dlsym(handle, 'nvrtcGetErrorString')

cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ cdef void* __nvJitLinkVersion = NULL
5353

5454

5555
cdef void* load_library(int driver_ver) except* with gil:
56-
cdef uintptr_t handle = path_finder.load_nvidia_dynamic_library("nvJitLink")
56+
cdef uintptr_t handle = path_finder.load_nvidia_dynamic_library("nvJitLink").handle
5757
return <void*>handle
5858

5959

cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ cdef void* __nvJitLinkVersion = NULL
4040

4141

4242
cdef void* load_library(int driver_ver) except* with gil:
43-
cdef intptr_t handle = path_finder.load_nvidia_dynamic_library("nvJitLink")
43+
cdef intptr_t handle = path_finder.load_nvidia_dynamic_library("nvJitLink").handle
4444
return <void*>handle
4545

4646

cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ cdef void* __nvvmGetProgramLog = NULL
5151

5252

5353
cdef void* load_library(const int driver_ver) except* with gil:
54-
cdef uintptr_t handle = path_finder.load_nvidia_dynamic_library("nvvm")
54+
cdef uintptr_t handle = path_finder.load_nvidia_dynamic_library("nvvm").handle
5555
return <void*>handle
5656

5757

cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ cdef void* __nvvmGetProgramLog = NULL
3838

3939

4040
cdef void* load_library(int driver_ver) except* with gil:
41-
cdef intptr_t handle = path_finder.load_nvidia_dynamic_library("nvvm")
41+
cdef intptr_t handle = path_finder.load_nvidia_dynamic_library("nvvm").handle
4242
return <void*>handle
4343

4444

cuda_bindings/cuda/bindings/_path_finder/load_nvidia_dynamic_library.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import functools
77
import os
88
import sys
9+
from dataclasses import dataclass
910
from typing import Optional, Tuple
1011

1112
if sys.platform == "win32":
@@ -114,19 +115,29 @@ def _load_and_report_path_linux(libname, soname: str) -> Tuple[int, str]:
114115
return handle, info.dli_fname.decode()
115116

116117

118+
@dataclass
119+
class LoadedDL:
120+
# ATTENTION: To convert `handle` back to `void*` in cython:
121+
# Linux: `cdef void* ptr = <void*><uintptr_t>`
122+
# Windows: `cdef void* ptr = <void*><intptr_t>`
123+
handle: int
124+
abs_path: Optional[str]
125+
was_already_loaded_from_elsewhere: bool
126+
127+
117128
@functools.cache
118-
def load_nvidia_dynamic_library(libname: str) -> int:
129+
def load_nvidia_dynamic_library(libname: str) -> LoadedDL:
119130
# Detect if the library was loaded already in some other way (i.e. not via this function).
120131
if sys.platform == "win32":
121-
for dll_name in SUPPORTED_WINDOWS_DLLS.get(libname):
132+
for dll_name in SUPPORTED_WINDOWS_DLLS.get(libname, ()):
122133
try:
123-
return win32api.GetModuleHandle(dll_name)
134+
return LoadedDL(win32api.GetModuleHandle(dll_name), None, True)
124135
except pywintypes.error:
125136
pass
126137
else:
127-
for soname in SUPPORTED_LINUX_SONAMES.get(libname):
138+
for soname in SUPPORTED_LINUX_SONAMES.get(libname, ()):
128139
try:
129-
return ctypes.CDLL(soname, mode=os.RTLD_NOLOAD)
140+
return LoadedDL(ctypes.CDLL(soname, mode=os.RTLD_NOLOAD), None, True)
130141
except OSError:
131142
pass
132143

@@ -138,18 +149,14 @@ def load_nvidia_dynamic_library(libname: str) -> int:
138149
if sys.platform == "win32":
139150
handle, abs_path = _windows_load_with_dll_basename(libname)
140151
if handle:
141-
# Use `cdef void* ptr = <void*><intptr_t>` in cython to convert back to void*
142-
print(f"SYSTEM ABS_PATH for {libname=!r}: {abs_path}", flush=True)
143-
return handle
152+
return LoadedDL(handle, abs_path, False)
144153
else:
145154
try:
146155
handle, abs_path = _load_and_report_path_linux(libname, found.lib_searched_for)
147-
except OSError as e:
148-
print(f"SYSTEM OSError for {libname=!r}: {e!r}", flush=True)
156+
except OSError:
157+
pass
149158
else:
150-
# Use `cdef void* ptr = <void*><uintptr_t>` in cython to convert back to void*
151-
print(f"SYSTEM ABS_PATH for {libname=!r}: {abs_path}", flush=True)
152-
return handle._handle # C unsigned int
159+
return LoadedDL(handle._handle, abs_path, False)
153160
found.raise_if_abs_path_is_None()
154161

155162
if sys.platform == "win32":
@@ -160,14 +167,10 @@ def load_nvidia_dynamic_library(libname: str) -> int:
160167
handle = win32api.LoadLibraryEx(found.abs_path, 0, flags)
161168
except pywintypes.error as e:
162169
raise RuntimeError(f"Failed to load DLL at {found.abs_path}: {e}") from e
163-
# Use `cdef void* ptr = <void*><intptr_t>` in cython to convert back to void*
164-
print(f"FOUND ABS_PATH for {libname=!r}: {found.abs_path}", flush=True)
165-
return handle # C signed int, matches win32api.GetProcAddress
170+
return LoadedDL(handle, found.abs_path, False)
166171
else:
167172
try:
168173
handle = ctypes.CDLL(found.abs_path, _LINUX_CDLL_MODE)
169174
except OSError as e:
170175
raise RuntimeError(f"Failed to dlopen {found.abs_path}: {e}") from e
171-
# Use `cdef void* ptr = <void*><uintptr_t>` in cython to convert back to void*
172-
print(f"FOUND ABS_PATH for {libname=!r}: {found.abs_path}", flush=True)
173-
return handle._handle # C unsigned int
176+
return LoadedDL(handle._handle, found.abs_path, False)

cuda_bindings/tests/path_finder.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import sys
2+
import traceback
23

34
from cuda.bindings import path_finder
45
from cuda.bindings._path_finder import cuda_paths, supported_libs
5-
from cuda.bindings._path_finder.find_nvidia_dynamic_library import find_nvidia_dynamic_library
66

77
ALL_LIBNAMES = path_finder.SUPPORTED_LIBNAMES + supported_libs.PARTIALLY_SUPPORTED_LIBNAMES
88

@@ -11,26 +11,20 @@ def run(args):
1111
assert len(args) == 0
1212

1313
paths = cuda_paths.get_cuda_paths()
14-
1514
for k, v in paths.items():
1615
print(f"{k}: {v}", flush=True)
1716
print()
1817

19-
for libname in supported_libs.SUPPORTED_WINDOWS_DLLS:
20-
if libname not in ALL_LIBNAMES:
21-
print(f"MISSING IN SUPPORTED_LIBNAMES: {libname}")
22-
2318
for libname in ALL_LIBNAMES:
24-
print(libname)
25-
dlls = supported_libs.SUPPORTED_WINDOWS_DLLS.get(libname)
26-
if dlls is None:
27-
print(f"MISSING IN SUPPORTED_WINDOWS_DLLS: {libname}")
28-
for fun in (find_nvidia_dynamic_library, path_finder.load_nvidia_dynamic_library):
29-
try:
30-
out = fun(libname)
31-
except Exception as e:
32-
out = f"EXCEPTION: {type(e)} {str(e)}"
33-
print(out)
19+
print(f"{libname=}")
20+
try:
21+
loaded_dl = path_finder.load_nvidia_dynamic_library(libname)
22+
except Exception as e:
23+
print(f"EXCEPTION for {libname=}:")
24+
traceback.print_exc(file=sys.stdout)
25+
else:
26+
print(f" {loaded_dl.abs_path=!r}")
27+
print(f" {loaded_dl.was_already_loaded_from_elsewhere=!r}")
3428
print()
3529

3630

0 commit comments

Comments
 (0)