77import urllib .request
88from typing import Union
99from modules import rocm
10+ from modules .launch_utils import args
1011
1112
1213DLL_MAPPING = {
1819}
1920HIPSDK_TARGETS = ['rocblas.dll' , 'rocsolver.dll' , 'hipfft.dll' ,]
2021
21- hipBLASLt_available = False
22- MIOpen_available = False
22+ MIOpen_enabled = False
2323
2424path = os .path .abspath (os .environ .get ('ZLUDA' , '.zluda' ))
2525default_agent : Union [rocm .Agent , None ] = None
2626hipBLASLt_enabled = False
2727
28- nightly = os .environ .get ("ZLUDA_NIGHTLY" , "0" ) == "1"
29-
3028
3129class ZLUDAResult (ctypes .Structure ):
3230 _fields_ = [
@@ -66,42 +64,25 @@ def get_nightly_flag(self) -> int:
6664ml = None
6765
6866
69- def load_core_modules ():
70- global core , ml # pylint: disable=global-statement
71- core = Core (ctypes .windll .LoadLibrary (os .path .join (path , 'nvcuda.dll' )))
72- ml = ZLUDALibrary (ctypes .windll .LoadLibrary (os .path .join (path , 'nvml.dll' )))
73-
74-
7567def set_default_agent (agent : rocm .Agent ):
7668 global default_agent # pylint: disable=global-statement
7769 default_agent = agent
7870
79- is_nightly = False
80- try :
81- load_core_modules ()
82- is_nightly = core .get_nightly_flag () == 1
83- except Exception :
84- pass
85-
86- global hipBLASLt_available , hipBLASLt_enabled # pylint: disable=global-statement
87- hipBLASLt_available = is_nightly and os .path .exists (rocm .blaslt_tensile_libpath )
88- hipBLASLt_enabled = hipBLASLt_available and os .path .exists (os .path .join (rocm .path , "bin" , "hipblaslt.dll" ))
89-
90- global MIOpen_available # pylint: disable=global-statement
91- MIOpen_available = is_nightly and os .path .exists (os .path .join (rocm .path , "bin" , "MIOpen.dll" ))
92-
9371
9472def is_reinstall_needed () -> bool : # ZLUDA<3.8.7
9573 return not os .path .exists (os .path .join (path , 'cufftw.dll' ))
9674
9775
98- def install () -> None :
76+ def install ():
9977 if os .path .exists (path ):
10078 return
10179
10280 platform = "windows"
103- commit = os .environ .get ("ZLUDA_HASH" , "ae0540beb129ffd140226ce956b386619b38f84c" )
104- if nightly :
81+ commit = os .environ .get ("ZLUDA_HASH" , "dba64c0966df2c71e82255e942c96e2e1cea3a2d" )
82+ if os .environ .get ("ZLUDA_NIGHTLY" , "0" ) == "1" :
83+ print ("Warning: Environment variable 'ZLUDA_NIGHTLY' will be removed. Please use command-line argument '--use-nightly' instead." )
84+ args .use_nightly = True
85+ if args .use_nightly :
10586 platform = "nightly-" + platform
10687 urllib .request .urlretrieve (f'https://github.com/lshqqytiger/ZLUDA/releases/download/rel.{ commit } /ZLUDA-{ platform } -rocm{ rocm .version [0 ]} -amd64.zip' , '_zluda' )
10788 with zipfile .ZipFile ('_zluda' , 'r' ) as archive :
@@ -113,7 +94,7 @@ def install() -> None:
11394 os .remove ('_zluda' )
11495
11596
116- def uninstall () -> None :
97+ def uninstall ():
11798 if os .path .exists (path ):
11899 shutil .rmtree (path )
119100
@@ -137,23 +118,29 @@ def link_or_copy(src: os.PathLike, dst: os.PathLike):
137118 shutil .copyfile (src , dst )
138119
139120
140- def make_copy () -> None :
121+ def load ():
122+ global core , ml , hipBLASLt_enabled , MIOpen_enabled # pylint: disable=global-statement
123+ core = Core (ctypes .windll .LoadLibrary (os .path .join (path , 'nvcuda.dll' )))
124+ ml = ZLUDALibrary (ctypes .windll .LoadLibrary (os .path .join (path , 'nvml.dll' )))
125+ is_nightly = core .get_nightly_flag () == 1
126+ hipBLASLt_enabled = is_nightly and os .path .exists (rocm .blaslt_tensile_libpath ) and os .path .exists (os .path .join (rocm .path , "bin" , "hipblaslt.dll" ))
127+ MIOpen_enabled = is_nightly and os .path .exists (os .path .join (rocm .path , "bin" , "MIOpen.dll" ))
128+
141129 for k , v in DLL_MAPPING .items ():
142130 if not os .path .exists (os .path .join (path , v )):
143131 link_or_copy (os .path .join (path , k ), os .path .join (path , v ))
144132
145133 if hipBLASLt_enabled and not os .path .exists (os .path .join (path , 'cublasLt64_11.dll' )):
146134 link_or_copy (os .path .join (path , 'cublasLt.dll' ), os .path .join (path , 'cublasLt64_11.dll' ))
147135
148- if MIOpen_available and not os .path .exists (os .path .join (path , 'cudnn64_9.dll' )):
136+ if MIOpen_enabled and not os .path .exists (os .path .join (path , 'cudnn64_9.dll' )):
149137 link_or_copy (os .path .join (path , 'cudnn.dll' ), os .path .join (path , 'cudnn64_9.dll' ))
150138
139+ print (f"ZLUDA load: path='{ path } ' nightly={ bool (core .get_nightly_flag ())} " )
151140
152- def load () -> None :
153141 os .environ ["ZLUDA_COMGR_LOG_LEVEL" ] = "1"
154142 os .environ ["ZLUDA_NVRTC_LIB" ] = os .path .join ([v for v in site .getsitepackages () if v .endswith ("site-packages" )][0 ], "torch" , "lib" , "nvrtc64_112_0.dll" )
155143
156- load_core_modules ()
157144 for v in HIPSDK_TARGETS :
158145 ctypes .windll .LoadLibrary (os .path .join (rocm .path , 'bin' , v ))
159146 for v in DLL_MAPPING .values ():
@@ -166,12 +153,13 @@ def load() -> None:
166153 else :
167154 os .environ ["DISABLE_ADDMM_CUDA_LT" ] = "1"
168155
169- if MIOpen_available :
156+ if MIOpen_enabled :
170157 ctypes .windll .LoadLibrary (os .path .join (rocm .path , 'bin' , 'MIOpen.dll' ))
171158 ctypes .windll .LoadLibrary (os .path .join (path , 'cudnn64_9.dll' ))
172159
173160 def conceal ():
174- import torch # noqa: F401
161+ import torch
162+ torch .version .hip = rocm .version
175163 platform = sys .platform
176164 sys .platform = ""
177165 from torch .utils import cpp_extension
0 commit comments