Skip to content

Commit 74b328b

Browse files
committed
path_finder tweaks (experiment)
1 parent a8285b0 commit 74b328b

1 file changed

Lines changed: 24 additions & 17 deletions

File tree

cuda_bindings/cuda/bindings/_path_finder/find_nvidia_dynamic_library.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import glob
66
import os
77

8-
from cuda.bindings._path_finder.find_sub_dirs import find_sub_dirs_all_sitepackages
8+
from cuda.bindings._path_finder.find_sub_dirs import find_sub_dirs, find_sub_dirs_all_sitepackages
99
from cuda.bindings._path_finder.supported_libs import IS_WINDOWS, is_suppressed_dll_file
1010

1111

@@ -44,11 +44,14 @@ def _find_dll_under_dir(dirpath, file_wild):
4444

4545

4646
def _find_dll_using_nvidia_bin_dirs(libname, lib_searched_for, error_messages, attachments):
47-
nvidia_sub_dirs = ("nvidia", "*", "nvvm", "bin") if libname == "nvvm" else ("nvidia", "*", "bin")
48-
for bin_dir in find_sub_dirs_all_sitepackages(nvidia_sub_dirs):
49-
dll_name = _find_dll_under_dir(bin_dir, lib_searched_for)
50-
if dll_name is not None:
51-
return dll_name
47+
nvidia_sub_dirs_list = [("nvidia", "*", "bin")]
48+
if libname == "nvvm":
49+
nvidia_sub_dirs_list.append(("nvidia", "*", "nvvm", "bin"))
50+
for nvidia_sub_dirs in nvidia_sub_dirs_list:
51+
for bin_dir in find_sub_dirs_all_sitepackages(nvidia_sub_dirs):
52+
dll_name = _find_dll_under_dir(bin_dir, lib_searched_for)
53+
if dll_name is not None:
54+
return dll_name
5255
_no_such_file_in_sub_dirs(nvidia_sub_dirs, lib_searched_for, error_messages, attachments)
5356
return None
5457

@@ -65,19 +68,23 @@ def _find_lib_dir_using_cuda_home(libname):
6568
if cuda_home is None:
6669
return None
6770
if IS_WINDOWS:
68-
subdirs = (os.path.join("nvvm", "bin"),) if libname == "nvvm" else ("bin",)
71+
if libname == "nvvm": # noqa: SIM108
72+
subdirs_list = (
73+
("nvvm", "bin", "*"),
74+
("nvvm", "bin"),
75+
)
76+
else:
77+
subdirs_list = (("bin",),)
6978
else:
70-
subdirs = (
71-
(os.path.join("nvvm", "lib64"),)
72-
if libname == "nvvm"
73-
else (
74-
"lib64", # CTK
75-
"lib", # Conda
79+
if libname == "nvvm": # noqa: SIM108
80+
subdirs_list = (("nvvm", "lib64"),)
81+
else:
82+
subdirs_list = (
83+
("lib64",), # CTK
84+
("lib",), # Conda
7685
)
77-
)
78-
for subdir in subdirs:
79-
dirname = os.path.join(cuda_home, subdir)
80-
if os.path.isdir(dirname):
86+
for sub_dirs in subdirs_list:
87+
for dirname in find_sub_dirs((cuda_home,), sub_dirs):
8188
return dirname
8289
return None
8390

0 commit comments

Comments
 (0)