Skip to content

Commit 5c3234d

Browse files
committed
Add --policy presets for preview/unsupported gating
1 parent 04eeda3 commit 5c3234d

6 files changed

Lines changed: 179 additions & 7 deletions

File tree

API.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import torchruntime
88
You can use the command line:
99
`python -m torchruntime install <optional list of package names and versions>`
1010

11+
CLI flags: `--policy <compat|stable|preview>`, `--preview`, `--no-unsupported`, `--uv`
12+
1113
Or you can use the library:
1214
```py
1315
torchruntime.install(["torch", "torchvision<0.20"])

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ On Windows CUDA, Linux ROCm (6.x+), and Linux XPU, this also installs the approp
3434

3535
**Tip:** You can also add the `--uv` flag to install packages using [uv](https://docs.astral.sh/uv/) (instead of `pip`). For e.g. `python -m torchruntime install --uv`
3636

37+
Build-selection options:
38+
- `--policy <name>`: `compat` (default), `stable`, `preview` (or `nightly`)
39+
- Overrides: `--preview`, `--no-unsupported`
40+
3741
### Step 2. Configure torch
3842
This should be run inside your program, to initialize the required environment variables (if any) for the variant of torch being used.
3943

tests/test_policy_parsing.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import pytest
2+
from torchruntime.utils.args import parse_policy_args
3+
4+
5+
def test_default_policy():
6+
args = ["pkg1"]
7+
preview, unsupported, cleaned = parse_policy_args(args)
8+
assert preview is False
9+
assert unsupported is True
10+
assert cleaned == ["pkg1"]
11+
12+
13+
def test_stable_policy():
14+
args = ["--policy", "stable", "pkg1"]
15+
preview, unsupported, cleaned = parse_policy_args(args)
16+
assert preview is False
17+
assert unsupported is False
18+
assert cleaned == ["pkg1"]
19+
20+
21+
def test_nightly_policy():
22+
args = ["--policy", "nightly"]
23+
preview, unsupported, cleaned = parse_policy_args(args)
24+
assert preview is True
25+
assert unsupported is True
26+
assert cleaned == []
27+
28+
29+
def test_preview_policy_alias():
30+
args = ["--policy", "preview"]
31+
preview, unsupported, cleaned = parse_policy_args(args)
32+
assert preview is True
33+
assert unsupported is True
34+
assert cleaned == []
35+
36+
37+
def test_policy_equals_syntax():
38+
args = ["--policy=stable", "pkg1"]
39+
preview, unsupported, cleaned = parse_policy_args(args)
40+
assert preview is False
41+
assert unsupported is False
42+
assert cleaned == ["pkg1"]
43+
44+
45+
def test_policy_override_preview():
46+
# stable is p=F, u=F. --preview should make p=T
47+
args = ["--policy", "stable", "--preview"]
48+
preview, unsupported, cleaned = parse_policy_args(args)
49+
assert preview is True
50+
assert unsupported is False
51+
52+
def test_policy_override_unsupported():
53+
# nightly is p=T, u=T. --no-unsupported should make u=F
54+
args = ["--policy", "nightly", "--no-unsupported"]
55+
preview, unsupported, cleaned = parse_policy_args(args)
56+
assert preview is True
57+
assert unsupported is False
58+
59+
60+
def test_unknown_policy():
61+
args = ["--policy", "nonexistent"]
62+
with pytest.raises(ValueError, match="Unknown policy"):
63+
parse_policy_args(args)
64+
65+
66+
def test_missing_policy_arg():
67+
args = ["--policy"]
68+
with pytest.raises(ValueError, match="--policy requires an argument"):
69+
parse_policy_args(args)
70+
71+
72+
def test_mixed_args():
73+
args = ["torch", "--preview", "--policy", "stable", "--uv"]
74+
# stable: p=F, u=F
75+
# --preview: p=T
76+
# Result: p=T, u=F
77+
# cleaned: ["torch", "--uv"]
78+
preview, unsupported, cleaned = parse_policy_args(args)
79+
assert preview is True
80+
assert unsupported is False
81+
assert cleaned == ["torch", "--uv"]

torchruntime/__main__.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .installer import install
22
from .utils.torch_test import test
33
from .utils import info
4+
from .utils.args import parse_policy_args
45

56

67
def print_usage(entry_command: str):
@@ -18,6 +19,7 @@ def print_usage(entry_command: str):
1819
{entry_command} install --uv
1920
{entry_command} install --preview
2021
{entry_command} install --no-unsupported
22+
{entry_command} install --policy stable
2123
{entry_command} install torch==2.2.0 torchvision==0.17.0
2224
{entry_command} install --uv torch>=2.0.0 torchaudio
2325
{entry_command} install torch==2.1.* torchvision>=0.16.0 torchaudio==2.1.0
@@ -39,6 +41,7 @@ def print_usage(entry_command: str):
3941
--uv Use uv instead of pip for installation
4042
--preview Allow preview builds (e.g. ROCm 6.4)
4143
--no-unsupported Forbid EOL/unsupported builds (e.g. DirectML / IPEX / Torch 1.x)
44+
--policy <name> Set configuration policy (stable, compat, preview|nightly). Default: compat
4245
4346
Version specification formats (follows pip format):
4447
package==2.1.0 Exact version
@@ -66,19 +69,26 @@ def main():
6669

6770
if command == "install":
6871
args = sys.argv[2:] if len(sys.argv) > 2 else []
69-
use_uv = "--uv" in args
70-
preview = "--preview" in args
71-
unsupported = "--no-unsupported" not in args
72-
# Remove flags from args to get package list
73-
package_versions = [arg for arg in args if arg not in ("--uv", "--preview", "--no-unsupported")] if args else None
72+
try:
73+
preview, unsupported, cleaned_args = parse_policy_args(args)
74+
except ValueError as e:
75+
print(f"Error: {e}")
76+
return
77+
78+
use_uv = "--uv" in cleaned_args
79+
# Remove --uv from package list
80+
package_versions = [arg for arg in cleaned_args if arg != "--uv"]
7481
install(package_versions, use_uv=use_uv, preview=preview, unsupported=unsupported)
7582
elif command == "test":
7683
subcommand = sys.argv[2] if len(sys.argv) > 2 else "all"
7784
test(subcommand)
7885
elif command == "info":
7986
args = sys.argv[2:] if len(sys.argv) > 2 else []
80-
preview = "--preview" in args
81-
unsupported = "--no-unsupported" not in args
87+
try:
88+
preview, unsupported, _ = parse_policy_args(args)
89+
except ValueError as e:
90+
print(f"Error: {e}")
91+
return
8292
from .utils import info
8393
info(preview=preview, unsupported=unsupported)
8494
else:

torchruntime/consts.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,10 @@
33
AMD = "1002"
44
NVIDIA = "10de"
55
INTEL = "8086"
6+
7+
POLICIES = {
8+
"stable": (False, False), # preview=False, unsupported=False
9+
"compat": (False, True), # preview=False, unsupported=True (Default)
10+
"preview": (True, True), # preview=True, unsupported=True
11+
"nightly": (True, True), # alias for preview
12+
}

torchruntime/utils/args.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from ..consts import POLICIES
2+
3+
4+
def parse_policy_args(args):
5+
"""
6+
Parses arguments for policy and flags.
7+
Returns (preview, unsupported, cleaned_args)
8+
9+
Supports both `--policy NAME` and `--policy=NAME`.
10+
11+
Logic:
12+
1. Determine base configuration from the LAST provided --policy argument (or default 'compat').
13+
2. Apply explicit flags (--preview, --no-unsupported) which ALWAYS override the policy.
14+
3. Remove policy and flags from args to produce cleaned_args.
15+
"""
16+
# Default: compat
17+
preview, unsupported = POLICIES["compat"]
18+
19+
# 1. Scan for the last policy to set the baseline
20+
last_policy_name = None
21+
i = 0
22+
while i < len(args):
23+
arg = args[i]
24+
if arg == "--policy":
25+
if i + 1 < len(args):
26+
last_policy_name = args[i+1]
27+
i += 2
28+
else:
29+
# We will catch this error in the second pass or we can raise it now.
30+
# Raising now is safer.
31+
raise ValueError("--policy requires an argument")
32+
elif arg.startswith("--policy="):
33+
last_policy_name = arg.split("=", 1)[1]
34+
if not last_policy_name:
35+
raise ValueError("--policy requires an argument")
36+
i += 1
37+
else:
38+
i += 1
39+
40+
if last_policy_name:
41+
if last_policy_name in POLICIES:
42+
preview, unsupported = POLICIES[last_policy_name]
43+
else:
44+
raise ValueError(f"Unknown policy: {last_policy_name}")
45+
46+
# 2. Apply flags and build cleaned_args
47+
cleaned_args = []
48+
i = 0
49+
while i < len(args):
50+
arg = args[i]
51+
if arg == "--policy":
52+
# Skip policy and its value (already processed)
53+
i += 2
54+
continue
55+
elif arg.startswith("--policy="):
56+
i += 1
57+
continue
58+
elif arg == "--preview":
59+
preview = True
60+
i += 1
61+
elif arg == "--no-unsupported":
62+
unsupported = False
63+
i += 1
64+
else:
65+
cleaned_args.append(arg)
66+
i += 1
67+
68+
return preview, unsupported, cleaned_args

0 commit comments

Comments
 (0)