@@ -13,7 +13,7 @@ def configure():
1313
1414
1515def _configure_internal (gpu_infos , torch_platform ):
16- if torch_platform . startswith ( "rocm" ) :
16+ if "rocm" in torch_platform :
1717 check_rocm_permissions ()
1818 set_rocm_env_vars (gpu_infos , torch_platform )
1919 elif os_name == "Darwin" :
@@ -62,6 +62,7 @@ def _set_rocm_vars_for_discrete(gpu_infos, all_gpu_info):
6262 # past settings from: https://github.com/easydiffusion/easydiffusion/blob/20d77a85a1ed766ece0cc4b6a55dca003bce262c/scripts/check_modules.py#L405-L420
6363
6464 # Determine GPU generations present
65+ has_navi4 = any ("Navi 4" in device_name for device_name in device_names ) # RX 9000 series
6566 has_navi3 = any ("Navi 3" in device_name for device_name in device_names ) # RX 7000 series
6667 has_navi2 = any ("Navi 2" in device_name for device_name in device_names ) # RX 6000 series
6768 has_navi1 = any ("Navi 1" in device_name for device_name in device_names ) # RX 5000 series
@@ -70,7 +71,11 @@ def _set_rocm_vars_for_discrete(gpu_infos, all_gpu_info):
7071 has_ellesmere = any ("Ellesmere" in device_name for device_name in device_names ) # RX 570/580/Polaris etc
7172
7273 # Select GPU generation settings based on priority
73- if has_navi3 :
74+ if has_navi4 :
75+ env ["HSA_OVERRIDE_GFX_VERSION" ] = "12.0.0"
76+ # Find the index of the first Navi 4 GPU
77+ env ["HIP_VISIBLE_DEVICES" ] = _visible_device_ids (all_gpu_info , "Navi 4" )
78+ elif has_navi3 :
7479 env ["HSA_OVERRIDE_GFX_VERSION" ] = "11.0.0"
7580 # Find the index of the first Navi 3 GPU
7681 env ["HIP_VISIBLE_DEVICES" ] = _visible_device_ids (all_gpu_info , "Navi 3" )
0 commit comments