@@ -53,15 +53,23 @@ def _find_dll_using_nvidia_bin_dirs(libname, lib_searched_for, error_messages, a
5353 return None
5454
5555
56- def _get_cuda_home ():
56+ def _get_cuda_home (priority ):
57+ supported_priorities = ("first" , "last" )
58+ assert priority in supported_priorities
59+ env_priority = os .environ .get ("CUDA_PYTHON_CUDA_HOME_PRIORITY" )
60+ if env_priority :
61+ if env_priority not in supported_priorities :
62+ raise RuntimeError (f"Invalid CUDA_PYTHON_CUDA_HOME_PRIORITY { env_priority !r} ({ supported_priorities = } )" )
63+ if priority != env_priority :
64+ return None
5765 cuda_home = os .environ .get ("CUDA_HOME" )
5866 if cuda_home is None :
5967 cuda_home = os .environ .get ("CUDA_PATH" )
6068 return cuda_home
6169
6270
63- def _find_lib_dir_using_cuda_home (libname ):
64- cuda_home = _get_cuda_home ()
71+ def _find_lib_dir_using_cuda_home (libname , priority ):
72+ cuda_home = _get_cuda_home (priority )
6573 if cuda_home is None :
6674 return None
6775 if IS_WINDOWS :
@@ -126,7 +134,7 @@ def __init__(self, libname: str):
126134 self .attachments = []
127135 self .abs_path = None
128136
129- cuda_home_lib_dir = _find_lib_dir_using_cuda_home (libname )
137+ cuda_home_lib_dir = _find_lib_dir_using_cuda_home (libname , "first" )
130138 if IS_WINDOWS :
131139 self .lib_searched_for = f"{ libname } *.dll"
132140 if cuda_home_lib_dir is not None :
@@ -148,6 +156,18 @@ def __init__(self, libname: str):
148156 libname , self .lib_searched_for , self .error_messages , self .attachments
149157 )
150158
159+ def retry_with_cuda_home_priority_last (self ):
160+ cuda_home_lib_dir = _find_lib_dir_using_cuda_home (self .libname , "last" )
161+ if cuda_home_lib_dir is not None :
162+ if IS_WINDOWS :
163+ self .abs_path = _find_dll_using_lib_dir (
164+ cuda_home_lib_dir , self .libname , self .error_messages , self .attachments
165+ )
166+ else :
167+ self .abs_path = _find_so_using_lib_dir (
168+ cuda_home_lib_dir , self .lib_searched_for , self .error_messages , self .attachments
169+ )
170+
151171 def retry_with_anchor_abs_path (self , anchor_abs_path ):
152172 assert self .libname == "nvvm"
153173 nvvm_lib_dir = _find_nvvm_lib_dir_from_anchor_abs_path (anchor_abs_path )
0 commit comments