55import glob
66import os
77
8- from cuda .bindings ._path_finder .find_sub_dirs import find_sub_dirs_all_sitepackages
8+ from cuda .bindings ._path_finder .find_sub_dirs import find_sub_dirs , find_sub_dirs_all_sitepackages
99from cuda .bindings ._path_finder .supported_libs import IS_WINDOWS , is_suppressed_dll_file
1010
1111
@@ -44,11 +44,14 @@ def _find_dll_under_dir(dirpath, file_wild):
4444
4545
4646def _find_dll_using_nvidia_bin_dirs (libname , lib_searched_for , error_messages , attachments ):
47- nvidia_sub_dirs = ("nvidia" , "*" , "nvvm" , "bin" ) if libname == "nvvm" else ("nvidia" , "*" , "bin" )
48- for bin_dir in find_sub_dirs_all_sitepackages (nvidia_sub_dirs ):
49- dll_name = _find_dll_under_dir (bin_dir , lib_searched_for )
50- if dll_name is not None :
51- return dll_name
47+ nvidia_sub_dirs_list = [("nvidia" , "*" , "bin" )]
48+ if libname == "nvvm" :
49+ nvidia_sub_dirs_list .append (("nvidia" , "*" , "nvvm" , "bin" ))
50+ for nvidia_sub_dirs in nvidia_sub_dirs_list :
51+ for bin_dir in find_sub_dirs_all_sitepackages (nvidia_sub_dirs ):
52+ dll_name = _find_dll_under_dir (bin_dir , lib_searched_for )
53+ if dll_name is not None :
54+ return dll_name
5255 _no_such_file_in_sub_dirs (nvidia_sub_dirs , lib_searched_for , error_messages , attachments )
5356 return None
5457
@@ -65,19 +68,23 @@ def _find_lib_dir_using_cuda_home(libname):
6568 if cuda_home is None :
6669 return None
6770 if IS_WINDOWS :
68- subdirs = (os .path .join ("nvvm" , "bin" ),) if libname == "nvvm" else ("bin" ,)
71+ if libname == "nvvm" : # noqa: SIM108
72+ subdirs_list = (
73+ ("nvvm" , "bin" , "*" ),
74+ ("nvvm" , "bin" ),
75+ )
76+ else :
77+ subdirs_list = (("bin" ,),)
6978 else :
70- subdirs = (
71- ( os . path . join ("nvvm" , "lib64" ),)
72- if libname == "nvvm"
73- else (
74- "lib64" , # CTK
75- "lib" , # Conda
79+ if libname == "nvvm" : # noqa: SIM108
80+ subdirs_list = ( ("nvvm" , "lib64" ),)
81+ else :
82+ subdirs_list = (
83+ ( "lib64" ,) , # CTK
84+ ( "lib" ,) , # Conda
7685 )
77- )
78- for subdir in subdirs :
79- dirname = os .path .join (cuda_home , subdir )
80- if os .path .isdir (dirname ):
86+ for sub_dirs in subdirs_list :
87+ for dirname in find_sub_dirs ((cuda_home ,), sub_dirs ):
8188 return dirname
8289 return None
8390
0 commit comments