@@ -18,10 +18,7 @@ def _no_such_file_in_sub_dirs(sub_dirs, file_wild, error_messages, attachments):
1818
1919
2020def _find_so_using_nvidia_lib_dirs (libname , so_basename , error_messages , attachments ):
21- if libname == "nvvm" : # noqa: SIM108
22- nvidia_sub_dirs = ("nvidia" , "*" , "nvvm" , "lib64" )
23- else :
24- nvidia_sub_dirs = ("nvidia" , "*" , "lib" )
21+ nvidia_sub_dirs = ("nvidia" , "*" , "nvvm" , "lib64" ) if libname == "nvvm" else ("nvidia" , "*" , "lib" )
2522 file_wild = so_basename + "*"
2623 for lib_dir in find_sub_dirs_all_sitepackages (nvidia_sub_dirs ):
2724 # First look for an exact match
@@ -47,10 +44,7 @@ def _find_dll_under_dir(dirpath, file_wild):
4744
4845
4946def _find_dll_using_nvidia_bin_dirs (libname , lib_searched_for , error_messages , attachments ):
50- if libname == "nvvm" : # noqa: SIM108
51- nvidia_sub_dirs = ("nvidia" , "*" , "nvvm" , "bin" )
52- else :
53- nvidia_sub_dirs = ("nvidia" , "*" , "bin" )
47+ nvidia_sub_dirs = ("nvidia" , "*" , "nvvm" , "bin" ) if libname == "nvvm" else ("nvidia" , "*" , "bin" )
5448 for bin_dir in find_sub_dirs_all_sitepackages (nvidia_sub_dirs ):
5549 dll_name = _find_dll_under_dir (bin_dir , lib_searched_for )
5650 if dll_name is not None :
@@ -71,18 +65,16 @@ def _find_lib_dir_using_cuda_home(libname):
7165 if cuda_home is None :
7266 return None
7367 if IS_WINDOWS :
74- if libname == "nvvm" : # noqa: SIM108
75- subdirs = (os .path .join ("nvvm" , "bin" ),)
76- else :
77- subdirs = ("bin" ,)
68+ subdirs = (os .path .join ("nvvm" , "bin" ),) if libname == "nvvm" else ("bin" ,)
7869 else :
79- if libname == "nvvm" : # noqa: SIM108
80- subdirs = (os .path .join ("nvvm" , "lib64" ),)
81- else :
82- subdirs = (
70+ subdirs = (
71+ (os .path .join ("nvvm" , "lib64" ),)
72+ if libname == "nvvm"
73+ else (
8374 "lib64" , # CTK
8475 "lib" , # Conda
8576 )
77+ )
8678 for subdir in subdirs :
8779 dirname = os .path .join (cuda_home , subdir )
8880 if os .path .isdir (dirname ):
@@ -116,14 +108,14 @@ def _find_dll_using_lib_dir(lib_dir, libname, error_messages, attachments):
116108 return None
117109
118110
119- def _find_nvvm_lib_dir_from_other_abs_path ( other_abs_path ):
111+ def _find_nvvm_lib_dir_from_anchor_abs_path ( anchor_abs_path ):
120112 nvvm_subdir = "bin" if IS_WINDOWS else "lib64"
121- while other_abs_path :
122- if os .path .isdir (other_abs_path ):
123- nvvm_lib_dir = os .path .join (other_abs_path , "nvvm" , nvvm_subdir )
113+ while anchor_abs_path :
114+ if os .path .isdir (anchor_abs_path ):
115+ nvvm_lib_dir = os .path .join (anchor_abs_path , "nvvm" , nvvm_subdir )
124116 if os .path .isdir (nvvm_lib_dir ):
125117 return nvvm_lib_dir
126- other_abs_path = os .path .dirname (other_abs_path )
118+ anchor_abs_path = os .path .dirname (anchor_abs_path )
127119 return None
128120
129121
@@ -134,31 +126,34 @@ def __init__(self, libname: str):
134126 self .attachments = []
135127 self .abs_path = None
136128
137- cuda_home_lib_dir = _find_lib_dir_using_cuda_home (libname )
138129 if IS_WINDOWS :
139130 self .lib_searched_for = f"{ libname } *.dll"
140- if cuda_home_lib_dir is not None :
141- self .abs_path = _find_dll_using_lib_dir (
142- cuda_home_lib_dir , libname , self .error_messages , self .attachments
143- )
144131 if self .abs_path is None :
145132 self .abs_path = _find_dll_using_nvidia_bin_dirs (
146133 libname , self .lib_searched_for , self .error_messages , self .attachments
147134 )
148135 else :
149136 self .lib_searched_for = f"lib{ libname } .so"
150- if cuda_home_lib_dir is not None :
151- self .abs_path = _find_so_using_lib_dir (
152- cuda_home_lib_dir , self .lib_searched_for , self .error_messages , self .attachments
153- )
154137 if self .abs_path is None :
155138 self .abs_path = _find_so_using_nvidia_lib_dirs (
156139 libname , self .lib_searched_for , self .error_messages , self .attachments
157140 )
158141
159- def retry_with_other_abs_path (self , other_abs_path ):
142+ def retry_with_cuda_home_priority_last (self ):
143+ cuda_home_lib_dir = _find_lib_dir_using_cuda_home (self .libname )
144+ if cuda_home_lib_dir is not None :
145+ if IS_WINDOWS :
146+ self .abs_path = _find_dll_using_lib_dir (
147+ cuda_home_lib_dir , self .libname , self .error_messages , self .attachments
148+ )
149+ else :
150+ self .abs_path = _find_so_using_lib_dir (
151+ cuda_home_lib_dir , self .lib_searched_for , self .error_messages , self .attachments
152+ )
153+
154+ def retry_with_anchor_abs_path (self , anchor_abs_path ):
160155 assert self .libname == "nvvm"
161- nvvm_lib_dir = _find_nvvm_lib_dir_from_other_abs_path ( other_abs_path )
156+ nvvm_lib_dir = _find_nvvm_lib_dir_from_anchor_abs_path ( anchor_abs_path )
162157 if nvvm_lib_dir is None :
163158 return
164159 if IS_WINDOWS :
0 commit comments