Skip to content

Commit 0dba0f5

Browse files
committed
add therock support
1 parent ae5ea24 commit 0dba0f5

7 files changed

Lines changed: 365 additions & 148 deletions

File tree

modules/launch_utils.py

Lines changed: 53 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434

3535
def check_python_version():
36-
is_windows = platform.system() == "Windows"
36+
is_windows = sys.platform == "win32"
3737
major = sys.version_info.major
3838
minor = sys.version_info.minor
3939
micro = sys.version_info.micro
@@ -427,16 +427,15 @@ def requirements_met(requirements_file):
427427
def prepare_environment():
428428
from modules import rocm
429429

430-
system = platform.system()
431430
nvidia_driver_found = False
432-
backend = "cuda"
433-
torch_command = "pip install torch==2.7.0 torchvision numpy==1.26.4 --extra-index-url https://download.pytorch.org/whl/cu121"
431+
backend = "unknown"
432+
torch_command = "pip install torch==2.7.1 torchvision numpy==1.26.4"
434433

435434
if args.use_cpu_torch:
436435
backend = "cpu"
437436
torch_command = os.environ.get(
438437
"TORCH_COMMAND",
439-
"pip install torch==2.7.0 torchvision numpy==1.26.4",
438+
"pip install torch==2.7.1 torchvision numpy==1.26.4",
440439
)
441440
elif args.use_directml:
442441
backend = "directml"
@@ -450,7 +449,7 @@ def prepare_environment():
450449
backend = "zluda"
451450
elif args.use_ipex:
452451
backend = "ipex"
453-
if system == "Windows":
452+
if sys.platform == "win32":
454453
# The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main
455454
# This is NOT an Intel official release so please use it at your own risk!!
456455
# See https://github.com/Nuullll/intel-extension-for-pytorch/releases/tag/v2.0.110%2Bxpu-master%2Bdll-bundle for details.
@@ -469,19 +468,6 @@ def prepare_environment():
469468
# See https://intel.github.io/intel-extension-for-pytorch/index.html#installation for details.
470469
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/")
471470
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
472-
elif rocm.is_installed:
473-
if system == "Windows": # ZLUDA
474-
args.use_zluda = True
475-
backend = "zluda"
476-
else:
477-
backend = "rocm"
478-
torch_index_url = os.environ.get(
479-
"TORCH_INDEX_URL", "https://download.pytorch.org/whl/rocm6.3"
480-
)
481-
torch_command = os.environ.get(
482-
"TORCH_COMMAND",
483-
f"pip install torch==2.7.0 torchvision numpy==1.26.4 --extra-index-url {torch_index_url}",
484-
)
485471
else:
486472
nvidia_driver_found = shutil.which("nvidia-smi") is not None
487473
if nvidia_driver_found:
@@ -492,7 +478,7 @@ def prepare_environment():
492478
)
493479
torch_command = os.environ.get(
494480
"TORCH_COMMAND",
495-
f"pip install torch==2.7.0 torchvision numpy==1.26.4 --extra-index-url {torch_index_url}",
481+
f"pip install torch==2.7.1 torchvision numpy==1.26.4 --extra-index-url {torch_index_url}",
496482
)
497483

498484
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
@@ -537,37 +523,60 @@ def prepare_environment():
537523
if args.skip_torch_cuda_test:
538524
print("WARNING: you should not skip torch test unless you want CPU to work.")
539525

540-
if backend in ("rocm", "zluda",):
541-
device = None
526+
device = None
527+
if backend in ("rocm", "zluda", "unknown"):
528+
amd_gpus = []
542529
try:
543530
amd_gpus = rocm.get_agents()
544-
if len(amd_gpus) == 0:
545-
print('ROCm: no agent was found')
546-
else:
547-
print(f'ROCm: agents={[gpu.name for gpu in amd_gpus]}')
548-
if args.device_id is None:
549-
index = 0
550-
for idx, gpu in enumerate(amd_gpus):
551-
index = idx
552-
if not gpu.is_apu:
553-
# although apu was found, there can be a dedicated card. do not break loop.
554-
# if no dedicated card was found, apu will be used.
555-
break
556-
os.environ.setdefault('HIP_VISIBLE_DEVICES', str(index))
557-
device = amd_gpus[index]
558-
else:
559-
device_id = int(args.device_id)
560-
if device_id < len(amd_gpus):
561-
device = amd_gpus[device_id]
531+
print('ROCm: AMD toolkit detected')
562532
except Exception as e:
563533
print(f'ROCm agent enumerator failed: {e}')
564534

535+
if len(amd_gpus) == 0:
536+
if args.use_rocm or args.use_zluda:
537+
print('No ROCm agent was found. Please make sure that graphics driver is installed and up to date.')
538+
backend = "cpu"
539+
else:
540+
print(f'ROCm: agents={[gpu.name for gpu in amd_gpus]}')
541+
if args.device_id is None:
542+
index = 0
543+
for idx, gpu in enumerate(amd_gpus):
544+
index = idx
545+
if not gpu.is_apu:
546+
# although apu was found, there can be a dedicated card. do not break loop.
547+
# if no dedicated card was found, apu will be used.
548+
break
549+
os.environ.setdefault('HIP_VISIBLE_DEVICES', str(index))
550+
device = amd_gpus[index]
551+
else:
552+
device_id = int(args.device_id)
553+
if device_id < len(amd_gpus):
554+
device = amd_gpus[device_id]
555+
556+
if backend != "zluda":
557+
backend = "rocm"
558+
559+
if backend in ("rocm", "zluda"):
560+
assert device is not None
561+
562+
if sys.platform == "win32" and backend == "rocm":
563+
if device.therock is None:
564+
backend = "zluda"
565+
else:
566+
run_pip(f"install rocm rocm-sdk-core --index-url https://rocm.nightlies.amd.com/v2-staging/{device.therock}", "rocm")
567+
rocm.refresh()
568+
565569
msg = f'ROCm: version={rocm.version}'
566570
if device is not None:
567571
msg += f', using agent {device.name}'
568572
print(msg)
569573

570-
if system == "Windows":
574+
if backend == "rocm":
575+
if isinstance(rocm.environment, rocm.PythonPackageEnvironment):
576+
torch_command = os.environ.get('TORCH_COMMAND', f'pip install torch torchvision --index-url https://rocm.nightlies.amd.com/v2-staging/{device.therock}')
577+
else:
578+
torch_command = os.environ.get('TORCH_COMMAND', 'pip install --no-cache-dir https://repo.radeon.com/rocm/windows/rocm-rel-6.4.4/torch-2.8.0a0%2Bgitfc14c65-cp312-cp312-win_amd64.whl https://repo.radeon.com/rocm/windows/rocm-rel-6.4.4/torchvision-0.24.0a0%2Bc85f008-cp312-cp312-win_amd64.whl')
579+
else:
571580
if args.device_id is not None:
572581
if os.environ.get('HIP_VISIBLE_DEVICES', None) is not None:
573582
print('Setting HIP_VISIBLE_DEVICES and --device-id at the same time may be mistake.')
@@ -589,16 +598,16 @@ def prepare_environment():
589598
if error is None:
590599
try:
591600
zluda_installer.load()
592-
torch_command = os.environ.get('TORCH_COMMAND', 'pip install torch==2.7.0 torchvision numpy==1.26.4 --extra-index-url https://download.pytorch.org/whl/cu118')
601+
torch_command = os.environ.get('TORCH_COMMAND', 'pip install torch==2.7.1 torchvision numpy==1.26.4 --extra-index-url https://download.pytorch.org/whl/cu118')
593602
except Exception as e:
594603
error = e
595604
print(f'Failed to load ZLUDA: {e}')
596605
if error is not None:
597606
print('Using CPU-only torch')
598-
torch_command = os.environ.get('TORCH_COMMAND', 'pip install torch==2.7.0 torchvision numpy==1.26.4')
607+
torch_command = os.environ.get('TORCH_COMMAND', 'pip install torch==2.7.1 torchvision numpy==1.26.4')
599608

600609
if rocm.is_wsl:
601-
rocm.load_hsa_runtime()
610+
rocm.postinstall()
602611

603612
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
604613
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
@@ -607,9 +616,6 @@ def prepare_environment():
607616
if args.use_ipex or args.use_directml or args.use_cpu_torch:
608617
args.skip_torch_cuda_test = True
609618

610-
if rocm.is_installed:
611-
rocm.conceal()
612-
613619
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
614620
raise RuntimeError(
615621
'Torch is not able to use GPU; '

0 commit comments

Comments
 (0)