Skip to content

Commit 4379bb5

Browse files
committed
Merge branch 'path_finder_search_priority_v2' into path_finder_search_priority_v2_use_in_bindings
2 parents 34c2874 + 782fcf6 commit 4379bb5

File tree

2 files changed

+29
-42
lines changed

2 files changed

+29
-42
lines changed

cuda_bindings/cuda/bindings/_path_finder/find_nvidia_dynamic_library.py

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@ def _no_such_file_in_sub_dirs(sub_dirs, file_wild, error_messages, attachments):
1818

1919

2020
def _find_so_using_nvidia_lib_dirs(libname, so_basename, error_messages, attachments):
21-
if libname == "nvvm": # noqa: SIM108
22-
nvidia_sub_dirs = ("nvidia", "*", "nvvm", "lib64")
23-
else:
24-
nvidia_sub_dirs = ("nvidia", "*", "lib")
21+
nvidia_sub_dirs = ("nvidia", "*", "nvvm", "lib64") if libname == "nvvm" else ("nvidia", "*", "lib")
2522
file_wild = so_basename + "*"
2623
for lib_dir in find_sub_dirs_all_sitepackages(nvidia_sub_dirs):
2724
# First look for an exact match
@@ -47,10 +44,7 @@ def _find_dll_under_dir(dirpath, file_wild):
4744

4845

4946
def _find_dll_using_nvidia_bin_dirs(libname, lib_searched_for, error_messages, attachments):
50-
if libname == "nvvm": # noqa: SIM108
51-
nvidia_sub_dirs = ("nvidia", "*", "nvvm", "bin")
52-
else:
53-
nvidia_sub_dirs = ("nvidia", "*", "bin")
47+
nvidia_sub_dirs = ("nvidia", "*", "nvvm", "bin") if libname == "nvvm" else ("nvidia", "*", "bin")
5448
for bin_dir in find_sub_dirs_all_sitepackages(nvidia_sub_dirs):
5549
dll_name = _find_dll_under_dir(bin_dir, lib_searched_for)
5650
if dll_name is not None:
@@ -71,18 +65,16 @@ def _find_lib_dir_using_cuda_home(libname):
7165
if cuda_home is None:
7266
return None
7367
if IS_WINDOWS:
74-
if libname == "nvvm": # noqa: SIM108
75-
subdirs = (os.path.join("nvvm", "bin"),)
76-
else:
77-
subdirs = ("bin",)
68+
subdirs = (os.path.join("nvvm", "bin"),) if libname == "nvvm" else ("bin",)
7869
else:
79-
if libname == "nvvm": # noqa: SIM108
80-
subdirs = (os.path.join("nvvm", "lib64"),)
81-
else:
82-
subdirs = (
70+
subdirs = (
71+
(os.path.join("nvvm", "lib64"),)
72+
if libname == "nvvm"
73+
else (
8374
"lib64", # CTK
8475
"lib", # Conda
8576
)
77+
)
8678
for subdir in subdirs:
8779
dirname = os.path.join(cuda_home, subdir)
8880
if os.path.isdir(dirname):
@@ -116,14 +108,14 @@ def _find_dll_using_lib_dir(lib_dir, libname, error_messages, attachments):
116108
return None
117109

118110

119-
def _find_nvvm_lib_dir_from_other_abs_path(other_abs_path):
111+
def _find_nvvm_lib_dir_from_anchor_abs_path(anchor_abs_path):
120112
nvvm_subdir = "bin" if IS_WINDOWS else "lib64"
121-
while other_abs_path:
122-
if os.path.isdir(other_abs_path):
123-
nvvm_lib_dir = os.path.join(other_abs_path, "nvvm", nvvm_subdir)
113+
while anchor_abs_path:
114+
if os.path.isdir(anchor_abs_path):
115+
nvvm_lib_dir = os.path.join(anchor_abs_path, "nvvm", nvvm_subdir)
124116
if os.path.isdir(nvvm_lib_dir):
125117
return nvvm_lib_dir
126-
other_abs_path = os.path.dirname(other_abs_path)
118+
anchor_abs_path = os.path.dirname(anchor_abs_path)
127119
return None
128120

129121

@@ -134,31 +126,34 @@ def __init__(self, libname: str):
134126
self.attachments = []
135127
self.abs_path = None
136128

137-
cuda_home_lib_dir = _find_lib_dir_using_cuda_home(libname)
138129
if IS_WINDOWS:
139130
self.lib_searched_for = f"{libname}*.dll"
140-
if cuda_home_lib_dir is not None:
141-
self.abs_path = _find_dll_using_lib_dir(
142-
cuda_home_lib_dir, libname, self.error_messages, self.attachments
143-
)
144131
if self.abs_path is None:
145132
self.abs_path = _find_dll_using_nvidia_bin_dirs(
146133
libname, self.lib_searched_for, self.error_messages, self.attachments
147134
)
148135
else:
149136
self.lib_searched_for = f"lib{libname}.so"
150-
if cuda_home_lib_dir is not None:
151-
self.abs_path = _find_so_using_lib_dir(
152-
cuda_home_lib_dir, self.lib_searched_for, self.error_messages, self.attachments
153-
)
154137
if self.abs_path is None:
155138
self.abs_path = _find_so_using_nvidia_lib_dirs(
156139
libname, self.lib_searched_for, self.error_messages, self.attachments
157140
)
158141

159-
def retry_with_other_abs_path(self, other_abs_path):
142+
def retry_with_cuda_home_priority_last(self):
143+
cuda_home_lib_dir = _find_lib_dir_using_cuda_home(self.libname)
144+
if cuda_home_lib_dir is not None:
145+
if IS_WINDOWS:
146+
self.abs_path = _find_dll_using_lib_dir(
147+
cuda_home_lib_dir, self.libname, self.error_messages, self.attachments
148+
)
149+
else:
150+
self.abs_path = _find_so_using_lib_dir(
151+
cuda_home_lib_dir, self.lib_searched_for, self.error_messages, self.attachments
152+
)
153+
154+
def retry_with_anchor_abs_path(self, anchor_abs_path):
160155
assert self.libname == "nvvm"
161-
nvvm_lib_dir = _find_nvvm_lib_dir_from_other_abs_path(other_abs_path)
156+
nvvm_lib_dir = _find_nvvm_lib_dir_from_anchor_abs_path(anchor_abs_path)
162157
if nvvm_lib_dir is None:
163158
return
164159
if IS_WINDOWS:

cuda_bindings/cuda/bindings/_path_finder/load_nvidia_dynamic_library.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
)
2828

2929

30-
def _load_other_in_subprocess(libname, error_messages):
30+
def _load_anchor_in_subprocess(libname, error_messages):
3131
code = f"""\
3232
from cuda.bindings._path_finder.load_nvidia_dynamic_library import load_nvidia_dynamic_library
3333
import json
@@ -59,15 +59,7 @@ def _load_nvidia_dynamic_library_no_cache(libname: str) -> LoadedDL:
5959
loaded = load_with_system_search(libname, found.lib_searched_for)
6060
if loaded is not None:
6161
return loaded
62-
if libname == "nvvm":
63-
# Use cudart as anchor point (libcudart.so.12 is only ~720K, cudart64_12.dll ~560K).
64-
loaded_cudart = check_if_already_loaded_from_elsewhere("cudart")
65-
if loaded_cudart is not None:
66-
found.retry_with_other_abs_path(loaded_cudart.abs_path)
67-
else:
68-
cudart_abs_path = _load_other_in_subprocess("cudart", found.error_messages)
69-
if cudart_abs_path is not None:
70-
found.retry_with_other_abs_path(cudart_abs_path)
62+
found.retry_with_cuda_home_priority_last()
7163
found.raise_if_abs_path_is_None()
7264

7365
# Load the library from the found path

0 commit comments

Comments
 (0)