Skip to content

Commit a3ae3a3

Browse files
committed
Factor out _find_dll_under_dir(dirpath, file_wild) and reuse from _find_dll_using_nvidia_bin_dirs(), _find_dll_using_cudalib_dir() (to fix loading nvrtc64_120_0.dll from local CTK)
1 parent 1344621 commit a3ae3a3

1 file changed

Lines changed: 32 additions & 28 deletions

File tree

cuda_bindings/cuda/bindings/_path_finder/find_nvidia_dynamic_library.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -43,39 +43,43 @@ def _append_to_os_environ_path(dirpath):
4343
os.environ["PATH"] = dirpath if curr_path is None else os.pathsep.join((curr_path, dirpath))
4444

4545

46+
def _find_dll_under_dir(dirpath, file_wild):
47+
dll_name = None
48+
have_builtins = False
49+
for path in sorted(glob.glob(os.path.join(dirpath, file_wild))):
50+
# nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-win_amd64.whl:
51+
# nvidia\cuda_nvrtc\bin\
52+
# nvrtc-builtins64_128.dll
53+
# nvrtc64_120_0.alt.dll
54+
# nvrtc64_120_0.dll
55+
node = os.path.basename(path)
56+
if node.endswith(".alt.dll"):
57+
continue
58+
if "-builtins" in node:
59+
have_builtins = True
60+
continue
61+
if dll_name is not None:
62+
continue
63+
if os.path.isfile(path):
64+
dll_name = path
65+
if dll_name is not None:
66+
if have_builtins:
67+
# Add the DLL directory to the search path
68+
os.add_dll_directory(dirpath)
69+
# Update PATH as a fallback for dependent DLL resolution
70+
_append_to_os_environ_path(dirpath)
71+
return dll_name
72+
73+
4674
def _find_dll_using_nvidia_bin_dirs(libname, error_messages, attachments):
4775
if libname == "nvvm": # noqa: SIM108
4876
nvidia_sub_dirs = ("nvidia", "*", "nvvm", "bin")
4977
else:
5078
nvidia_sub_dirs = ("nvidia", "*", "bin")
5179
file_wild = libname + "*.dll"
5280
for bin_dir in sys_path_find_sub_dirs(nvidia_sub_dirs):
53-
dll_name = None
54-
have_builtins = False
55-
for path in sorted(glob.glob(os.path.join(bin_dir, file_wild))):
56-
# nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-win_amd64.whl:
57-
# nvidia\cuda_nvrtc\bin\
58-
# nvrtc-builtins64_128.dll
59-
# nvrtc64_120_0.alt.dll
60-
# nvrtc64_120_0.dll
61-
# See also:
62-
# https://github.com/NVIDIA/cuda-python/pull/563#discussion_r2054427641
63-
node = os.path.basename(path)
64-
if node.endswith(".alt.dll"):
65-
continue
66-
if "-builtins" in node:
67-
have_builtins = True
68-
continue
69-
if dll_name is not None:
70-
continue
71-
if os.path.isfile(path):
72-
dll_name = path
81+
dll_name = _find_dll_under_dir(bin_dir, file_wild)
7382
if dll_name is not None:
74-
if have_builtins:
75-
# Add the DLL directory to the search path
76-
os.add_dll_directory(bin_dir)
77-
# Update PATH as a fallback for dependent DLL resolution
78-
_append_to_os_environ_path(bin_dir)
7983
return dll_name
8084
_no_such_file_in_sub_dirs(nvidia_sub_dirs, file_wild, error_messages, attachments)
8185
return None
@@ -124,9 +128,9 @@ def _find_dll_using_cudalib_dir(libname, error_messages, attachments):
124128
if cudalib_dir is None:
125129
return None
126130
file_wild = libname + "*.dll"
127-
for dll_name in sorted(glob.glob(os.path.join(cudalib_dir, file_wild))):
128-
if os.path.isfile(dll_name):
129-
return dll_name
131+
dll_name = _find_dll_under_dir(cudalib_dir, file_wild)
132+
if dll_name is not None:
133+
return dll_name
130134
error_messages.append(f"No such file: {file_wild}")
131135
attachments.append(f' listdir("{cudalib_dir}"):')
132136
for node in sorted(os.listdir(cudalib_dir)):

0 commit comments

Comments
 (0)