Skip to content

Commit 0e2ddd9

Browse files
committed
Support Navi 4x graphics cards, gfx 12
1 parent 8f65813 commit 0e2ddd9

2 files changed

Lines changed: 17 additions & 2 deletions

File tree

tests/test_configuration.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,16 @@ def clean_env():
3939
os.environ[var] = original_values[var]
4040

4141

42+
def test_rocm_navi_4_settings():
43+
gpus = [create_gpu_info(AMD, "123", "Navi 44 XTX", True)]
44+
configure(gpus, "nightly/rocm6.4")
45+
46+
assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "12.0.0"
47+
assert os.environ.get("HIP_VISIBLE_DEVICES") == "0"
48+
assert "ROC_ENABLE_PRE_VEGA" not in os.environ
49+
assert "PYTORCH_ENABLE_MPS_FALLBACK" not in os.environ
50+
51+
4252
def test_rocm_navi_3_settings():
4353
gpus = [create_gpu_info(AMD, "123", "Navi 31 XTX", True)]
4454
configure(gpus, "rocm6.2")

torchruntime/configuration.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def configure():
1313

1414

1515
def _configure_internal(gpu_infos, torch_platform):
16-
if torch_platform.startswith("rocm"):
16+
if "rocm" in torch_platform:
1717
check_rocm_permissions()
1818
set_rocm_env_vars(gpu_infos, torch_platform)
1919
elif os_name == "Darwin":
@@ -62,6 +62,7 @@ def _set_rocm_vars_for_discrete(gpu_infos, all_gpu_info):
6262
# past settings from: https://github.com/easydiffusion/easydiffusion/blob/20d77a85a1ed766ece0cc4b6a55dca003bce262c/scripts/check_modules.py#L405-L420
6363

6464
# Determine GPU generations present
65+
has_navi4 = any("Navi 4" in device_name for device_name in device_names) # RX 9000 series
6566
has_navi3 = any("Navi 3" in device_name for device_name in device_names) # RX 7000 series
6667
has_navi2 = any("Navi 2" in device_name for device_name in device_names) # RX 6000 series
6768
has_navi1 = any("Navi 1" in device_name for device_name in device_names) # RX 5000 series
@@ -70,7 +71,11 @@ def _set_rocm_vars_for_discrete(gpu_infos, all_gpu_info):
7071
has_ellesmere = any("Ellesmere" in device_name for device_name in device_names) # RX 570/580/Polaris etc
7172

7273
# Select GPU generation settings based on priority
73-
if has_navi3:
74+
if has_navi4:
75+
env["HSA_OVERRIDE_GFX_VERSION"] = "12.0.0"
76+
# Find the index of the first Navi 4 GPU
77+
env["HIP_VISIBLE_DEVICES"] = _visible_device_ids(all_gpu_info, "Navi 4")
78+
elif has_navi3:
7479
env["HSA_OVERRIDE_GFX_VERSION"] = "11.0.0"
7580
# Find the index of the first Navi 3 GPU
7681
env["HIP_VISIBLE_DEVICES"] = _visible_device_ids(all_gpu_info, "Navi 3")

0 commit comments

Comments
 (0)