Skip to content

Commit d8465a6

Browse files
committed
Improve flaky PR job reruns
1 parent 6617883 commit d8465a6

File tree

1 file changed

+91
-31
lines changed

1 file changed

+91
-31
lines changed

.github/scripts/rerun-flaky-pr-jobs.py

Lines changed: 91 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#!/usr/bin/env python3
22
"""Rerun up to two failed jobs for recent pull request CI runs.
33
4-
Scans recent failed runs of build-pull-request.yml, ignores the synthetic
5-
required-status-check job, and reruns eligible failed jobs once.
4+
Scans recent failed pull request workflow runs, ignores the synthetic
5+
required-status-check job, and reruns eligible failed jobs up to two times.
66
"""
77

88
from __future__ import annotations
@@ -16,7 +16,8 @@
1616
from pathlib import Path
1717

1818
LOOKBACK_HOURS = 2
19-
MAX_ELIGIBLE_FAILURES = 2
19+
MAX_FAILED_JOBS_PER_WORKFLOW_RUN = 2
20+
MAX_RERUN_ATTEMPTS = 2
2021

2122

2223
def main() -> None:
@@ -25,7 +26,7 @@ def main() -> None:
2526
lookback_cutoff = datetime.now(timezone.utc).timestamp() - LOOKBACK_HOURS * 60 * 60
2627

2728
recent_runs = list_recent_pull_request_runs(owner, repo, lookback_cutoff)
28-
latest_run_by_pull_request = latest_run_per_pull_request(recent_runs)
29+
latest_run_by_pull_request = latest_run_per_pull_request_workflow(owner, repo, recent_runs)
2930

3031
processed_runs: list[str] = []
3132
rerun_jobs: list[str] = []
@@ -34,8 +35,11 @@ def main() -> None:
3435
if run["status"] != "completed" or run.get("conclusion") != "failure":
3536
continue
3637

37-
if run["run_attempt"] > 1:
38-
processed_runs.append(f"Skipped {format_run(run)}: already rerun once.")
38+
rerun_attempts = run["run_attempt"] - 1
39+
if rerun_attempts > MAX_RERUN_ATTEMPTS:
40+
processed_runs.append(
41+
f"Skipped {format_run(run)}: already rerun {rerun_attempts} times."
42+
)
3943
continue
4044

4145
jobs = list_jobs_for_run(owner, repo, run["id"])
@@ -49,21 +53,20 @@ def main() -> None:
4953
processed_runs.append(f"Skipped {format_run(run)}: only synthetic jobs failed.")
5054
continue
5155

52-
if len(failed_real_jobs) > MAX_ELIGIBLE_FAILURES:
56+
if len(failed_real_jobs) > MAX_FAILED_JOBS_PER_WORKFLOW_RUN:
5357
processed_runs.append(
54-
f"Skipped {format_run(run)}: {len(failed_real_jobs)} failed jobs exceeded limit {MAX_ELIGIBLE_FAILURES}."
58+
f"Skipped {format_run(run)}: {len(failed_real_jobs)} failed jobs exceeded limit {MAX_FAILED_JOBS_PER_WORKFLOW_RUN}."
5559
)
5660
continue
5761

58-
for job in failed_real_jobs:
59-
try:
60-
github_request("POST", f"/repos/{owner}/{repo}/actions/jobs/{job['id']}/rerun")
61-
rerun_jobs.append(f"{format_run(run)}: reran {job['name']} ({job['id']}).")
62-
except subprocess.CalledProcessError as e:
63-
message = read_process_error(e)
64-
processed_runs.append(
65-
f"Failed rerun for {format_run(run)} job {job['name']} ({job['id']}): {message}"
66-
)
62+
try:
63+
github_request("POST", f"/repos/{owner}/{repo}/actions/runs/{run['id']}/rerun-failed-jobs")
64+
rerun_jobs.append(f"{format_run(run)}: reran failed jobs {format_jobs(failed_real_jobs)}.")
65+
except subprocess.CalledProcessError as e:
66+
message = read_process_error(e)
67+
processed_runs.append(
68+
f"Failed rerun for {format_run(run)} jobs {format_jobs(failed_real_jobs)}: {message}"
69+
)
6770

6871
if not processed_runs and not rerun_jobs:
6972
processed_runs.append("No recent failed PR runs matched the rerun policy.")
@@ -86,9 +89,9 @@ def list_recent_pull_request_runs(owner: str, repo: str, lookback_cutoff: float)
8689
page = 1
8790

8891
while True:
89-
response = github_request(
92+
response = github_request_object(
9093
"GET",
91-
f"/repos/{owner}/{repo}/actions/workflows/build-pull-request.yml/runs",
94+
f"/repos/{owner}/{repo}/actions/runs",
9295
{
9396
"event": "pull_request",
9497
"per_page": str(per_page),
@@ -114,19 +117,21 @@ def list_recent_pull_request_runs(owner: str, repo: str, lookback_cutoff: float)
114117
return [run for run in runs if parse_github_time(run["created_at"]).timestamp() >= lookback_cutoff]
115118

116119

117-
def latest_run_per_pull_request(runs: list[dict]) -> dict[int, dict]:
118-
latest_by_pr: dict[int, dict] = {}
120+
def latest_run_per_pull_request_workflow(owner: str, repo: str, runs: list[dict]) -> dict[tuple[int, int], dict]:
121+
latest_by_pr_workflow: dict[tuple[int, int], dict] = {}
122+
branch_cache: dict[tuple[str, str], int | None] = {}
119123

120124
for run in runs:
121-
pr_number = get_pr_number(run)
125+
pr_number = resolve_pr_number(owner, repo, run, branch_cache)
122126
if pr_number is None:
123127
continue
124128

125-
existing = latest_by_pr.get(pr_number)
129+
key = (pr_number, run["workflow_id"])
130+
existing = latest_by_pr_workflow.get(key)
126131
if existing is None or parse_github_time(run["created_at"]) > parse_github_time(existing["created_at"]):
127-
latest_by_pr[pr_number] = run
132+
latest_by_pr_workflow[key] = run
128133

129-
return latest_by_pr
134+
return latest_by_pr_workflow
130135

131136

132137
def list_jobs_for_run(owner: str, repo: str, run_id: int) -> list[dict]:
@@ -135,7 +140,7 @@ def list_jobs_for_run(owner: str, repo: str, run_id: int) -> list[dict]:
135140
page = 1
136141

137142
while True:
138-
response = github_request(
143+
response = github_request_object(
139144
"GET",
140145
f"/repos/{owner}/{repo}/actions/runs/{run_id}/jobs",
141146
{
@@ -153,7 +158,7 @@ def list_jobs_for_run(owner: str, repo: str, run_id: int) -> list[dict]:
153158
page += 1
154159

155160

156-
def github_request(method: str, path: str, query: dict[str, str] | None = None) -> dict:
161+
def github_request(method: str, path: str, query: dict[str, str] | None = None) -> dict | list[dict]:
157162
url = path.removeprefix("/")
158163
if query:
159164
url += "?" + urllib.parse.urlencode(query)
@@ -180,6 +185,20 @@ def github_request(method: str, path: str, query: dict[str, str] | None = None)
180185
return json.loads(result.stdout)
181186

182187

188+
def github_request_object(method: str, path: str, query: dict[str, str] | None = None) -> dict:
189+
response = github_request(method, path, query)
190+
if isinstance(response, list):
191+
raise TypeError(f"Expected object response for {path}")
192+
return response
193+
194+
195+
def github_request_list(method: str, path: str, query: dict[str, str] | None = None) -> list[dict]:
196+
response = github_request(method, path, query)
197+
if not isinstance(response, list):
198+
raise TypeError(f"Expected list response for {path}")
199+
return response
200+
201+
183202
def read_process_error(error: subprocess.CalledProcessError) -> str:
184203
return error.stderr.strip() or error.stdout.strip() or f"exit code {error.returncode}"
185204

@@ -194,19 +213,60 @@ def get_current_repository() -> str:
194213
return result.stdout.strip()
195214

196215

197-
def get_pr_number(run: dict) -> int | None:
216+
def resolve_pr_number(
217+
owner: str, repo: str, run: dict, branch_cache: dict[tuple[str, str], int | None]
218+
) -> int | None:
219+
cached_pr_number = run.get("resolved_pr_number")
220+
if isinstance(cached_pr_number, int):
221+
return cached_pr_number
222+
198223
pull_requests = run.get("pull_requests") or []
199-
if not pull_requests:
224+
if pull_requests:
225+
pr_number = pull_requests[0].get("number")
226+
run["resolved_pr_number"] = pr_number
227+
return pr_number
228+
229+
head_repository = run.get("head_repository") or {}
230+
head_owner = (head_repository.get("owner") or {}).get("login")
231+
head_branch = run.get("head_branch")
232+
if not head_owner or not head_branch:
233+
run["resolved_pr_number"] = None
200234
return None
201-
return pull_requests[0].get("number")
235+
236+
cache_key = (head_owner, head_branch)
237+
if cache_key not in branch_cache:
238+
pull_requests = github_request_list(
239+
"GET",
240+
f"/repos/{owner}/{repo}/pulls",
241+
{
242+
"state": "open",
243+
"head": f"{head_owner}:{head_branch}",
244+
"per_page": "5",
245+
},
246+
)
247+
pr_number = None
248+
for pull_request in pull_requests:
249+
if pull_request.get("head", {}).get("sha") == run.get("head_sha"):
250+
pr_number = pull_request["number"]
251+
break
252+
if pr_number is None and pull_requests:
253+
pr_number = pull_requests[0]["number"]
254+
branch_cache[cache_key] = pr_number
255+
256+
run["resolved_pr_number"] = branch_cache[cache_key]
257+
return branch_cache[cache_key]
202258

203259

204260
def format_run(run: dict) -> str:
205-
pr_number = get_pr_number(run)
261+
pr_number = run.get("resolved_pr_number")
206262
pr_label = f"PR #{pr_number}" if pr_number is not None else "PR unknown"
207263
return f"{pr_label}, run {run['id']}, attempt {run['run_attempt']}"
208264

209265

266+
def format_jobs(jobs: list[dict]) -> str:
267+
return ", ".join(f"{job['name']} ({job['id']})" for job in jobs)
268+
269+
210270
def parse_github_time(value: str) -> datetime:
211271
return datetime.fromisoformat(value.replace("Z", "+00:00"))
212272

0 commit comments

Comments
 (0)