Skip to content

Commit c1a4983

Browse files
committed
Move nvrtc-specific code from find_nvidia_dynamic_library.py to supported_libs.is_suppressed_dll_file()
1 parent 74d7230 commit c1a4983

2 files changed

Lines changed: 17 additions & 15 deletions

File tree

cuda_bindings/cuda/bindings/_path_finder/find_nvidia_dynamic_library.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88

99
from .cuda_paths import IS_WIN32, get_cuda_paths
10+
from .supported_libs import is_suppressed_dll_file
1011
from .sys_path_find_sub_dirs import sys_path_find_sub_dirs
1112

1213

@@ -39,23 +40,12 @@ def _find_so_using_nvidia_lib_dirs(libname, so_basename, error_messages, attachm
3940

4041

4142
def _find_dll_under_dir(dirpath, file_wild):
42-
dll_name = None
4343
for path in sorted(glob.glob(os.path.join(dirpath, file_wild))):
44-
# nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-win_amd64.whl:
45-
# nvidia\cuda_nvrtc\bin\
46-
# nvrtc-builtins64_128.dll
47-
# nvrtc64_120_0.alt.dll
48-
# nvrtc64_120_0.dll
49-
node = os.path.basename(path)
50-
if node.endswith(".alt.dll"):
44+
if not os.path.isfile(path):
5145
continue
52-
if "-builtins" in node:
53-
continue
54-
if dll_name is not None:
55-
continue
56-
if os.path.isfile(path):
57-
dll_name = path
58-
return dll_name
46+
if not is_suppressed_dll_file(os.path.basename(path)):
47+
return path
48+
return None
5949

6050

6151
def _find_dll_using_nvidia_bin_dirs(libname, error_messages, attachments):

cuda_bindings/cuda/bindings/_path_finder/supported_libs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,18 @@
319319
"nvrtc",
320320
)
321321

322+
323+
def is_suppressed_dll_file(path_basename: str) -> bool:
324+
if path_basename.startswith("nvrtc"):
325+
# nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-win_amd64.whl:
326+
# nvidia\cuda_nvrtc\bin\
327+
# nvrtc-builtins64_128.dll
328+
# nvrtc64_120_0.alt.dll
329+
# nvrtc64_120_0.dll
330+
return path_basename.endswith(".alt.dll") or "-builtins" in path_basename
331+
return False
332+
333+
322334
# Based on nm output for Linux x86_64 /usr/local/cuda (12.8.1)
323335
EXPECTED_LIB_SYMBOLS = {
324336
"nvJitLink": ("nvJitLinkVersion",),

0 commit comments

Comments
 (0)