diff --git a/cuda_bindings/cuda/bindings/_path_finder/find_nvidia_dynamic_library.py b/cuda_bindings/cuda/bindings/_path_finder/find_nvidia_dynamic_library.py index 9835b72d0e3..bb043b45508 100644 --- a/cuda_bindings/cuda/bindings/_path_finder/find_nvidia_dynamic_library.py +++ b/cuda_bindings/cuda/bindings/_path_finder/find_nvidia_dynamic_library.py @@ -5,7 +5,7 @@ import glob import os -from cuda.bindings._path_finder.find_sub_dirs import find_sub_dirs_all_sitepackages +from cuda.bindings._path_finder.find_sub_dirs import find_sub_dirs, find_sub_dirs_all_sitepackages from cuda.bindings._path_finder.supported_libs import IS_WINDOWS, is_suppressed_dll_file @@ -44,11 +44,14 @@ def _find_dll_under_dir(dirpath, file_wild): def _find_dll_using_nvidia_bin_dirs(libname, lib_searched_for, error_messages, attachments): - nvidia_sub_dirs = ("nvidia", "*", "nvvm", "bin") if libname == "nvvm" else ("nvidia", "*", "bin") - for bin_dir in find_sub_dirs_all_sitepackages(nvidia_sub_dirs): - dll_name = _find_dll_under_dir(bin_dir, lib_searched_for) - if dll_name is not None: - return dll_name + nvidia_sub_dirs_list = [("nvidia", "*", "bin")] + if libname == "nvvm": + nvidia_sub_dirs_list.append(("nvidia", "*", "nvvm", "bin")) + for nvidia_sub_dirs in nvidia_sub_dirs_list: + for bin_dir in find_sub_dirs_all_sitepackages(nvidia_sub_dirs): + dll_name = _find_dll_under_dir(bin_dir, lib_searched_for) + if dll_name is not None: + return dll_name _no_such_file_in_sub_dirs(nvidia_sub_dirs, lib_searched_for, error_messages, attachments) return None @@ -65,19 +68,23 @@ def _find_lib_dir_using_cuda_home(libname): if cuda_home is None: return None if IS_WINDOWS: - subdirs = (os.path.join("nvvm", "bin"),) if libname == "nvvm" else ("bin",) + if libname == "nvvm": # noqa: SIM108 + subdirs_list = ( + ("nvvm", "bin", "*"), + ("nvvm", "bin"), + ) + else: + subdirs_list = (("bin",),) else: - subdirs = ( - (os.path.join("nvvm", "lib64"),) - if libname == "nvvm" - else ( - "lib64", # CTK - "lib", # Conda + if libname == "nvvm": # noqa: SIM108 + subdirs_list = (("nvvm", "lib64"),) + else: + subdirs_list = ( + ("lib64",), # CTK + ("lib",), # Conda ) - ) - for subdir in subdirs: - dirname = os.path.join(cuda_home, subdir) - if os.path.isdir(dirname): + for sub_dirs in subdirs_list: + for dirname in find_sub_dirs((cuda_home,), sub_dirs): return dirname return None