Skip to content

Commit 7f961ad

Browse files
author
Godnight1006
committed
Move CUDA package-based demotion to platform detection
1 parent 5327187 commit 7f961ad

4 files changed

Lines changed: 114 additions & 153 deletions

File tree

tests/test_installer.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import sys
22
import pytest
33
from unittest.mock import patch
4-
from torchruntime.installer import get_install_commands, get_pip_commands, run_commands, install
4+
from torchruntime.installer import get_install_commands, get_pip_commands, run_commands
55

66

77
def test_empty_args():
@@ -125,25 +125,3 @@ def test_run_commands():
125125
# Check that subprocess.run was called with the correct arguments
126126
mock_run.assert_any_call(cmds[0])
127127
mock_run.assert_any_call(cmds[1])
128-
129-
130-
def test_install_demotes_cu128_to_cu124_for_torch_2_6(monkeypatch):
131-
# Simulate a system where the detected platform would be cu128.
132-
monkeypatch.setattr("torchruntime.installer.get_gpus", lambda: ["dummy_gpu"])
133-
monkeypatch.setattr("torchruntime.installer.get_torch_platform", lambda gpu_infos: "cu128")
134-
135-
seen = {}
136-
137-
def fake_get_install_commands(torch_platform, packages):
138-
seen["torch_platform"] = torch_platform
139-
seen["packages"] = packages
140-
return [packages]
141-
142-
monkeypatch.setattr("torchruntime.installer.get_install_commands", fake_get_install_commands)
143-
monkeypatch.setattr("torchruntime.installer.get_pip_commands", lambda cmds, use_uv=False: cmds)
144-
monkeypatch.setattr("torchruntime.installer.run_commands", lambda cmds: None)
145-
146-
install(["torch==2.6.0"])
147-
148-
assert seen["packages"] == ["torch==2.6.0"]
149-
assert seen["torch_platform"] == "cu124"

tests/test_platform_detection.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,18 @@ def test_nvidia_gpu_linux(monkeypatch):
121121
assert get_torch_platform(gpu_infos) == expected
122122

123123

124+
def test_nvidia_gpu_demotes_to_cu124_for_pinned_torch_below_2_7(monkeypatch):
125+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
126+
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
127+
monkeypatch.setattr("torchruntime.platform_detection.py_version", (3, 11))
128+
monkeypatch.setattr("torchruntime.platform_detection.get_nvidia_arch", lambda device_names: 8.6)
129+
130+
gpu_infos = [GPU(NVIDIA, "NVIDIA", 0x1234, "GeForce", True)]
131+
132+
assert get_torch_platform(gpu_infos) == "cu128"
133+
assert get_torch_platform(gpu_infos, packages=["torch==2.6.0"]) == "cu124"
134+
135+
124136
def test_nvidia_gpu_mac(monkeypatch):
125137
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Darwin")
126138
monkeypatch.setattr("torchruntime.platform_detection.arch", "arm64")

torchruntime/installer.py

Lines changed: 1 addition & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -13,131 +13,6 @@
1313
CUDA_REGEX = re.compile(r"^(nightly/)?cu\d+$")
1414
ROCM_REGEX = re.compile(r"^(nightly/)?rocm\d+\.\d+$")
1515

16-
_CUDA_12_8_PLATFORM = "cu128"
17-
_CUDA_12_4_PLATFORM = "cu124"
18-
_CUDA_12_8_MIN_VERSIONS = {
19-
"torch": (2, 7, 0),
20-
"torchaudio": (2, 7, 0),
21-
"torchvision": (0, 22, 0),
22-
}
23-
24-
25-
def _parse_version_segments(text):
26-
text = text.strip().split("+", 1)[0]
27-
segments = []
28-
for part in text.split("."):
29-
m = re.match(r"^(\d+)", part)
30-
if not m:
31-
break
32-
segments.append(int(m.group(1)))
33-
return segments
34-
35-
36-
def _as_version_tuple(version_segments):
37-
padded = list(version_segments[:3])
38-
while len(padded) < 3:
39-
padded.append(0)
40-
return tuple(padded)
41-
42-
43-
def _version_lt(a, b):
44-
return _as_version_tuple(a) < _as_version_tuple(b)
45-
46-
47-
def _version_le(a, b):
48-
return _as_version_tuple(a) <= _as_version_tuple(b)
49-
50-
51-
def _get_requirement_name_and_specifier(requirement):
52-
req = requirement.strip()
53-
if not req or req.startswith("-") or "@" in req:
54-
return None, None
55-
56-
match = re.match(r"^([A-Za-z0-9][A-Za-z0-9_.-]*)(?:\[[^\]]+\])?", req)
57-
if not match:
58-
return None, None
59-
60-
name = match.group(1).lower().replace("_", "-")
61-
spec = req[match.end() :].split(";", 1)[0].strip()
62-
return name, spec
63-
64-
65-
def _upper_bound_for_specifier(specifier):
66-
"""
67-
Returns (upper_bound_segments, is_inclusive) for specifiers that impose an upper bound,
68-
or (None, None) if there is no upper bound.
69-
"""
70-
71-
s = specifier.strip()
72-
73-
if s.startswith("=="):
74-
value = s[2:].strip()
75-
if "*" in value:
76-
prefix = value.split("*", 1)[0].rstrip(".")
77-
prefix_segments = _parse_version_segments(prefix)
78-
if not prefix_segments:
79-
return None, None
80-
upper = list(prefix_segments)
81-
upper[-1] += 1
82-
upper.append(0)
83-
return upper, False
84-
85-
return _parse_version_segments(value), True
86-
87-
if s.startswith("<="):
88-
return _parse_version_segments(s[2:].strip()), True
89-
90-
if s.startswith("<"):
91-
return _parse_version_segments(s[1:].strip()), False
92-
93-
if s.startswith("~="):
94-
value_segments = _parse_version_segments(s[2:].strip())
95-
if len(value_segments) < 2:
96-
return None, None
97-
upper = list(value_segments[:-1])
98-
upper[-1] += 1
99-
upper.append(0)
100-
return upper, False
101-
102-
return None, None
103-
104-
105-
def _packages_require_cuda_12_4(packages):
106-
"""
107-
True if the requested torch package versions cannot be satisfied by the CUDA 12.8 wheel index.
108-
109-
This happens when a package is pinned (or capped) below the first version that has CUDA 12.8 wheels.
110-
"""
111-
112-
if not packages:
113-
return False
114-
115-
for package in packages:
116-
name, spec = _get_requirement_name_and_specifier(package)
117-
if not name or name not in _CUDA_12_8_MIN_VERSIONS or not spec:
118-
continue
119-
120-
threshold = _CUDA_12_8_MIN_VERSIONS[name]
121-
for raw in spec.split(","):
122-
upper, inclusive = _upper_bound_for_specifier(raw)
123-
if not upper:
124-
continue
125-
126-
if inclusive:
127-
if _version_lt(upper, threshold):
128-
return True
129-
else:
130-
if _version_le(upper, threshold):
131-
return True
132-
133-
return False
134-
135-
136-
def _adjust_cuda_platform_for_requested_packages(torch_platform, packages):
137-
if torch_platform == _CUDA_12_8_PLATFORM and _packages_require_cuda_12_4(packages):
138-
return _CUDA_12_4_PLATFORM
139-
return torch_platform
140-
14116

14217
def get_install_commands(torch_platform, packages):
14318
"""
@@ -223,8 +98,7 @@ def install(packages=[], use_uv=False):
22398
"""
22499

225100
gpu_infos = get_gpus()
226-
torch_platform = get_torch_platform(gpu_infos)
227-
torch_platform = _adjust_cuda_platform_for_requested_packages(torch_platform, packages)
101+
torch_platform = get_torch_platform(gpu_infos, packages=packages)
228102
cmds = get_install_commands(torch_platform, packages)
229103
cmds = get_pip_commands(cmds, use_uv=use_uv)
230104
run_commands(cmds)

torchruntime/platform_detection.py

Lines changed: 100 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,115 @@
22
import sys
33
import platform
44

5+
from packaging.requirements import Requirement
6+
from packaging.version import Version
7+
58
from .gpu_db import get_nvidia_arch, get_amd_gfx_info
69
from .consts import AMD, INTEL, NVIDIA, CONTACT_LINK
710

811
os_name = platform.system()
912
arch = platform.machine().lower()
1013
py_version = sys.version_info
1114

15+
_CUDA_12_8_PLATFORM = "cu128"
16+
_CUDA_12_4_PLATFORM = "cu124"
17+
_CUDA_12_8_MIN_VERSIONS = {
18+
"torch": Version("2.7.0"),
19+
"torchaudio": Version("2.7.0"),
20+
"torchvision": Version("0.22.0"),
21+
}
22+
23+
24+
def _parse_release_segments(text):
25+
segments = []
26+
for part in text.split("."):
27+
match = re.match(r"^(\d+)", part)
28+
if not match:
29+
break
30+
segments.append(int(match.group(1)))
31+
return segments
32+
33+
34+
def _upper_bound_for_specifier(specifier):
35+
operator = specifier.operator
36+
version = specifier.version
37+
38+
if operator == "<":
39+
return Version(version), False
40+
if operator == "<=":
41+
return Version(version), True
42+
if operator == "==":
43+
if "*" in version:
44+
prefix = version.split("*", 1)[0].rstrip(".")
45+
prefix_segments = _parse_release_segments(prefix)
46+
if not prefix_segments:
47+
return None, None
48+
prefix_segments[-1] += 1
49+
upper = Version(".".join(str(s) for s in prefix_segments))
50+
return upper, False
51+
return Version(version), True
52+
if operator == "~=":
53+
release_segments = _parse_release_segments(version)
54+
if len(release_segments) < 2:
55+
return None, None
56+
bump_index = len(release_segments) - 2
57+
upper_segments = release_segments[: bump_index + 1]
58+
upper_segments[bump_index] += 1
59+
upper = Version(".".join(str(s) for s in upper_segments))
60+
return upper, False
61+
62+
return None, None
63+
64+
65+
def _packages_require_cuda_12_4(packages):
66+
if not packages:
67+
return False
68+
69+
for package in packages:
70+
try:
71+
requirement = Requirement(package)
72+
except Exception:
73+
continue
74+
75+
name = requirement.name.lower().replace("_", "-")
76+
threshold = _CUDA_12_8_MIN_VERSIONS.get(name)
77+
if not threshold or not requirement.specifier:
78+
continue
79+
80+
threshold_allowed = None
81+
for specifier in requirement.specifier:
82+
upper, inclusive = _upper_bound_for_specifier(specifier)
83+
if not upper:
84+
continue
85+
86+
if upper < threshold:
87+
return True
88+
89+
if upper == threshold and not inclusive:
90+
return True
91+
92+
if upper == threshold and inclusive:
93+
if threshold_allowed is None:
94+
threshold_allowed = requirement.specifier.contains(threshold, prereleases=True)
95+
if not threshold_allowed:
96+
return True
97+
98+
return False
99+
100+
101+
def _adjust_cuda_platform_for_requested_packages(torch_platform, packages):
102+
if torch_platform == _CUDA_12_8_PLATFORM and _packages_require_cuda_12_4(packages):
103+
return _CUDA_12_4_PLATFORM
104+
return torch_platform
105+
12106

13-
def get_torch_platform(gpu_infos):
107+
def get_torch_platform(gpu_infos, packages=[]):
14108
"""
15109
Determine the appropriate PyTorch platform to use based on the system architecture, OS, and GPU information.
16110
17111
Args:
18112
gpu_infos (list of `torchruntime.device_db.GPU` instances)
113+
packages (list of str): Optional list of torch/torchvision/torchaudio requirement strings.
19114
20115
Returns:
21116
str: A string representing the platform to use. Possible values:
@@ -53,9 +148,11 @@ def get_torch_platform(gpu_infos):
53148
integrated_devices.append(device)
54149

55150
if discrete_devices:
56-
return _get_platform_for_discrete(discrete_devices)
151+
torch_platform = _get_platform_for_discrete(discrete_devices)
152+
return _adjust_cuda_platform_for_requested_packages(torch_platform, packages)
57153

58-
return _get_platform_for_integrated(integrated_devices)
154+
torch_platform = _get_platform_for_integrated(integrated_devices)
155+
return _adjust_cuda_platform_for_requested_packages(torch_platform, packages)
59156

60157

61158
def _get_platform_for_discrete(gpu_infos):

0 commit comments

Comments
 (0)