@@ -41,7 +41,7 @@ cdef void* __nvvmGetProgramLog = NULL
4141
4242
4343cdef inline list get_site_packages():
44- return [site.getusersitepackages()] + site.getsitepackages()
44+ return [site.getusersitepackages()] + site.getsitepackages() + [ " conda " ]
4545
4646
4747cdef load_library(const int driver_ver):
@@ -50,44 +50,42 @@ cdef load_library(const int driver_ver):
5050 for suffix in get_nvvm_dso_version_suffix(driver_ver):
5151 if len (suffix) == 0 :
5252 continue
53- dll_name = " nvvm64_40_0"
53+ dll_name = " nvvm64_40_0.dll "
5454
5555 # First check if the DLL has been loaded by 3rd parties
5656 try :
57- handle = win32api.GetModuleHandle(dll_name)
57+ return win32api.GetModuleHandle(dll_name)
5858 except :
5959 pass
60- else :
61- break
6260
63- # Next, check if DLLs are installed via pip
61+ # Next, check if DLLs are installed via pip or conda
6462 for sp in get_site_packages():
65- mod_path = os.path.join(sp, " nvidia" , " cuda_nvcc" , " nvvm" , " bin" )
66- if not os.path.isdir(mod_path):
67- continue
68- os.add_dll_directory(mod_path)
69- try :
70- handle = win32api.LoadLibraryEx(
71- # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
72- os.path.join(mod_path, dll_name),
73- 0 , LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
74- except :
75- pass
76- else :
77- break
63+ if sp == " conda" :
64+ # nvvm is not under $CONDA_PREFIX/lib, so it's not in the default search path
65+ conda_prefix = os.environ.get(" CONDA_PREFIX" )
66+ if conda_prefix is None :
67+ continue
68+ mod_path = os.path.join(conda_prefix, " Library" , " nvvm" , " bin" )
69+ else :
70+ mod_path = os.path.join(sp, " nvidia" , " cuda_nvcc" , " nvvm" , " bin" )
71+ if os.path.isdir(mod_path):
72+ os.add_dll_directory(mod_path)
73+ try :
74+ return win32api.LoadLibraryEx(
75+ # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
76+ os.path.join(mod_path, dll_name),
77+ 0 , LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
78+ except :
79+ pass
7880
7981 # Finally, try default search
82+ # Only reached if DLL wasn't found in any site-package path
8083 try :
81- handle = win32api.LoadLibrary(dll_name)
84+ return win32api.LoadLibrary(dll_name)
8285 except :
8386 pass
84- else :
85- break
86- else :
87- raise RuntimeError (' Failed to load nvvm' )
8887
89- assert handle != 0
90- return handle
88+ raise RuntimeError (' Failed to load nvvm' )
9189
9290
9391cdef int _check_or_init_nvvm() except - 1 nogil:
0 commit comments