Skip to content

Commit 3f7be4f

Browse files
committed
update zluda loader
1 parent a31ef08 commit 3f7be4f

5 files changed

Lines changed: 55 additions & 45 deletions

File tree

modules/launch_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,6 @@ def prepare_environment():
593593
if device is not None and zluda_installer.get_blaslt_enabled():
594594
print(f'ROCm hipBLASLt: arch={device.name} available={device.blaslt_supported}')
595595
zluda_installer.set_blaslt_enabled(device.blaslt_supported)
596-
zluda_installer.make_copy()
597596
zluda_installer.load()
598597
torch_command = os.environ.get('TORCH_COMMAND', 'pip install torch==2.6.0 torchvision --index-url https://download.pytorch.org/whl/cu118')
599598
print(f'Using ZLUDA in {zluda_installer.path}')

modules/rocm.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,12 @@ def get_version() -> str:
172172
return f'{arr[0]}.{arr[1]}' if len(arr) >= 2 else None
173173

174174
def get_agents() -> List[Agent]:
175-
if is_wsl: # WSL does not have 'rocm_agent_enumerator'
176-
agents = spawn("rocminfo").split("\n")
177-
agents = [x.strip().split(" ")[-1] for x in agents if x.startswith(' Name:') and "CPU" not in x]
178-
else:
175+
try:
179176
agents = spawn("rocm_agent_enumerator").split("\n")
180177
agents = [x for x in agents if x and x != 'gfx000']
178+
except Exception: # old version of ROCm WSL doesn't have rocm_agent_enumerator
179+
agents = spawn("rocminfo").split("\n")
180+
agents = [x.strip().split(" ")[-1] for x in agents if x.startswith(' Name:') and "CPU" not in x]
181181
return [Agent(x) for x in agents]
182182

183183
def load_hsa_runtime() -> None:
@@ -204,7 +204,9 @@ def get_blaslt_enabled() -> bool:
204204
def get_flash_attention_command(agent: Agent):
205205
default = "git+https://github.com/ROCm/flash-attention"
206206
if agent.gfx_version >= 0x1100 and os.environ.get("FLASH_ATTENTION_USE_TRITON_ROCM", "false").lower() != "true":
207-
default = "git+https://github.com/ROCm/flash-attention@howiejay/navi_support"
207+
# use the navi_rotary_fix fork because the original doesn't support rotary_emb for transformers
208+
# original: "git+https://github.com/ROCm/flash-attention@howiejay/navi_support"
209+
default = "https://github.com/Disty0/flash-attention@navi_rotary_fix"
208210
return os.environ.get("FLASH_ATTENTION_PACKAGE", default)
209211

210212
is_wsl: bool = os.environ.get('WSL_DISTRO_NAME', 'unknown' if spawn('wslpath -w /') else None) is not None

modules/zluda.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def initialize_zluda():
3333
from modules.zluda_hijacks import do_hijack
3434
do_hijack()
3535

36-
torch.backends.cudnn.enabled = zluda_installer.MIOpen_available
37-
if not zluda_installer.MIOpen_available:
36+
torch.backends.cudnn.enabled = zluda_installer.MIOpen_enabled
37+
if not zluda_installer.MIOpen_enabled:
3838
torch.backends.cuda.enable_cudnn_sdp(False)
3939
torch.backends.cuda.enable_cudnn_sdp = do_nothing
4040
torch.backends.cuda.enable_flash_sdp(False)

modules/zluda_hijacks.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from functools import wraps
22
import torch
33
import torch._dynamo.device_interface
4-
from modules import rocm, zluda, shared
4+
from modules import shared, zluda # pylint: disable=unused-import
55

66

77
MEM_BUS_WIDTH = {
@@ -13,11 +13,33 @@
1313
"AMD Radeon RX 7900 GRE": 256,
1414
"AMD Radeon RX 7800 XT": 256,
1515
"AMD Radeon RX 7700 XT": 192,
16+
"AMD Radeon RX 7700": 192,
17+
"AMD Radeon RX 7650 GRE": 128,
1618
"AMD Radeon RX 7600 XT": 128,
1719
"AMD Radeon RX 7600": 128,
20+
"AMD Radeon RX 7500 XT": 96,
21+
"AMD Radeon RX 6950 XT": 256,
22+
"AMD Radeon RX 6900 XT": 256,
23+
"AMD Radeon RX 6800 XT": 256,
24+
"AMD Radeon RX 6800": 256,
25+
"AMD Radeon RX 6750 XT": 192,
26+
"AMD Radeon RX 6700 XT": 192,
27+
"AMD Radeon RX 6700": 160,
28+
"AMD Radeon RX 6650 XT": 128,
29+
"AMD Radeon RX 6600 XT": 128,
30+
"AMD Radeon RX 6600": 128,
31+
"AMD Radeon RX 6500 XT": 64,
32+
"AMD Radeon RX 6400": 64,
1833
}
1934

2035

36+
_topk = torch.topk
37+
def topk(input: torch.Tensor, *args, **kwargs): # pylint: disable=redefined-builtin
38+
device = input.device
39+
values, indices = _topk(input.cpu(), *args, **kwargs)
40+
return torch.return_types.topk((values.to(device), indices.to(device),))
41+
42+
2143
class DeviceProperties:
2244
PROPERTIES_OVERRIDE = {"regs_per_multiprocessor": 65535, "gcnArchName": "UNKNOWN ARCHITECTURE"}
2345
internal: torch._C._CudaDeviceProperties
@@ -42,8 +64,7 @@ def torch__C__cuda_getCurrentRawStream(device):
4264

4365

4466
def do_hijack():
45-
torch.version.hip = rocm.version
46-
67+
torch.topk = topk
4768
if zluda.default_agent is not None:
4869
DeviceProperties.PROPERTIES_OVERRIDE["gcnArchName"] = zluda.default_agent.name
4970
torch.cuda._get_device_properties = torch_cuda__get_device_properties # pylint: disable=protected-access

modules/zluda_installer.py

Lines changed: 22 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import urllib.request
88
from typing import Union
99
from modules import rocm
10+
from modules.launch_utils import args
1011

1112

1213
DLL_MAPPING = {
@@ -18,15 +19,12 @@
1819
}
1920
HIPSDK_TARGETS = ['rocblas.dll', 'rocsolver.dll', 'hipfft.dll',]
2021

21-
hipBLASLt_available = False
22-
MIOpen_available = False
22+
MIOpen_enabled = False
2323

2424
path = os.path.abspath(os.environ.get('ZLUDA', '.zluda'))
2525
default_agent: Union[rocm.Agent, None] = None
2626
hipBLASLt_enabled = False
2727

28-
nightly = os.environ.get("ZLUDA_NIGHTLY", "0") == "1"
29-
3028

3129
class ZLUDAResult(ctypes.Structure):
3230
_fields_ = [
@@ -66,42 +64,25 @@ def get_nightly_flag(self) -> int:
6664
ml = None
6765

6866

69-
def load_core_modules():
70-
global core, ml # pylint: disable=global-statement
71-
core = Core(ctypes.windll.LoadLibrary(os.path.join(path, 'nvcuda.dll')))
72-
ml = ZLUDALibrary(ctypes.windll.LoadLibrary(os.path.join(path, 'nvml.dll')))
73-
74-
7567
def set_default_agent(agent: rocm.Agent):
7668
global default_agent # pylint: disable=global-statement
7769
default_agent = agent
7870

79-
is_nightly = False
80-
try:
81-
load_core_modules()
82-
is_nightly = core.get_nightly_flag() == 1
83-
except Exception:
84-
pass
85-
86-
global hipBLASLt_available, hipBLASLt_enabled # pylint: disable=global-statement
87-
hipBLASLt_available = is_nightly and os.path.exists(rocm.blaslt_tensile_libpath)
88-
hipBLASLt_enabled = hipBLASLt_available and os.path.exists(os.path.join(rocm.path, "bin", "hipblaslt.dll"))
89-
90-
global MIOpen_available # pylint: disable=global-statement
91-
MIOpen_available = is_nightly and os.path.exists(os.path.join(rocm.path, "bin", "MIOpen.dll"))
92-
9371

9472
def is_reinstall_needed() -> bool: # ZLUDA<3.8.7
9573
return not os.path.exists(os.path.join(path, 'cufftw.dll'))
9674

9775

98-
def install() -> None:
76+
def install():
9977
if os.path.exists(path):
10078
return
10179

10280
platform = "windows"
103-
commit = os.environ.get("ZLUDA_HASH", "ae0540beb129ffd140226ce956b386619b38f84c")
104-
if nightly:
81+
commit = os.environ.get("ZLUDA_HASH", "dba64c0966df2c71e82255e942c96e2e1cea3a2d")
82+
if os.environ.get("ZLUDA_NIGHTLY", "0") == "1":
83+
print("Warning: Environment variable 'ZLUDA_NIGHTLY' will be removed. Please use command-line argument '--use-nightly' instead.")
84+
args.use_nightly = True
85+
if args.use_nightly:
10586
platform = "nightly-" + platform
10687
urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/rel.{commit}/ZLUDA-{platform}-rocm{rocm.version[0]}-amd64.zip', '_zluda')
10788
with zipfile.ZipFile('_zluda', 'r') as archive:
@@ -113,7 +94,7 @@ def install() -> None:
11394
os.remove('_zluda')
11495

11596

116-
def uninstall() -> None:
97+
def uninstall():
11798
if os.path.exists(path):
11899
shutil.rmtree(path)
119100

@@ -137,23 +118,29 @@ def link_or_copy(src: os.PathLike, dst: os.PathLike):
137118
shutil.copyfile(src, dst)
138119

139120

140-
def make_copy() -> None:
121+
def load():
122+
global core, ml, hipBLASLt_enabled, MIOpen_enabled # pylint: disable=global-statement
123+
core = Core(ctypes.windll.LoadLibrary(os.path.join(path, 'nvcuda.dll')))
124+
ml = ZLUDALibrary(ctypes.windll.LoadLibrary(os.path.join(path, 'nvml.dll')))
125+
is_nightly = core.get_nightly_flag() == 1
126+
hipBLASLt_enabled = is_nightly and os.path.exists(rocm.blaslt_tensile_libpath) and os.path.exists(os.path.join(rocm.path, "bin", "hipblaslt.dll"))
127+
MIOpen_enabled = is_nightly and os.path.exists(os.path.join(rocm.path, "bin", "MIOpen.dll"))
128+
141129
for k, v in DLL_MAPPING.items():
142130
if not os.path.exists(os.path.join(path, v)):
143131
link_or_copy(os.path.join(path, k), os.path.join(path, v))
144132

145133
if hipBLASLt_enabled and not os.path.exists(os.path.join(path, 'cublasLt64_11.dll')):
146134
link_or_copy(os.path.join(path, 'cublasLt.dll'), os.path.join(path, 'cublasLt64_11.dll'))
147135

148-
if MIOpen_available and not os.path.exists(os.path.join(path, 'cudnn64_9.dll')):
136+
if MIOpen_enabled and not os.path.exists(os.path.join(path, 'cudnn64_9.dll')):
149137
link_or_copy(os.path.join(path, 'cudnn.dll'), os.path.join(path, 'cudnn64_9.dll'))
150138

139+
print(f"ZLUDA load: path='{path}' nightly={bool(core.get_nightly_flag())}")
151140

152-
def load() -> None:
153141
os.environ["ZLUDA_COMGR_LOG_LEVEL"] = "1"
154142
os.environ["ZLUDA_NVRTC_LIB"] = os.path.join([v for v in site.getsitepackages() if v.endswith("site-packages")][0], "torch", "lib", "nvrtc64_112_0.dll")
155143

156-
load_core_modules()
157144
for v in HIPSDK_TARGETS:
158145
ctypes.windll.LoadLibrary(os.path.join(rocm.path, 'bin', v))
159146
for v in DLL_MAPPING.values():
@@ -166,12 +153,13 @@ def load() -> None:
166153
else:
167154
os.environ["DISABLE_ADDMM_CUDA_LT"] = "1"
168155

169-
if MIOpen_available:
156+
if MIOpen_enabled:
170157
ctypes.windll.LoadLibrary(os.path.join(rocm.path, 'bin', 'MIOpen.dll'))
171158
ctypes.windll.LoadLibrary(os.path.join(path, 'cudnn64_9.dll'))
172159

173160
def conceal():
174-
import torch # noqa: F401
161+
import torch
162+
torch.version.hip = rocm.version
175163
platform = sys.platform
176164
sys.platform = ""
177165
from torch.utils import cpp_extension

0 commit comments

Comments
 (0)