Skip to content

Commit 349751f

Browse files
author
Godnight1006
committed
Add build segmentation flags and XPU index selection
1 parent 8ae6aae commit 349751f

8 files changed

Lines changed: 173 additions & 40 deletions

File tree

API.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ Or you can use the library:
1313
torchruntime.install(["torch", "torchvision<0.20"])
1414
```
1515

16+
Optional flags:
17+
- `preview=True` to allow preview builds (e.g. ROCm 6.4, nightly builds, XPU test index)
18+
- `unsupported=False` to forbid EOL/unsupported builds (errors instead of installing older builds)
19+
1620
On Windows CUDA, Linux ROCm (6.x+), and Linux XPU, this also installs the appropriate Triton package to enable `torch.compile` (`triton-windows`, `pytorch-triton-rocm`, or `pytorch-triton-xpu`).
1721

1822
## Test torch

tests/test_installer.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,26 @@ def test_cuda_platform_windows_installs_triton(monkeypatch):
3535
def test_cuda_nightly_platform_linux(monkeypatch):
3636
monkeypatch.setattr("torchruntime.installer.os_name", "Linux")
3737
packages = ["torch", "torchvision"]
38-
result = get_install_commands("nightly/cu112", packages)
38+
result = get_install_commands("nightly/cu112", packages, preview=True)
3939
expected_url = "https://download.pytorch.org/whl/nightly/cu112"
4040
assert result == [packages + ["--index-url", expected_url]]
4141

4242

4343
def test_cuda_nightly_platform_windows_installs_triton(monkeypatch):
4444
monkeypatch.setattr("torchruntime.installer.os_name", "Windows")
4545
packages = ["torch", "torchvision"]
46-
result = get_install_commands("nightly/cu112", packages)
46+
result = get_install_commands("nightly/cu112", packages, preview=True)
4747
expected_url = "https://download.pytorch.org/whl/nightly/cu112"
4848
assert result == [packages + ["--index-url", expected_url], ["triton-windows"]]
4949

5050

51+
def test_cuda_nightly_platform_requires_preview(monkeypatch):
52+
monkeypatch.setattr("torchruntime.installer.os_name", "Linux")
53+
packages = ["torch", "torchvision"]
54+
with pytest.raises(ValueError, match="preview"):
55+
get_install_commands("nightly/cu112", packages, preview=False)
56+
57+
5158
def test_rocm_4_platform_does_not_install_triton(monkeypatch):
5259
monkeypatch.setattr("torchruntime.installer.os_name", "Linux")
5360
packages = ["torch", "torchvision"]
@@ -71,25 +78,43 @@ def test_rocm_6_platform_linux_installs_triton(monkeypatch):
7178
def test_xpu_platform_windows_with_torch_only(monkeypatch):
7279
monkeypatch.setattr("torchruntime.installer.os_name", "Windows")
7380
packages = ["torch"]
74-
result = get_install_commands("xpu", packages)
75-
expected_url = "https://download.pytorch.org/whl/test/xpu"
81+
result = get_install_commands("xpu", packages, preview=False)
82+
expected_url = "https://download.pytorch.org/whl/xpu"
83+
assert result == [packages + ["--index-url", expected_url]]
84+
85+
86+
def test_xpu_platform_windows_with_torchvision(monkeypatch):
87+
monkeypatch.setattr("torchruntime.installer.os_name", "Windows")
88+
packages = ["torch", "torchvision"]
89+
result = get_install_commands("xpu", packages, preview=False)
90+
expected_url = "https://download.pytorch.org/whl/xpu"
7691
assert result == [packages + ["--index-url", expected_url]]
7792

7893

79-
def test_xpu_platform_windows_with_torchvision(monkeypatch, capsys):
94+
def test_xpu_platform_windows_preview(monkeypatch):
8095
monkeypatch.setattr("torchruntime.installer.os_name", "Windows")
8196
packages = ["torch", "torchvision"]
82-
result = get_install_commands("xpu", packages)
83-
expected_url = "https://download.pytorch.org/whl/nightly/xpu"
97+
result = get_install_commands("xpu", packages, preview=True)
98+
expected_url = "https://download.pytorch.org/whl/test/xpu"
8499
assert result == [packages + ["--index-url", expected_url]]
85-
captured = capsys.readouterr()
86-
assert "[WARNING]" in captured.out
87100

88101

89102
def test_xpu_platform_linux(monkeypatch):
90103
monkeypatch.setattr("torchruntime.installer.os_name", "Linux")
91104
packages = ["torch", "torchvision"]
92-
result = get_install_commands("xpu", packages)
105+
result = get_install_commands("xpu", packages, preview=False)
106+
expected_url = "https://download.pytorch.org/whl/xpu"
107+
triton_index_url = "https://download.pytorch.org/whl"
108+
assert result == [
109+
packages + ["--index-url", expected_url],
110+
["pytorch-triton-xpu", "--index-url", triton_index_url],
111+
]
112+
113+
114+
def test_xpu_platform_linux_preview(monkeypatch):
115+
monkeypatch.setattr("torchruntime.installer.os_name", "Linux")
116+
packages = ["torch", "torchvision"]
117+
result = get_install_commands("xpu", packages, preview=True)
93118
expected_url = "https://download.pytorch.org/whl/test/xpu"
94119
triton_index_url = "https://download.pytorch.org/whl"
95120
assert result == [

tests/test_platform_detection.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_amd_gpu_navi4_linux(monkeypatch):
3636
with pytest.raises(NotImplementedError):
3737
get_torch_platform(gpu_infos)
3838
else:
39-
assert get_torch_platform(gpu_infos) == "rocm6.4"
39+
assert get_torch_platform(gpu_infos) == "rocm6.2"
4040

4141

4242
def test_amd_gpu_navi3_linux(monkeypatch, capsys):
@@ -89,6 +89,14 @@ def test_amd_gpu_ellesmere_linux(monkeypatch):
8989
assert get_torch_platform(gpu_infos) == "rocm4.2"
9090

9191

92+
def test_amd_gpu_ellesmere_linux_unsupported_false_raises(monkeypatch):
93+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
94+
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
95+
gpu_infos = [GPU(AMD, "AMD", 0x1234, "Ellesmere", True)]
96+
with pytest.raises(ValueError, match="End-of-Life"):
97+
get_torch_platform(gpu_infos, unsupported=False)
98+
99+
92100
def test_amd_gpu_unsupported_linux(monkeypatch, capsys):
93101
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
94102
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")

tests/test_segmentation.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import pytest
2+
from torchruntime.device_db import GPU
3+
from torchruntime.platform_detection import AMD, INTEL, NVIDIA, get_torch_platform, py_version
4+
5+
6+
def test_preview_rocm_6_4_selection(monkeypatch):
7+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
8+
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
9+
gpu_infos = [GPU(AMD, "AMD", 0x1234, "Navi 41", True)]
10+
11+
if py_version < (3, 9):
12+
pytest.skip("Navi 4 requires Python 3.9+")
13+
14+
# Default: preview=False -> rocm6.2
15+
assert get_torch_platform(gpu_infos) == "rocm6.2"
16+
assert get_torch_platform(gpu_infos, preview=False) == "rocm6.2"
17+
18+
# preview=True -> rocm6.4
19+
assert get_torch_platform(gpu_infos, preview=True) == "rocm6.4"
20+
21+
22+
def test_eol_cu118_selection(monkeypatch):
23+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
24+
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
25+
# Kepler architecture (e.g. GTX 780)
26+
gpu_infos = [GPU(NVIDIA, "NVIDIA", "1004", "GK110 [GeForce GTX 780]", True)]
27+
28+
# Default: unsupported=True -> cu118
29+
assert get_torch_platform(gpu_infos) == "cu118"
30+
assert get_torch_platform(gpu_infos, unsupported=True) == "cu118"
31+
32+
# unsupported=False -> raises ValueError
33+
with pytest.raises(ValueError, match="considered End-of-Life"):
34+
get_torch_platform(gpu_infos, unsupported=False)
35+
36+
37+
def test_eol_rocm42_selection(monkeypatch):
38+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
39+
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
40+
# Ellesmere (e.g. RX 580)
41+
gpu_infos = [GPU(AMD, "AMD", "67df", "Ellesmere [Radeon RX 470/480/570/570X/580/580X/590]", True)]
42+
43+
# Default: unsupported=True -> rocm4.2
44+
assert get_torch_platform(gpu_infos) == "rocm4.2"
45+
46+
# unsupported=False -> raises ValueError
47+
with pytest.raises(ValueError, match="considered End-of-Life"):
48+
get_torch_platform(gpu_infos, unsupported=False)
49+
50+
51+
def test_eol_directml_selection(monkeypatch):
52+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
53+
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
54+
gpu_infos = [GPU(AMD, "AMD", 0x1234, "Radeon", True)]
55+
56+
assert get_torch_platform(gpu_infos) == "directml"
57+
58+
with pytest.raises(ValueError, match="considered End-of-Life"):
59+
get_torch_platform(gpu_infos, unsupported=False)
60+
61+
62+
def test_eol_ipex_selection(monkeypatch):
63+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
64+
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
65+
monkeypatch.setattr("torchruntime.platform_detection.py_version", (3, 8))
66+
gpu_infos = [GPU(INTEL, "Intel", 0x1234, "Iris", True)]
67+
68+
assert get_torch_platform(gpu_infos) == "ipex"
69+
70+
# unsupported=False -> raises ValueError
71+
with pytest.raises(ValueError, match="considered End-of-Life"):
72+
get_torch_platform(gpu_infos, unsupported=False)

torchruntime/__main__.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ def print_usage(entry_command: str):
1616
Examples:
1717
{entry_command} install
1818
{entry_command} install --uv
19+
{entry_command} install --preview
20+
{entry_command} install --no-unsupported
1921
{entry_command} install torch==2.2.0 torchvision==0.17.0
2022
{entry_command} install --uv torch>=2.0.0 torchaudio
2123
{entry_command} install torch==2.1.* torchvision>=0.16.0 torchaudio==2.1.0
@@ -35,6 +37,8 @@ def print_usage(entry_command: str):
3537
3638
Options:
3739
--uv Use uv instead of pip for installation
40+
--preview Allow preview builds (e.g. ROCm 6.4)
41+
--no-unsupported Forbid EOL/unsupported builds (e.g. CUDA 11.8)
3842
3943
Version specification formats (follows pip format):
4044
package==2.1.0 Exact version
@@ -63,14 +67,20 @@ def main():
6367
if command == "install":
6468
args = sys.argv[2:] if len(sys.argv) > 2 else []
6569
use_uv = "--uv" in args
66-
# Remove --uv from args to get package list
67-
package_versions = [arg for arg in args if arg != "--uv"] if args else None
68-
install(package_versions, use_uv=use_uv)
70+
preview = "--preview" in args
71+
unsupported = "--no-unsupported" not in args
72+
# Remove flags from args to get package list
73+
package_versions = [arg for arg in args if arg not in ("--uv", "--preview", "--no-unsupported")] if args else None
74+
install(package_versions, use_uv=use_uv, preview=preview, unsupported=unsupported)
6975
elif command == "test":
7076
subcommand = sys.argv[2] if len(sys.argv) > 2 else "all"
7177
test(subcommand)
7278
elif command == "info":
73-
info()
79+
args = sys.argv[2:] if len(sys.argv) > 2 else []
80+
preview = "--preview" in args
81+
unsupported = "--no-unsupported" not in args
82+
from .utils import info
83+
info(preview=preview, unsupported=unsupported)
7484
else:
7585
print(f"Unknown command: {command}")
7686
entry_path = sys.argv[0]

torchruntime/installer.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,17 @@
22
import sys
33
import platform
44
import subprocess
5-
6-
from .consts import CONTACT_LINK
75
from .device_db import get_gpus
86
from .platform_detection import get_torch_platform
97

108
os_name = platform.system()
119

12-
PIP_PREFIX = [sys.executable, "-m", "pip", "install"]
1310
CUDA_REGEX = re.compile(r"^(nightly/)?cu\d+$")
1411
ROCM_REGEX = re.compile(r"^(nightly/)?rocm\d+\.\d+$")
1512
ROCM_VERSION_REGEX = re.compile(r"^(?:nightly/)?rocm(?P<major>\d+)\.(?P<minor>\d+)$")
1613

1714

18-
def get_install_commands(torch_platform, packages):
15+
def get_install_commands(torch_platform, packages, preview=False):
1916
"""
2017
Generates pip installation commands for PyTorch and related packages based on the specified platform.
2118
@@ -30,6 +27,7 @@ def get_install_commands(torch_platform, packages):
3027
packages (list of str): List of package names (and optionally versions in pip format). Examples:
3128
- ["torch", "torchvision"]
3229
- ["torch>=2.0", "torchaudio==0.16.0"]
30+
preview (bool): If True, allow preview/nightly builds. Defaults to False.
3331
3432
Returns:
3533
list of list of str: Each sublist contains a pip install command (excluding the `pip install` prefix).
@@ -41,7 +39,7 @@ def get_install_commands(torch_platform, packages):
4139
ValueError: If an unsupported platform is provided.
4240
4341
Notes:
44-
- For "xpu" on Windows, if torchvision or torchaudio are included, the function switches to nightly builds.
42+
- For "xpu", if preview is True, the function installs from the test index.
4543
- For "directml", the "torch-directml" package is returned as part of the installation commands.
4644
- For "ipex", the "intel-extension-for-pytorch" package is returned as part of the installation commands.
4745
- For Windows CUDA, the function also installs "triton-windows" (for torch.compile and Triton kernels).
@@ -54,6 +52,9 @@ def get_install_commands(torch_platform, packages):
5452
if torch_platform == "cpu":
5553
return [packages]
5654

55+
if torch_platform.startswith("nightly/") and not preview:
56+
raise ValueError("preview=True is required for nightly builds")
57+
5758
if CUDA_REGEX.match(torch_platform) or ROCM_REGEX.match(torch_platform):
5859
index_url = f"https://download.pytorch.org/whl/{torch_platform}"
5960
cmds = [packages + ["--index-url", index_url]]
@@ -69,15 +70,11 @@ def get_install_commands(torch_platform, packages):
6970
return cmds
7071

7172
if torch_platform == "xpu":
72-
if os_name == "Windows" and ("torchvision" in packages or "torchaudio" in packages):
73-
print(
74-
f"[WARNING] The preview build of 'xpu' on Windows currently only supports torch, not torchvision/torchaudio. "
75-
f"torchruntime will instead use the nightly build, to get the 'xpu' version of torchaudio and torchvision as well. "
76-
f"Please contact torchruntime if this is no longer accurate: {CONTACT_LINK}"
77-
)
78-
index_url = f"https://download.pytorch.org/whl/nightly/{torch_platform}"
79-
else:
80-
index_url = f"https://download.pytorch.org/whl/test/{torch_platform}"
73+
index_url = (
74+
f"https://download.pytorch.org/whl/test/{torch_platform}"
75+
if preview
76+
else f"https://download.pytorch.org/whl/{torch_platform}"
77+
)
8178

8279
cmds = [packages + ["--index-url", index_url]]
8380
if os_name == "Linux":
@@ -108,14 +105,16 @@ def run_commands(cmds):
108105
subprocess.run(cmd)
109106

110107

111-
def install(packages=[], use_uv=False):
108+
def install(packages=[], use_uv=False, preview=False, unsupported=True):
112109
"""
113110
packages: a list of strings with package names (and optionally their versions in pip-format). e.g. ["torch", "torchvision"] or ["torch>=2.0", "torchaudio==0.16.0"]. Defaults to ["torch", "torchvision", "torchaudio"].
114111
use_uv: bool, whether to use uv for installation. Defaults to False.
112+
preview: bool, whether to allow preview/nightly builds. Defaults to False.
113+
unsupported: bool, whether to allow EOL/unsupported builds. Defaults to True.
115114
"""
116115

117116
gpu_infos = get_gpus()
118-
torch_platform = get_torch_platform(gpu_infos, packages=packages)
119-
cmds = get_install_commands(torch_platform, packages)
117+
torch_platform = get_torch_platform(gpu_infos, packages=packages, preview=preview, unsupported=unsupported)
118+
cmds = get_install_commands(torch_platform, packages, preview=preview)
120119
cmds = get_pip_commands(cmds, use_uv=use_uv)
121120
run_commands(cmds)

torchruntime/platform_detection.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,15 @@ def _packages_require_cuda_12_4(packages):
5151
return False
5252

5353

54-
def get_torch_platform(gpu_infos, packages=[]):
54+
def get_torch_platform(gpu_infos, packages=[], preview=False, unsupported=True):
5555
"""
5656
Determine the appropriate PyTorch platform to use based on the system architecture, OS, and GPU information.
5757
5858
Args:
5959
gpu_infos (list of `torchruntime.device_db.GPU` instances)
6060
packages (list of str): Optional list of torch/torchvision/torchaudio requirement strings.
61+
preview (bool): If True, allow preview/nightly builds (e.g. rocm6.4). Defaults to False.
62+
unsupported (bool): If False, forbid EOL/unsupported builds (e.g. cu118). Defaults to True.
6163
6264
Returns:
6365
str: A string representing the platform to use. Possible values:
@@ -70,6 +72,7 @@ def get_torch_platform(gpu_infos, packages=[]):
7072
7173
Raises:
7274
NotImplementedError: For unsupported architectures, OS-GPU combinations, or multiple GPU vendors.
75+
ValueError: If unsupported=False and only an EOL build is available.
7376
Warning: Outputs warnings for deprecated Python versions or fallback configurations.
7477
"""
7578

@@ -95,12 +98,23 @@ def get_torch_platform(gpu_infos, packages=[]):
9598
integrated_devices.append(device)
9699

97100
if discrete_devices:
98-
return _get_platform_for_discrete(discrete_devices, packages=packages)
101+
platform = _get_platform_for_discrete(discrete_devices, packages=packages, preview=preview)
102+
else:
103+
platform = _get_platform_for_integrated(integrated_devices, preview=preview)
99104

100-
return _get_platform_for_integrated(integrated_devices)
105+
# Segmentation Logic
106+
EOL_PLATFORMS = {"cu118", "directml", "ipex", "rocm5.7", "rocm5.5", "rocm5.2", "rocm4.2"}
101107

108+
if not unsupported and platform in EOL_PLATFORMS:
109+
raise ValueError(
110+
f"The recommended platform '{platform}' is considered End-of-Life (EOL) and is forbidden because 'unsupported' is set to False. "
111+
f"Please use a more recent GPU or set unsupported=True to allow this installation."
112+
)
113+
114+
return platform
102115

103-
def _get_platform_for_discrete(gpu_infos, packages=None):
116+
117+
def _get_platform_for_discrete(gpu_infos, packages=None, preview=False):
104118
vendor_ids = set(gpu.vendor_id for gpu in gpu_infos)
105119

106120
if len(vendor_ids) > 1:
@@ -124,7 +138,7 @@ def _get_platform_for_discrete(gpu_infos, packages=None):
124138
raise NotImplementedError(
125139
f"Torch does not support Navi 4x series of GPUs on Python 3.8. Please switch to a newer Python version to use the latest version of torch!"
126140
)
127-
return "rocm6.4"
141+
return "rocm6.4" if preview else "rocm6.2"
128142
if any(device_name.startswith("Navi") for device_name in device_names) and any(
129143
device_name.startswith("Vega 2") for device_name in device_names
130144
): # lowest-common denominator is rocm5.7, which works with both Navi and Vega 20
@@ -201,7 +215,7 @@ def _get_platform_for_discrete(gpu_infos, packages=None):
201215
return "cpu"
202216

203217

204-
def _get_platform_for_integrated(gpu_infos):
218+
def _get_platform_for_integrated(gpu_infos, preview=False):
205219
gpu = gpu_infos[0]
206220

207221
if os_name == "Windows":
@@ -235,3 +249,4 @@ def _get_platform_for_integrated(gpu_infos):
235249
)
236250

237251
return "cpu"
252+

0 commit comments

Comments
 (0)