Skip to content

Commit 3e5aa1d

Browse files
committed
Add in ._try_with_conda_prefix()
1 parent 37cb325 commit 3e5aa1d

File tree

1 file changed

+47
-24
lines changed

1 file changed

+47
-24
lines changed

cuda_pathfinder/cuda/pathfinder/_dynamic_libs/find_nvidia_dynamic_lib.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,7 @@ def _find_dll_using_nvidia_bin_dirs(
8080
return None
8181

8282

83-
def _find_lib_dir_using_cuda_home(libname: str) -> Optional[str]:
84-
cuda_home = get_cuda_home_or_path()
85-
if cuda_home is None:
86-
return None
83+
def _find_lib_dir_using_anchor_point(libname: str, anchor_point: str) -> Optional[str]:
8784
subdirs_list: tuple[tuple[str, ...], ...]
8885
if IS_WINDOWS:
8986
if libname == "nvvm": # noqa: SIM108
@@ -106,11 +103,25 @@ def _find_lib_dir_using_cuda_home(libname: str) -> Optional[str]:
106103
)
107104
for sub_dirs in subdirs_list:
108105
dirname: str # work around bug in mypy
109-
for dirname in find_sub_dirs((cuda_home,), sub_dirs):
106+
for dirname in find_sub_dirs((anchor_point,), sub_dirs):
110107
return dirname
111108
return None
112109

113110

111+
def _find_lib_dir_using_cuda_home(libname: str) -> Optional[str]:
112+
cuda_home = get_cuda_home_or_path()
113+
if cuda_home is None:
114+
return None
115+
return _find_lib_dir_using_anchor_point(libname, cuda_home)
116+
117+
118+
def _find_lib_dir_using_conda_prefix(libname: str) -> Optional[str]:
119+
conda_prefix = os.getenv("CONDA_PREFIX")
120+
if not conda_prefix:
121+
return None
122+
return _find_lib_dir_using_anchor_point(libname, conda_prefix)
123+
124+
114125
def _find_so_using_lib_dir(
115126
lib_dir: str, so_basename: str, error_messages: list[str], attachments: list[str]
116127
) -> Optional[str]:
@@ -146,44 +157,56 @@ def __init__(self, libname: str):
146157
self.libname = libname
147158
self.error_messages: list[str] = []
148159
self.attachments: list[str] = []
149-
self.abs_path = None
160+
self.abs_path: Optional[str] = None
161+
162+
self._try_site_packages()
163+
self._try_with_conda_prefix()
150164

165+
def _try_site_packages(self) -> None:
151166
if IS_WINDOWS:
152-
self.lib_searched_for = f"{libname}*.dll"
167+
self.lib_searched_for = f"{self.libname}*.dll"
153168
if self.abs_path is None:
154169
self.abs_path = _find_dll_using_nvidia_bin_dirs(
155-
libname,
170+
self.libname,
156171
self.lib_searched_for,
157172
self.error_messages,
158173
self.attachments,
159174
)
160175
else:
161-
self.lib_searched_for = f"lib{libname}.so"
176+
self.lib_searched_for = f"lib{self.libname}.so"
162177
if self.abs_path is None:
163178
self.abs_path = _find_so_using_nvidia_lib_dirs(
164-
libname,
179+
self.libname,
165180
self.lib_searched_for,
166181
self.error_messages,
167182
self.attachments,
168183
)
169184

185+
def _try_with_conda_prefix(self) -> None:
186+
conda_lib_dir = _find_lib_dir_using_conda_prefix(self.libname)
187+
if conda_lib_dir is not None:
188+
self._find_using_lib_dir(conda_lib_dir)
189+
170190
def try_with_cuda_home(self) -> None:
171191
cuda_home_lib_dir = _find_lib_dir_using_cuda_home(self.libname)
172192
if cuda_home_lib_dir is not None:
173-
if IS_WINDOWS:
174-
self.abs_path = _find_dll_using_lib_dir(
175-
cuda_home_lib_dir,
176-
self.libname,
177-
self.error_messages,
178-
self.attachments,
179-
)
180-
else:
181-
self.abs_path = _find_so_using_lib_dir(
182-
cuda_home_lib_dir,
183-
self.lib_searched_for,
184-
self.error_messages,
185-
self.attachments,
186-
)
193+
self._find_using_lib_dir(cuda_home_lib_dir)
194+
195+
def _find_using_lib_dir(self, lib_dir: str) -> None:
196+
if IS_WINDOWS:
197+
self.abs_path = _find_dll_using_lib_dir(
198+
lib_dir,
199+
self.libname,
200+
self.error_messages,
201+
self.attachments,
202+
)
203+
else:
204+
self.abs_path = _find_so_using_lib_dir(
205+
lib_dir,
206+
self.lib_searched_for,
207+
self.error_messages,
208+
self.attachments,
209+
)
187210

188211
def raise_if_abs_path_is_None(self) -> str: # noqa: N802
189212
if self.abs_path:

0 commit comments

Comments
 (0)