Skip to content

Commit d6e645d

Browse files
committed
ci: accept post-release wheel versions
1 parent 7ea8a46 commit d6e645d

2 files changed

Lines changed: 75 additions & 18 deletions

File tree

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
import os
7+
import sys
8+
from importlib.machinery import SourceFileLoader
9+
from importlib.util import module_from_spec, spec_from_loader
10+
11+
import pytest
12+
13+
TOOLS_DIR = os.path.join(os.path.dirname(__file__), "..")
14+
SCRIPT_PATH = os.path.join(TOOLS_DIR, "validate-release-wheels")
15+
16+
sys.path.insert(0, TOOLS_DIR)
17+
18+
loader = SourceFileLoader("validate_release_wheels", SCRIPT_PATH)
19+
spec = spec_from_loader(loader.name, loader)
20+
validate_release_wheels = module_from_spec(spec)
21+
loader.exec_module(validate_release_wheels)
22+
23+
24+
class TestVersionFromTag:
25+
def test_plain_post_release_tag(self):
26+
assert validate_release_wheels.version_from_tag("v12.6.2.post1", "cuda-bindings") == "12.6.2.post1"
27+
28+
def test_component_prefix_mismatch(self):
29+
with pytest.raises(ValueError, match="cuda-pathfinder"):
30+
validate_release_wheels.version_from_tag("cuda-core-v0.7.0", "cuda-pathfinder")
31+
32+
33+
class TestMain:
34+
def test_accepts_post_release_wheel(self, tmp_path):
35+
(tmp_path / "cuda_bindings-12.6.2.post1-py3-none-any.whl").touch()
36+
37+
rc = validate_release_wheels.main(["v12.6.2.post1", "cuda-bindings", str(tmp_path)])
38+
39+
assert rc == 0
40+
41+
def test_rejects_wheel_version_mismatch(self, tmp_path):
42+
(tmp_path / "cuda_bindings-12.6.2.post1-py3-none-any.whl").touch()
43+
44+
rc = validate_release_wheels.main(["v12.6.2", "cuda-bindings", str(tmp_path)])
45+
46+
assert rc == 1

ci/tools/validate-release-wheels

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
from __future__ import annotations
1010

1111
import argparse
12-
import re
1312
import sys
1413
from collections import defaultdict
1514
from pathlib import Path
1615

16+
from check_release_notes import parse_version_from_tag
17+
1718
COMPONENT_TO_DISTRIBUTIONS: dict[str, set[str]] = {
1819
"cuda-core": {"cuda_core"},
1920
"cuda-bindings": {"cuda_bindings"},
@@ -22,14 +23,16 @@ COMPONENT_TO_DISTRIBUTIONS: dict[str, set[str]] = {
2223
"all": {"cuda_core", "cuda_bindings", "cuda_pathfinder", "cuda_python"},
2324
}
2425

25-
TAG_PATTERNS = (
26-
re.compile(r"^v(?P<version>\d+\.\d+\.\d+)"),
27-
re.compile(r"^cuda-core-v(?P<version>\d+\.\d+\.\d+)"),
28-
re.compile(r"^cuda-pathfinder-v(?P<version>\d+\.\d+\.\d+)"),
29-
)
26+
COMPONENT_TO_TAG_COMPONENTS: dict[str, tuple[str, ...]] = {
27+
"cuda-core": ("cuda-core",),
28+
"cuda-bindings": ("cuda-bindings",),
29+
"cuda-pathfinder": ("cuda-pathfinder",),
30+
"cuda-python": ("cuda-python",),
31+
"all": ("cuda-core", "cuda-bindings", "cuda-pathfinder", "cuda-python"),
32+
}
3033

3134

32-
def parse_args() -> argparse.Namespace:
35+
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
3336
parser = argparse.ArgumentParser(
3437
description=(
3538
"Validate that wheel versions match the release tag. "
@@ -39,18 +42,21 @@ def parse_args() -> argparse.Namespace:
3942
parser.add_argument("git_tag", help="Release git tag (for example: v13.0.0)")
4043
parser.add_argument("component", choices=sorted(COMPONENT_TO_DISTRIBUTIONS.keys()))
4144
parser.add_argument("wheel_dir", help="Directory containing wheel files")
42-
return parser.parse_args()
45+
return parser.parse_args(argv)
4346

4447

45-
def version_from_tag(tag: str) -> str:
46-
for pattern in TAG_PATTERNS:
47-
match = pattern.match(tag)
48-
if match:
49-
return match.group("version")
48+
def version_from_tag(tag: str, component: str) -> str:
49+
versions = {
50+
version
51+
for tag_component in COMPONENT_TO_TAG_COMPONENTS[component]
52+
if (version := parse_version_from_tag(tag, tag_component)) is not None
53+
}
54+
if len(versions) == 1:
55+
return versions.pop()
5056
raise ValueError(
5157
"Unsupported git tag format "
52-
f"{tag!r}; expected tags beginning with vX.Y.Z, cuda-core-vX.Y.Z, "
53-
"or cuda-pathfinder-vX.Y.Z."
58+
f"{tag!r} for component {component!r}; expected vX.Y.Z, cuda-core-vX.Y.Z, "
59+
"or cuda-pathfinder-vX.Y.Z with a valid release version."
5460
)
5561

5662

@@ -62,9 +68,14 @@ def parse_wheel_dist_and_version(path: Path) -> tuple[str, str]:
6268
return parts[0], parts[1]
6369

6470

65-
def main() -> int:
66-
args = parse_args()
67-
expected_version = version_from_tag(args.git_tag)
71+
def main(argv: list[str] | None = None) -> int:
72+
args = parse_args(argv)
73+
try:
74+
expected_version = version_from_tag(args.git_tag, args.component)
75+
except ValueError as exc:
76+
print(f"Error: {exc}", file=sys.stderr)
77+
return 1
78+
6879
expected_distributions = COMPONENT_TO_DISTRIBUTIONS[args.component]
6980
wheel_dir = Path(args.wheel_dir)
7081

0 commit comments

Comments
 (0)