1212 IS_WINDOWS ,
1313 is_suppressed_dll_file ,
1414)
15- from cuda .pathfinder ._utils .find_sub_dirs import find_sub_dirs_all_sitepackages
15+ from cuda .pathfinder ._utils .find_sub_dirs import find_sub_dirs , find_sub_dirs_all_sitepackages
1616
1717
1818def _no_such_file_in_sub_dirs (
@@ -28,18 +28,21 @@ def _no_such_file_in_sub_dirs(
2828def _find_so_using_nvidia_lib_dirs (
2929 libname : str , so_basename : str , error_messages : list [str ], attachments : list [str ]
3030) -> Optional [str ]:
31- nvidia_sub_dirs = ("nvidia" , "*" , "nvvm" , "lib64" ) if libname == "nvvm" else ("nvidia" , "*" , "lib" )
3231 file_wild = so_basename + "*"
33- for lib_dir in find_sub_dirs_all_sitepackages (nvidia_sub_dirs ):
34- # First look for an exact match
35- so_name = os .path .join (lib_dir , so_basename )
36- if os .path .isfile (so_name ):
37- return so_name
38- # Look for a versioned library
39- # Using sort here mainly to make the result deterministic.
40- for so_name in sorted (glob .glob (os .path .join (lib_dir , file_wild ))):
32+ nvidia_sub_dirs_list : list [tuple [str , ...]] = [("nvidia" , "*" , "lib" )] # works also for CTK 13 nvvm
33+ if libname == "nvvm" :
34+ nvidia_sub_dirs_list .append (("nvidia" , "*" , "nvvm" , "lib64" )) # CTK 12
35+ for nvidia_sub_dirs in nvidia_sub_dirs_list :
36+ for lib_dir in find_sub_dirs_all_sitepackages (nvidia_sub_dirs ):
37+ # First look for an exact match
38+ so_name = os .path .join (lib_dir , so_basename )
4139 if os .path .isfile (so_name ):
4240 return so_name
41+ # Look for a versioned library
42+ # Using sort here mainly to make the result deterministic.
43+ for so_name in sorted (glob .glob (os .path .join (lib_dir , file_wild ))):
44+ if os .path .isfile (so_name ):
45+ return so_name
4346 _no_such_file_in_sub_dirs (nvidia_sub_dirs , file_wild , error_messages , attachments )
4447 return None
4548
@@ -56,11 +59,17 @@ def _find_dll_under_dir(dirpath: str, file_wild: str) -> Optional[str]:
5659def _find_dll_using_nvidia_bin_dirs (
5760 libname : str , lib_searched_for : str , error_messages : list [str ], attachments : list [str ]
5861) -> Optional [str ]:
59- nvidia_sub_dirs = ("nvidia" , "*" , "nvvm" , "bin" ) if libname == "nvvm" else ("nvidia" , "*" , "bin" )
60- for bin_dir in find_sub_dirs_all_sitepackages (nvidia_sub_dirs ):
61- dll_name = _find_dll_under_dir (bin_dir , lib_searched_for )
62- if dll_name is not None :
63- return dll_name
62+ nvidia_sub_dirs_list : list [tuple [str , ...]] = [
63+ ("nvidia" , "*" , "bin" ), # CTK 12
64+ ("nvidia" , "*" , "bin" , "*" ), # CTK 13, e.g. site-packages\nvidia\cu13\bin\x86_64\
65+ ]
66+ if libname == "nvvm" :
67+ nvidia_sub_dirs_list .append (("nvidia" , "*" , "nvvm" , "bin" )) # Only for CTK 12
68+ for nvidia_sub_dirs in nvidia_sub_dirs_list :
69+ for bin_dir in find_sub_dirs_all_sitepackages (nvidia_sub_dirs ):
70+ dll_name = _find_dll_under_dir (bin_dir , lib_searched_for )
71+ if dll_name is not None :
72+ return dll_name
6473 _no_such_file_in_sub_dirs (nvidia_sub_dirs , lib_searched_for , error_messages , attachments )
6574 return None
6675
@@ -76,21 +85,29 @@ def _find_lib_dir_using_cuda_home(libname: str) -> Optional[str]:
7685 cuda_home = _get_cuda_home ()
7786 if cuda_home is None :
7887 return None
79- subdirs : tuple [str , ...]
88+ subdirs_list : tuple [tuple [ str , ...] , ...]
8089 if IS_WINDOWS :
81- subdirs = (os .path .join ("nvvm" , "bin" ),) if libname == "nvvm" else ("bin" ,)
90+ if libname == "nvvm" : # noqa: SIM108
91+ subdirs_list = (
92+ ("nvvm" , "bin" , "*" ), # CTK 13
93+ ("nvvm" , "bin" ), # CTK 12
94+ )
95+ else :
96+ subdirs_list = (
97+ ("bin" , "x64" ), # CTK 13
98+ ("bin" ,), # CTK 12
99+ )
82100 else :
83- subdirs = (
84- ( os . path . join ("nvvm" , "lib64" ),)
85- if libname == "nvvm"
86- else (
87- "lib64" , # CTK
88- "lib" , # Conda
101+ if libname == "nvvm" : # noqa: SIM108
102+ subdirs_list = ( ("nvvm" , "lib64" ),)
103+ else :
104+ subdirs_list = (
105+ ( "lib64" ,) , # CTK
106+ ( "lib" ,) , # Conda
89107 )
90- )
91- for subdir in subdirs :
92- dirname = os .path .join (cuda_home , subdir )
93- if os .path .isdir (dirname ):
108+ for sub_dirs in subdirs_list :
109+ dirname : str # work around bug in mypy
110+ for dirname in find_sub_dirs ((cuda_home ,), sub_dirs ):
94111 return dirname
95112 return None
96113
0 commit comments