@@ -80,10 +80,7 @@ def _find_dll_using_nvidia_bin_dirs(
8080 return None
8181
8282
83- def _find_lib_dir_using_cuda_home (libname : str ) -> Optional [str ]:
84- cuda_home = get_cuda_home_or_path ()
85- if cuda_home is None :
86- return None
83+ def _find_lib_dir_using_anchor_point (libname : str , anchor_point : str ) -> Optional [str ]:
8784 subdirs_list : tuple [tuple [str , ...], ...]
8885 if IS_WINDOWS :
8986 if libname == "nvvm" : # noqa: SIM108
@@ -106,11 +103,25 @@ def _find_lib_dir_using_cuda_home(libname: str) -> Optional[str]:
106103 )
107104 for sub_dirs in subdirs_list :
108105 dirname : str # work around bug in mypy
109- for dirname in find_sub_dirs ((cuda_home ,), sub_dirs ):
106+ for dirname in find_sub_dirs ((anchor_point ,), sub_dirs ):
110107 return dirname
111108 return None
112109
113110
111+ def _find_lib_dir_using_cuda_home (libname : str ) -> Optional [str ]:
112+ cuda_home = get_cuda_home_or_path ()
113+ if cuda_home is None :
114+ return None
115+ return _find_lib_dir_using_anchor_point (libname , cuda_home )
116+
117+
118+ def _find_lib_dir_using_conda_prefix (libname : str ) -> Optional [str ]:
119+ conda_prefix = os .getenv ("CONDA_PREFIX" )
120+ if not conda_prefix :
121+ return None
122+ return _find_lib_dir_using_anchor_point (libname , conda_prefix )
123+
124+
114125def _find_so_using_lib_dir (
115126 lib_dir : str , so_basename : str , error_messages : list [str ], attachments : list [str ]
116127) -> Optional [str ]:
@@ -146,44 +157,56 @@ def __init__(self, libname: str):
146157 self .libname = libname
147158 self .error_messages : list [str ] = []
148159 self .attachments : list [str ] = []
149- self .abs_path = None
160+ self .abs_path : Optional [str ] = None
161+
162+ self ._try_site_packages ()
163+ self ._try_with_conda_prefix ()
150164
165+ def _try_site_packages (self ) -> None :
151166 if IS_WINDOWS :
152- self .lib_searched_for = f"{ libname } *.dll"
167+ self .lib_searched_for = f"{ self . libname } *.dll"
153168 if self .abs_path is None :
154169 self .abs_path = _find_dll_using_nvidia_bin_dirs (
155- libname ,
170+ self . libname ,
156171 self .lib_searched_for ,
157172 self .error_messages ,
158173 self .attachments ,
159174 )
160175 else :
161- self .lib_searched_for = f"lib{ libname } .so"
176+ self .lib_searched_for = f"lib{ self . libname } .so"
162177 if self .abs_path is None :
163178 self .abs_path = _find_so_using_nvidia_lib_dirs (
164- libname ,
179+ self . libname ,
165180 self .lib_searched_for ,
166181 self .error_messages ,
167182 self .attachments ,
168183 )
169184
185+ def _try_with_conda_prefix (self ) -> None :
186+ conda_lib_dir = _find_lib_dir_using_conda_prefix (self .libname )
187+ if conda_lib_dir is not None :
188+ self ._find_using_lib_dir (conda_lib_dir )
189+
170190 def try_with_cuda_home (self ) -> None :
171191 cuda_home_lib_dir = _find_lib_dir_using_cuda_home (self .libname )
172192 if cuda_home_lib_dir is not None :
173- if IS_WINDOWS :
174- self .abs_path = _find_dll_using_lib_dir (
175- cuda_home_lib_dir ,
176- self .libname ,
177- self .error_messages ,
178- self .attachments ,
179- )
180- else :
181- self .abs_path = _find_so_using_lib_dir (
182- cuda_home_lib_dir ,
183- self .lib_searched_for ,
184- self .error_messages ,
185- self .attachments ,
186- )
193+ self ._find_using_lib_dir (cuda_home_lib_dir )
194+
195+ def _find_using_lib_dir (self , lib_dir : str ) -> None :
196+ if IS_WINDOWS :
197+ self .abs_path = _find_dll_using_lib_dir (
198+ lib_dir ,
199+ self .libname ,
200+ self .error_messages ,
201+ self .attachments ,
202+ )
203+ else :
204+ self .abs_path = _find_so_using_lib_dir (
205+ lib_dir ,
206+ self .lib_searched_for ,
207+ self .error_messages ,
208+ self .attachments ,
209+ )
187210
188211 def raise_if_abs_path_is_None (self ) -> str : # noqa: N802
189212 if self .abs_path :
0 commit comments