@@ -77,6 +77,14 @@ def _windows_cuDriverGetVersion() -> int:
7777 return driver_ver .value
7878
7979
80+ def _abs_path_for_dynamic_library_windows (handle : int ) -> str :
81+ buf = ctypes .create_unicode_buffer (260 )
82+ n_chars = ctypes .windll .kernel32 .GetModuleFileNameW (ctypes .wintypes .HMODULE (handle ), buf , len (buf ))
83+ if n_chars == 0 :
84+ raise OSError ("GetModuleFileNameW failed" )
85+ return buf .value
86+
87+
8088@functools .cache
8189def _windows_load_with_dll_basename (name : str ) -> Tuple [Optional [int ], Optional [str ]]:
8290 driver_ver = _windows_cuDriverGetVersion ()
@@ -86,33 +94,34 @@ def _windows_load_with_dll_basename(name: str) -> Tuple[Optional[int], Optional[
8694 if dll_names is None :
8795 return None
8896
89- kernel32 = ctypes .windll .kernel32
90-
9197 for dll_name in dll_names :
92- handle = kernel32 .LoadLibraryW (ctypes .c_wchar_p (dll_name ))
98+ handle = ctypes . windll . kernel32 .LoadLibraryW (ctypes .c_wchar_p (dll_name ))
9399 if handle :
94- buf = ctypes .create_unicode_buffer (260 )
95- n_chars = kernel32 .GetModuleFileNameW (ctypes .wintypes .HMODULE (handle ), buf , len (buf ))
96- if n_chars == 0 :
97- raise OSError ("GetModuleFileNameW failed" )
98- return handle , buf .value
100+ return handle , _abs_path_for_dynamic_library_windows (handle )
99101
100102 return None , None
101103
102104
103- def _load_and_report_path_linux (libname , soname : str ) -> Tuple [int , str ]:
104- handle = ctypes .CDLL (soname , _LINUX_CDLL_MODE )
105+ def _abs_path_for_dynamic_library_linux (libname : str , handle : int ) -> str :
105106 for symbol_name in EXPECTED_LIB_SYMBOLS [libname ]:
106107 symbol = getattr (handle , symbol_name , None )
107108 if symbol is not None :
108109 break
109110 else :
110- raise RuntimeError ( f"No expected symbol for { libname = !r } " )
111+ return None
111112 addr = ctypes .cast (symbol , ctypes .c_void_p )
112113 info = Dl_info ()
113114 if _LIBDL .dladdr (addr , ctypes .byref (info )) == 0 :
114- raise OSError (f"dladdr failed for { soname } " )
115- return handle , info .dli_fname .decode ()
115+ raise OSError (f"dladdr failed for { libname = !r} " )
116+ return info .dli_fname .decode ()
117+
118+
119+ def _load_and_report_path_linux (libname : str , soname : str ) -> Tuple [int , str ]:
120+ handle = ctypes .CDLL (soname , _LINUX_CDLL_MODE )
121+ abs_path = _abs_path_for_dynamic_library_linux (libname , handle )
122+ if abs_path is None :
123+ raise RuntimeError (f"No expected symbol for { libname = !r} " )
124+ return handle , abs_path
116125
117126
118127@dataclass
@@ -131,15 +140,19 @@ def load_nvidia_dynamic_library(libname: str) -> LoadedDL:
131140 if sys .platform == "win32" :
132141 for dll_name in SUPPORTED_WINDOWS_DLLS .get (libname , ()):
133142 try :
134- return LoadedDL ( win32api .GetModuleHandle (dll_name ), None , True )
143+ handle = win32api .GetModuleHandle (dll_name )
135144 except pywintypes .error :
136145 pass
146+ else :
147+ return LoadedDL (handle , _abs_path_for_dynamic_library_windows (handle ), True )
137148 else :
138149 for soname in SUPPORTED_LINUX_SONAMES .get (libname , ()):
139150 try :
140- return LoadedDL ( ctypes .CDLL (soname , mode = os .RTLD_NOLOAD ), None , True )
151+ handle = ctypes .CDLL (soname , mode = os .RTLD_NOLOAD )
141152 except OSError :
142153 pass
154+ else :
155+ return LoadedDL (handle , _abs_path_for_dynamic_library_linux (libname , handle ), True )
143156
144157 for dep in DIRECT_DEPENDENCIES .get (libname , ()):
145158 load_nvidia_dynamic_library (dep )
0 commit comments