55import ctypes
66import functools
77import sys
8+ from typing import Optional , Tuple
89
910if sys .platform == "win32" :
1011 import ctypes .wintypes
@@ -60,24 +61,29 @@ def _windows_cuDriverGetVersion() -> int:
6061
6162
6263@functools .cache
63- def _windows_load_with_dll_basename (name : str ) -> int :
64+ def _windows_load_with_dll_basename (name : str ) -> Tuple [ Optional [ int ], Optional [ str ]] :
6465 driver_ver = _windows_cuDriverGetVersion ()
6566 del driver_ver # Keeping this here because it will probably be needed in the future.
6667
6768 dll_names = SUPPORTED_WINDOWS_DLLS .get (name )
6869 if dll_names is None :
6970 return None
7071
72+ kernel32 = ctypes .windll .kernel32
73+
7174 for dll_name in dll_names :
72- try :
73- return win32api .LoadLibrary (dll_name )
74- except pywintypes .error :
75- pass
75+ handle = kernel32 .LoadLibraryW (ctypes .c_wchar_p (dll_name ))
76+ if handle :
77+ buf = ctypes .create_unicode_buffer (260 )
78+ n_chars = kernel32 .GetModuleFileNameW (ctypes .wintypes .HMODULE (handle ), buf , len (buf ))
79+ if n_chars == 0 :
80+ raise OSError ("GetModuleFileNameW failed" )
81+ return handle , buf .value
7682
77- return None
83+ return None , None
7884
7985
80- def _load_and_report_path_linux (libname , soname : str ) -> ( int , str ) :
86+ def _load_and_report_path_linux (libname , soname : str ) -> Tuple [ int , str ] :
8187 handle = ctypes .CDLL (soname , _LINUX_CDLL_MODE )
8288 for symbol_name in EXPECTED_LIB_SYMBOLS [libname ]:
8389 symbol = getattr (handle , symbol_name , None )
@@ -100,9 +106,10 @@ def load_nvidia_dynamic_library(libname: str) -> int:
100106 found = _find_nvidia_dynamic_library (libname )
101107 if found .abs_path is None :
102108 if sys .platform == "win32" :
103- handle = _windows_load_with_dll_basename (libname )
109+ handle , abs_path = _windows_load_with_dll_basename (libname )
104110 if handle :
105111 # Use `cdef void* ptr = <void*><intptr_t>` in cython to convert back to void*
112+ print (f"SYSTEM ABS_PATH for { libname = !r} : { abs_path } " , flush = True )
106113 return handle
107114 else :
108115 try :
@@ -122,6 +129,7 @@ def load_nvidia_dynamic_library(libname: str) -> int:
122129 except pywintypes .error as e :
123130 raise RuntimeError (f"Failed to load DLL at { found .abs_path } : { e } " ) from e
124131 # Use `cdef void* ptr = <void*><intptr_t>` in cython to convert back to void*
132+ print (f"FOUND ABS_PATH for { libname = !r} : { found .abs_path } " , flush = True )
125133 return handle # C signed int, matches win32api.GetProcAddress
126134 else :
127135 try :
0 commit comments