99from __future__ import annotations
1010
1111import argparse
12- import re
1312import sys
1413from collections import defaultdict
1514from pathlib import Path
1615
16+ from check_release_notes import parse_version_from_tag
17+
1718COMPONENT_TO_DISTRIBUTIONS : dict [str , set [str ]] = {
1819 "cuda-core" : {"cuda_core" },
1920 "cuda-bindings" : {"cuda_bindings" },
@@ -22,14 +23,16 @@ COMPONENT_TO_DISTRIBUTIONS: dict[str, set[str]] = {
2223 "all" : {"cuda_core" , "cuda_bindings" , "cuda_pathfinder" , "cuda_python" },
2324}
2425
25- TAG_PATTERNS = (
26- re .compile (r"^v(?P<version>\d+\.\d+\.\d+)" ),
27- re .compile (r"^cuda-core-v(?P<version>\d+\.\d+\.\d+)" ),
28- re .compile (r"^cuda-pathfinder-v(?P<version>\d+\.\d+\.\d+)" ),
29- )
26+ COMPONENT_TO_TAG_COMPONENTS : dict [str , tuple [str , ...]] = {
27+ "cuda-core" : ("cuda-core" ,),
28+ "cuda-bindings" : ("cuda-bindings" ,),
29+ "cuda-pathfinder" : ("cuda-pathfinder" ,),
30+ "cuda-python" : ("cuda-python" ,),
31+ "all" : ("cuda-core" , "cuda-bindings" , "cuda-pathfinder" , "cuda-python" ),
32+ }
3033
3134
32- def parse_args () -> argparse .Namespace :
35+ def parse_args (argv : list [ str ] | None = None ) -> argparse .Namespace :
3336 parser = argparse .ArgumentParser (
3437 description = (
3538 "Validate that wheel versions match the release tag. "
@@ -39,18 +42,21 @@ def parse_args() -> argparse.Namespace:
3942 parser .add_argument ("git_tag" , help = "Release git tag (for example: v13.0.0)" )
4043 parser .add_argument ("component" , choices = sorted (COMPONENT_TO_DISTRIBUTIONS .keys ()))
4144 parser .add_argument ("wheel_dir" , help = "Directory containing wheel files" )
42- return parser .parse_args ()
45+ return parser .parse_args (argv )
4346
4447
45- def version_from_tag (tag : str ) -> str :
46- for pattern in TAG_PATTERNS :
47- match = pattern .match (tag )
48- if match :
49- return match .group ("version" )
48+ def version_from_tag (tag : str , component : str ) -> str :
49+ versions = {
50+ version
51+ for tag_component in COMPONENT_TO_TAG_COMPONENTS [component ]
52+ if (version := parse_version_from_tag (tag , tag_component )) is not None
53+ }
54+ if len (versions ) == 1 :
55+ return versions .pop ()
5056 raise ValueError (
5157 "Unsupported git tag format "
52- f"{ tag !r} ; expected tags beginning with vX.Y.Z, cuda-core-vX.Y.Z, "
53- "or cuda-pathfinder-vX.Y.Z."
58+ f"{ tag !r} for component { component !r } ; expected vX.Y.Z, cuda-core-vX.Y.Z, "
59+ "or cuda-pathfinder-vX.Y.Z with a valid release version ."
5460 )
5561
5662
@@ -62,9 +68,14 @@ def parse_wheel_dist_and_version(path: Path) -> tuple[str, str]:
6268 return parts [0 ], parts [1 ]
6369
6470
65- def main () -> int :
66- args = parse_args ()
67- expected_version = version_from_tag (args .git_tag )
71+ def main (argv : list [str ] | None = None ) -> int :
72+ args = parse_args (argv )
73+ try :
74+ expected_version = version_from_tag (args .git_tag , args .component )
75+ except ValueError as exc :
76+ print (f"Error: { exc } " , file = sys .stderr )
77+ return 1
78+
6879 expected_distributions = COMPONENT_TO_DISTRIBUTIONS [args .component ]
6980 wheel_dir = Path (args .wheel_dir )
7081
0 commit comments