Skip to content

Commit e992427

Browse files
committed
feat: add a script to check for kernel public api changes
1 parent 303e2f7 commit e992427

2 files changed

Lines changed: 220 additions & 0 deletions

File tree

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
name: Check Public API
2+
3+
on:
4+
pull_request:
5+
paths-ignore:
6+
- "**/*.md"
7+
- ".github/**"
8+
- "scripts/**"
9+
- "*.json"
10+
- "*.nix"
11+
- "*.lock"
12+
- ".gitignore"
13+
- ".gitmodules"
14+
15+
concurrency:
16+
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
17+
cancel-in-progress: true
18+
19+
permissions:
20+
contents: read
21+
22+
jobs:
23+
check-public-api:
24+
runs-on: ubuntu-latest
25+
steps:
26+
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
27+
with:
28+
fetch-depth: 0
29+
30+
- name: Check public API
31+
env:
32+
BASE_SHA: ${{ github.event.pull_request.base.sha }}
33+
run: |
34+
python3 scripts/check_public_api.py --base-ref "$BASE_SHA"

scripts/check_public_api.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import argparse
2+
import ast
3+
import json
4+
import re
5+
import subprocess
6+
import sys
7+
import tempfile
8+
from pathlib import Path
9+
10+
# Private by name but still part of the public contract.
11+
PUBLIC_DUNDERS = {"__init__", "__call__"}
12+
DEF_RE = re.compile(r"\.def\(\s*\"((?:[^\"\\]|\\.)*)\"")
13+
14+
15+
def public(name: str) -> bool:
16+
return not name.startswith("_")
17+
18+
19+
def signature(node) -> str:
20+
args = ast.unparse(node.args)
21+
returns = f" -> {ast.unparse(node.returns)}" if node.returns else ""
22+
return f"({args}){returns}"
23+
24+
25+
def class_api(node: ast.ClassDef, rel: str) -> dict:
26+
out = {f"{rel} {node.name}(bases)": ", ".join(ast.unparse(b) for b in node.bases)}
27+
attrs = []
28+
for item in node.body:
29+
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
30+
if public(item.name) or item.name in PUBLIC_DUNDERS:
31+
out[f"{rel} {node.name}.{item.name}"] = signature(item)
32+
elif isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name):
33+
if public(item.target.id):
34+
attrs.append(item.target.id)
35+
elif isinstance(item, ast.Assign):
36+
attrs += [
37+
t.id for t in item.targets if isinstance(t, ast.Name) and public(t.id)
38+
]
39+
if attrs:
40+
out[f"{rel} {node.name}(attrs)"] = ", ".join(sorted(set(attrs)))
41+
return out
42+
43+
44+
def python_api(path: Path, rel: str) -> dict:
45+
out = {}
46+
for node in ast.parse(path.read_text()).body:
47+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
48+
if public(node.name):
49+
out[f"{rel} {node.name}"] = signature(node)
50+
elif isinstance(node, ast.ClassDef) and public(node.name):
51+
out.update(class_api(node, rel))
52+
elif isinstance(node, ast.Assign):
53+
if any(isinstance(t, ast.Name) and t.id == "__all__" for t in node.targets):
54+
try:
55+
out[f"{rel} __all__"] = repr(sorted(ast.literal_eval(node.value)))
56+
except (ValueError, TypeError):
57+
pass
58+
return out
59+
60+
61+
def extract_api(kernel_root: Path) -> dict:
62+
ext = kernel_root / "torch-ext"
63+
if not ext.is_dir():
64+
return {}
65+
api = {}
66+
for src in sorted([*ext.rglob("*.cpp"), *ext.rglob("*.cc")]):
67+
for match in DEF_RE.finditer(src.read_text(errors="replace")):
68+
schema = " ".join(match.group(1).split())
69+
if schema:
70+
api[f"op {schema.split('(', 1)[0].strip()}"] = schema
71+
for py in sorted(ext.rglob("*.py")):
72+
# Skip private modules (e.g. _ops.py) but keep dunders (__init__.py).
73+
if py.stem.startswith("_") and not py.stem.startswith("__"):
74+
continue
75+
api.update(python_api(py, py.relative_to(ext).as_posix()))
76+
return api
77+
78+
79+
def git(*args: str) -> str:
80+
return subprocess.run(
81+
["git", *args], check=True, capture_output=True, text=True
82+
).stdout.strip()
83+
84+
85+
def changed_kernels(ref: str) -> list:
86+
found = set()
87+
for line in git("diff", "--name-only", ref).splitlines():
88+
top = line.split("/", 1)[0]
89+
if "/" in line and (Path(top) / "build.toml").is_file():
90+
found.add(top)
91+
return sorted(found)
92+
93+
94+
def api_at_ref(kernel: str, ref: str) -> dict:
95+
with tempfile.TemporaryDirectory(prefix="api-base-") as tmp:
96+
try:
97+
git("worktree", "add", "--detach", "--quiet", tmp, ref)
98+
return extract_api(Path(tmp) / kernel)
99+
finally:
100+
subprocess.run(
101+
["git", "worktree", "remove", "--force", tmp],
102+
capture_output=True,
103+
text=True,
104+
)
105+
106+
107+
def report(kernel: str, base: dict, head: dict) -> bool:
108+
removed = {k: v for k, v in base.items() if k not in head}
109+
added = {k: v for k, v in head.items() if k not in base}
110+
changed = {k: (base[k], head[k]) for k in base if k in head and base[k] != head[k]}
111+
112+
if not (removed or changed or added):
113+
print(f"✅ {kernel}: public API unchanged.")
114+
return False
115+
116+
print(f"\n🔎 {kernel}: public API changes detected.")
117+
for key, val in sorted(removed.items()):
118+
print(f" ❌ removed: {key} = {val}")
119+
for key, (before, after) in sorted(changed.items()):
120+
print(f" ❌ changed: {key}\n before: {before}\n after: {after}")
121+
for key, val in sorted(added.items()):
122+
print(f" ➕ added: {key} = {val}")
123+
return True
124+
125+
126+
def main() -> int:
127+
parser = argparse.ArgumentParser(
128+
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
129+
)
130+
parser.add_argument(
131+
"kernels",
132+
nargs="*",
133+
help="Kernels to check (default: those changed vs --base-ref).",
134+
)
135+
parser.add_argument(
136+
"--base-ref",
137+
default="origin/main",
138+
help="Target branch to compare against (default: origin/main).",
139+
)
140+
parser.add_argument(
141+
"--dump",
142+
action="store_true",
143+
help="Print the working-tree API for each kernel and exit.",
144+
)
145+
args = parser.parse_args()
146+
147+
if args.dump:
148+
if not args.kernels:
149+
parser.error("--dump requires explicit kernel name(s).")
150+
for kernel in args.kernels:
151+
print(f"# {kernel}")
152+
print(json.dumps(extract_api(Path(kernel)), indent=2, sort_keys=True))
153+
return 0
154+
155+
# Compare against where the branch diverged, so changes others made on the
156+
# target branch aren't mistaken for changes made here.
157+
try:
158+
baseline = git("merge-base", args.base_ref, "HEAD")
159+
except subprocess.CalledProcessError:
160+
baseline = args.base_ref
161+
162+
kernels = args.kernels or changed_kernels(baseline)
163+
if not kernels:
164+
print("No kernel sources changed; nothing to check.")
165+
return 0
166+
167+
changed = False
168+
for kernel in kernels:
169+
if not Path(kernel).is_dir():
170+
print(f"⚠️ {kernel}: directory not found, skipping.", file=sys.stderr)
171+
continue
172+
if report(kernel, api_at_ref(kernel, baseline), extract_api(Path(kernel))):
173+
changed = True
174+
175+
if changed:
176+
print(
177+
"\n💥 Public API changed. If intentional, note it in the PR "
178+
"description so reviewers can approve the change.",
179+
file=sys.stderr,
180+
)
181+
return 1
182+
return 0
183+
184+
185+
if __name__ == "__main__":
186+
sys.exit(main())

0 commit comments

Comments
 (0)