Skip to content

Commit bd5a9a9

Browse files
committed
ci: improve pr-labels job quality
1 parent 8b755c4 commit bd5a9a9

3 files changed

Lines changed: 390 additions & 266 deletions

File tree

.github/scripts/pr_labels.py

Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,338 @@
1+
#!/usr/bin/env python3
2+
from __future__ import annotations
3+
4+
import argparse
5+
import json
6+
import os
7+
import pathlib
8+
import subprocess
9+
import sys
10+
from collections.abc import Sequence
11+
from typing import Any, Final
12+
13+
ALLOWED_LABELS: Final[set[str]] = {
14+
"documentation",
15+
"project",
16+
"bug",
17+
"enhancement",
18+
"dependencies",
19+
"feature:chat-completions",
20+
"feature:core",
21+
"feature:lite-llm",
22+
"feature:mcp",
23+
"feature:realtime",
24+
"feature:sessions",
25+
"feature:tracing",
26+
"feature:voice",
27+
}
28+
29+
SOURCE_FEATURE_PREFIXES: Final[dict[str, tuple[str, ...]]] = {
30+
"feature:realtime": ("src/agents/realtime/",),
31+
"feature:voice": ("src/agents/voice/",),
32+
"feature:mcp": ("src/agents/mcp/",),
33+
"feature:tracing": ("src/agents/tracing/",),
34+
"feature:sessions": ("src/agents/memory/",),
35+
}
36+
37+
CORE_EXCLUDED_PREFIXES: Final[tuple[str, ...]] = (
38+
"src/agents/realtime/",
39+
"src/agents/voice/",
40+
"src/agents/mcp/",
41+
"src/agents/tracing/",
42+
"src/agents/memory/",
43+
"src/agents/extensions/",
44+
"src/agents/models/",
45+
)
46+
47+
48+
def read_file_at(commit: str | None, path: str) -> str | None:
49+
if not commit:
50+
return None
51+
try:
52+
return subprocess.check_output(["git", "show", f"{commit}:{path}"], text=True)
53+
except subprocess.CalledProcessError:
54+
return None
55+
56+
57+
def dependency_lines_for_pyproject(text: str) -> set[int]:
58+
dependency_lines: set[int] = set()
59+
current_section: str | None = None
60+
in_project_dependencies = False
61+
62+
for line_number, raw_line in enumerate(text.splitlines(), start=1):
63+
stripped = raw_line.strip()
64+
if stripped.startswith("[") and stripped.endswith("]"):
65+
if stripped.startswith("[[") and stripped.endswith("]]"):
66+
current_section = stripped[2:-2].strip()
67+
else:
68+
current_section = stripped[1:-1].strip()
69+
in_project_dependencies = False
70+
if current_section in ("project.optional-dependencies", "dependency-groups"):
71+
dependency_lines.add(line_number)
72+
continue
73+
74+
if current_section in ("project.optional-dependencies", "dependency-groups"):
75+
dependency_lines.add(line_number)
76+
continue
77+
78+
if current_section != "project":
79+
continue
80+
81+
if in_project_dependencies:
82+
dependency_lines.add(line_number)
83+
if "]" in stripped:
84+
in_project_dependencies = False
85+
continue
86+
87+
if stripped.startswith("dependencies") and "=" in stripped:
88+
dependency_lines.add(line_number)
89+
if "[" in stripped and "]" not in stripped:
90+
in_project_dependencies = True
91+
92+
return dependency_lines
93+
94+
95+
def pyproject_dependency_changed(
96+
diff_text: str,
97+
*,
98+
base_sha: str | None,
99+
head_sha: str | None,
100+
) -> bool:
101+
import re
102+
103+
base_text = read_file_at(base_sha, "pyproject.toml")
104+
head_text = read_file_at(head_sha, "pyproject.toml")
105+
if base_text is None and head_text is None:
106+
return False
107+
108+
base_dependency_lines = dependency_lines_for_pyproject(base_text) if base_text else set()
109+
head_dependency_lines = dependency_lines_for_pyproject(head_text) if head_text else set()
110+
111+
in_pyproject = False
112+
base_line: int | None = None
113+
head_line: int | None = None
114+
hunk_re = re.compile(r"@@ -(\d+)(?:,\d+)? \+(\d+)(?:,\d+)? @@")
115+
116+
for line in diff_text.splitlines():
117+
if line.startswith("+++ b/"):
118+
current_file = line[len("+++ b/") :].strip()
119+
in_pyproject = current_file == "pyproject.toml"
120+
base_line = None
121+
head_line = None
122+
continue
123+
124+
if not in_pyproject:
125+
continue
126+
127+
if line.startswith("@@ "):
128+
match = hunk_re.match(line)
129+
if not match:
130+
continue
131+
base_line = int(match.group(1))
132+
head_line = int(match.group(2))
133+
continue
134+
135+
if base_line is None or head_line is None:
136+
continue
137+
138+
if line.startswith(" "):
139+
base_line += 1
140+
head_line += 1
141+
continue
142+
143+
if line.startswith("-"):
144+
if base_line in base_dependency_lines:
145+
return True
146+
base_line += 1
147+
continue
148+
149+
if line.startswith("+"):
150+
if head_line in head_dependency_lines:
151+
return True
152+
head_line += 1
153+
continue
154+
155+
return False
156+
157+
158+
def infer_specific_feature_labels(changed_files: Sequence[str]) -> set[str]:
159+
source_files = [path for path in changed_files if path.startswith("src/")]
160+
labels: set[str] = set()
161+
162+
for label, prefixes in SOURCE_FEATURE_PREFIXES.items():
163+
if any(path.startswith(prefix) for path in source_files for prefix in prefixes):
164+
labels.add(label)
165+
166+
if any(
167+
path.startswith(("src/agents/models/", "src/agents/extensions/models/"))
168+
and ("chatcmpl" in path or "chatcompletions" in path)
169+
for path in source_files
170+
):
171+
labels.add("feature:chat-completions")
172+
173+
if any(
174+
path.startswith(("src/agents/models/", "src/agents/extensions/models/"))
175+
and "litellm" in path
176+
for path in source_files
177+
):
178+
labels.add("feature:lite-llm")
179+
180+
return labels
181+
182+
183+
def infer_feature_labels(changed_files: Sequence[str]) -> set[str]:
184+
source_files = [path for path in changed_files if path.startswith("src/")]
185+
specific_labels = infer_specific_feature_labels(source_files)
186+
core_touched = any(
187+
path.startswith("src/agents/") and not path.startswith(CORE_EXCLUDED_PREFIXES)
188+
for path in source_files
189+
)
190+
191+
if core_touched and len(specific_labels) != 1:
192+
return {"feature:core"}
193+
return specific_labels
194+
195+
196+
def infer_fallback_labels(changed_files: Sequence[str]) -> set[str]:
197+
return infer_feature_labels(changed_files)
198+
199+
200+
def load_json(path: pathlib.Path) -> Any:
201+
return json.loads(path.read_text())
202+
203+
204+
def load_codex_labels(path: pathlib.Path) -> list[str]:
205+
if not path.exists():
206+
return []
207+
208+
raw = path.read_text().strip()
209+
if not raw:
210+
return []
211+
212+
try:
213+
payload = load_json(path)
214+
except json.JSONDecodeError:
215+
return []
216+
217+
if not isinstance(payload, dict):
218+
return []
219+
220+
labels = payload.get("labels", [])
221+
if not isinstance(labels, list):
222+
return []
223+
224+
return [label for label in labels if isinstance(label, str)]
225+
226+
227+
def fetch_existing_labels(pr_number: str) -> set[str]:
228+
result = subprocess.check_output(
229+
["gh", "pr", "view", pr_number, "--json", "labels", "--jq", ".labels[].name"],
230+
text=True,
231+
).strip()
232+
return {label for label in result.splitlines() if label}
233+
234+
235+
def compute_desired_labels(
236+
*,
237+
changed_files: Sequence[str],
238+
diff_text: str,
239+
codex_ran: bool,
240+
codex_labels: Sequence[str],
241+
base_sha: str | None,
242+
head_sha: str | None,
243+
) -> set[str]:
244+
desired: set[str] = set()
245+
246+
if "pyproject.toml" in changed_files:
247+
desired.add("project")
248+
249+
if any(path.startswith("docs/") for path in changed_files):
250+
desired.add("documentation")
251+
252+
dependencies_allowed = "uv.lock" in changed_files
253+
if "pyproject.toml" in changed_files and pyproject_dependency_changed(
254+
diff_text, base_sha=base_sha, head_sha=head_sha
255+
):
256+
dependencies_allowed = True
257+
if dependencies_allowed:
258+
desired.add("dependencies")
259+
260+
if codex_ran:
261+
for label in codex_labels:
262+
if label == "dependencies" and not dependencies_allowed:
263+
continue
264+
if label in ALLOWED_LABELS:
265+
desired.add(label)
266+
return desired
267+
268+
desired.update(infer_fallback_labels(changed_files))
269+
return desired
270+
271+
272+
def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace:
273+
parser = argparse.ArgumentParser()
274+
parser.add_argument("--pr-number", default=os.environ.get("PR_NUMBER", ""))
275+
parser.add_argument("--base-sha", default=os.environ.get("PR_BASE_SHA", ""))
276+
parser.add_argument("--head-sha", default=os.environ.get("PR_HEAD_SHA", ""))
277+
parser.add_argument(
278+
"--codex-output-path",
279+
default=os.environ.get("CODEX_OUTPUT_PATH", ".tmp/codex/outputs/pr-labels.json"),
280+
)
281+
parser.add_argument("--codex-conclusion", default=os.environ.get("CODEX_CONCLUSION", ""))
282+
parser.add_argument(
283+
"--changed-files-path",
284+
default=os.environ.get("CHANGED_FILES_PATH", ".tmp/pr-labels/changed-files.txt"),
285+
)
286+
parser.add_argument(
287+
"--changes-diff-path",
288+
default=os.environ.get("CHANGES_DIFF_PATH", ".tmp/pr-labels/changes.diff"),
289+
)
290+
return parser.parse_args(argv)
291+
292+
293+
def main(argv: Sequence[str] | None = None) -> int:
294+
args = parse_args(argv)
295+
if not args.pr_number:
296+
raise SystemExit("Missing PR number.")
297+
298+
changed_files_path = pathlib.Path(args.changed_files_path)
299+
changes_diff_path = pathlib.Path(args.changes_diff_path)
300+
codex_output_path = pathlib.Path(args.codex_output_path)
301+
codex_conclusion = args.codex_conclusion.strip().lower()
302+
codex_ran = bool(codex_conclusion) and codex_conclusion != "skipped"
303+
304+
changed_files = []
305+
if changed_files_path.exists():
306+
changed_files = [
307+
line.strip() for line in changed_files_path.read_text().splitlines() if line.strip()
308+
]
309+
310+
diff_text = changes_diff_path.read_text() if changes_diff_path.exists() else ""
311+
desired = compute_desired_labels(
312+
changed_files=changed_files,
313+
diff_text=diff_text,
314+
codex_ran=codex_ran,
315+
codex_labels=load_codex_labels(codex_output_path),
316+
base_sha=args.base_sha or None,
317+
head_sha=args.head_sha or None,
318+
)
319+
320+
existing = fetch_existing_labels(args.pr_number)
321+
to_add = sorted(desired - existing)
322+
to_remove = sorted((existing & ALLOWED_LABELS) - desired)
323+
324+
if not to_add and not to_remove:
325+
print("Labels already up to date.")
326+
return 0
327+
328+
cmd = ["gh", "pr", "edit", args.pr_number]
329+
if to_add:
330+
cmd += ["--add-label", ",".join(to_add)]
331+
if to_remove:
332+
cmd += ["--remove-label", ",".join(to_remove)]
333+
subprocess.check_call(cmd)
334+
return 0
335+
336+
337+
if __name__ == "__main__":
338+
sys.exit(main())

0 commit comments

Comments
 (0)