Skip to content

Commit fb0a430

Browse files
committed
Factor out _abs_path_for_dynamic_library_* and use on handle obtained through "is already loaded" checks
1 parent 9ff46d8 commit fb0a430

1 file changed

Lines changed: 28 additions & 15 deletions

File tree

cuda_bindings/cuda/bindings/_path_finder/load_nvidia_dynamic_library.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,14 @@ def _windows_cuDriverGetVersion() -> int:
7777
return driver_ver.value
7878

7979

80+
def _abs_path_for_dynamic_library_windows(handle: int) -> str:
81+
buf = ctypes.create_unicode_buffer(260)
82+
n_chars = ctypes.windll.kernel32.GetModuleFileNameW(ctypes.wintypes.HMODULE(handle), buf, len(buf))
83+
if n_chars == 0:
84+
raise OSError("GetModuleFileNameW failed")
85+
return buf.value
86+
87+
8088
@functools.cache
8189
def _windows_load_with_dll_basename(name: str) -> Tuple[Optional[int], Optional[str]]:
8290
driver_ver = _windows_cuDriverGetVersion()
@@ -86,33 +94,34 @@ def _windows_load_with_dll_basename(name: str) -> Tuple[Optional[int], Optional[
8694
if dll_names is None:
8795
return None
8896

89-
kernel32 = ctypes.windll.kernel32
90-
9197
for dll_name in dll_names:
92-
handle = kernel32.LoadLibraryW(ctypes.c_wchar_p(dll_name))
98+
handle = ctypes.windll.kernel32.LoadLibraryW(ctypes.c_wchar_p(dll_name))
9399
if handle:
94-
buf = ctypes.create_unicode_buffer(260)
95-
n_chars = kernel32.GetModuleFileNameW(ctypes.wintypes.HMODULE(handle), buf, len(buf))
96-
if n_chars == 0:
97-
raise OSError("GetModuleFileNameW failed")
98-
return handle, buf.value
100+
return handle, _abs_path_for_dynamic_library_windows(handle)
99101

100102
return None, None
101103

102104

103-
def _load_and_report_path_linux(libname, soname: str) -> Tuple[int, str]:
104-
handle = ctypes.CDLL(soname, _LINUX_CDLL_MODE)
105+
def _abs_path_for_dynamic_library_linux(libname: str, handle: int) -> str:
105106
for symbol_name in EXPECTED_LIB_SYMBOLS[libname]:
106107
symbol = getattr(handle, symbol_name, None)
107108
if symbol is not None:
108109
break
109110
else:
110-
raise RuntimeError(f"No expected symbol for {libname=!r}")
111+
return None
111112
addr = ctypes.cast(symbol, ctypes.c_void_p)
112113
info = Dl_info()
113114
if _LIBDL.dladdr(addr, ctypes.byref(info)) == 0:
114-
raise OSError(f"dladdr failed for {soname}")
115-
return handle, info.dli_fname.decode()
115+
raise OSError(f"dladdr failed for {libname=!r}")
116+
return info.dli_fname.decode()
117+
118+
119+
def _load_and_report_path_linux(libname: str, soname: str) -> Tuple[int, str]:
120+
handle = ctypes.CDLL(soname, _LINUX_CDLL_MODE)
121+
abs_path = _abs_path_for_dynamic_library_linux(libname, handle)
122+
if abs_path is None:
123+
raise RuntimeError(f"No expected symbol for {libname=!r}")
124+
return handle, abs_path
116125

117126

118127
@dataclass
@@ -131,15 +140,19 @@ def load_nvidia_dynamic_library(libname: str) -> LoadedDL:
131140
if sys.platform == "win32":
132141
for dll_name in SUPPORTED_WINDOWS_DLLS.get(libname, ()):
133142
try:
134-
return LoadedDL(win32api.GetModuleHandle(dll_name), None, True)
143+
handle = win32api.GetModuleHandle(dll_name)
135144
except pywintypes.error:
136145
pass
146+
else:
147+
return LoadedDL(handle, _abs_path_for_dynamic_library_windows(handle), True)
137148
else:
138149
for soname in SUPPORTED_LINUX_SONAMES.get(libname, ()):
139150
try:
140-
return LoadedDL(ctypes.CDLL(soname, mode=os.RTLD_NOLOAD), None, True)
151+
handle = ctypes.CDLL(soname, mode=os.RTLD_NOLOAD)
141152
except OSError:
142153
pass
154+
else:
155+
return LoadedDL(handle, _abs_path_for_dynamic_library_linux(libname, handle), True)
143156

144157
for dep in DIRECT_DEPENDENCIES.get(libname, ()):
145158
load_nvidia_dynamic_library(dep)

0 commit comments

Comments
 (0)