1515import subprocess
1616import sys
1717import tempfile
18+ from collections .abc import Callable
1819from dataclasses import dataclass , field
1920from pathlib import Path
2021from typing import Any
@@ -28,23 +29,6 @@ class WorkflowError(RuntimeError):
2829 """A user-facing workflow failure."""
2930
3031
31- class CommandError (WorkflowError ):
32- def __init__ (self , cmd : list [str ], returncode : int , stdout : str , stderr : str ):
33- super ().__init__ (f"command failed ({ returncode } ): { format_cmd (cmd )} " )
34- self .cmd = cmd
35- self .returncode = returncode
36- self .stdout = stdout
37- self .stderr = stderr
38-
39-
40- @dataclass
41- class RunResult :
42- cmd : list [str ]
43- returncode : int
44- stdout : str
45- stderr : str
46-
47-
4832@dataclass
4933class Summary :
5034 pr : int
@@ -58,12 +42,8 @@ class Summary:
5842 push_result : str | None = None
5943 failures : list [str ] = field (default_factory = list )
6044 notes : list [str ] = field (default_factory = list )
61- commands : list [str ] = field (default_factory = list )
6245 temp_dir : str | None = None
6346
64- def add_command (self , cmd : list [str ]) -> None :
65- self .commands .append (format_cmd (cmd ))
66-
6747 def print_text (self ) -> None :
6848 print (f"PR: #{ self .pr } " )
6949 if self .original_branch :
@@ -97,9 +77,6 @@ def print_text(self) -> None:
9777 for note in self .notes :
9878 print (f"- { note } " )
9979
100- def print_json (self ) -> None :
101- print (json .dumps (self .__dict__ , indent = 2 , sort_keys = True ))
102-
10380
10481def format_cmd (cmd : list [str ]) -> str :
10582 return " " .join (shlex .quote (part ) for part in cmd )
@@ -116,30 +93,25 @@ def progress(message: str) -> None:
11693 print (f"[pr-triage] { message } " , flush = True )
11794
11895
119- def run (cmd : list [str ], summary : Summary | None = None , check : bool = True ) -> RunResult :
96+ def run (cmd : list [str ], summary : Summary | None = None , check : bool = True ) -> subprocess . CompletedProcess [ str ] :
12097 if summary is not None :
121- summary .add_command (cmd )
12298 progress (f"Running: { format_cmd (cmd )} " )
123- proc = subprocess .run (
99+ return subprocess .run (
124100 cmd ,
125101 cwd = REPO_ROOT ,
126102 capture_output = True ,
127103 text = True ,
128104 encoding = "utf-8" ,
129105 errors = "replace" ,
130- check = False ,
106+ check = check ,
131107 )
132- result = RunResult (cmd , proc .returncode , proc .stdout , proc .stderr )
133- if check and proc .returncode != 0 :
134- raise CommandError (cmd , proc .returncode , proc .stdout , proc .stderr )
135- return result
136108
137109
138- def git (args : list [str ], summary : Summary | None = None , check : bool = True ) -> RunResult :
110+ def git (args : list [str ], summary : Summary | None = None , check : bool = True ) -> subprocess . CompletedProcess [ str ] :
139111 return run (["git" , * args ], summary , check )
140112
141113
142- def gh (args : list [str ], summary : Summary | None = None , check : bool = True ) -> RunResult :
114+ def gh (args : list [str ], summary : Summary | None = None , check : bool = True ) -> subprocess . CompletedProcess [ str ] :
143115 return run (["gh" , * args ], summary , check )
144116
145117
@@ -245,6 +217,31 @@ def checkout_pr(pr: int, summary: Summary) -> dict[str, Any]:
245217 return ensure_pr_push_allowed (pr , summary )
246218
247219
220+ def checkout_pr_no_push_check (pr : int , summary : Summary ) -> None :
221+ progress (f"Checking out PR #{ pr } " )
222+ gh (["pr" , "checkout" , str (pr )], summary )
223+ summary .pr_branch = current_branch (summary )
224+
225+
226+ def run_pr_workflow (pr : int , body : Callable [[Summary ], int ], * , push_required : bool = True ) -> int :
227+ summary = Summary (pr = pr )
228+ try :
229+ require_clean_worktree (summary )
230+ summary .original_branch = current_branch (summary )
231+ if push_required :
232+ checkout_pr (pr , summary )
233+ else :
234+ checkout_pr_no_push_check (pr , summary )
235+ return body (summary )
236+ except Exception as e :
237+ summary .outcome = "failed"
238+ print_failure (e )
239+ return 1
240+ finally :
241+ restore_original_branch (summary )
242+ summary .print_text ()
243+
244+
248245def gradlew_cmd (task : str ) -> list [str ]:
249246 if os .name == "nt" :
250247 return [str (REPO_ROOT / "gradlew.bat" ), task ]
@@ -268,12 +265,9 @@ def commit_all_tracked(message: str | list[str], summary: Summary) -> str:
268265
269266def push (summary : Summary ) -> None :
270267 progress ("Pushing PR branch" )
271- result = git (["push" ], summary , check = False )
272- if result .returncode != 0 :
273- raise CommandError (result .cmd , result .returncode , result .stdout , result .stderr )
268+ git (["push" ], summary )
274269 summary .push_result = "pushed successfully"
275270
276-
277271def diff_check (summary : Summary ) -> None :
278272 git (["diff" , "--check" ], summary )
279273
@@ -294,21 +288,18 @@ def make_temp_dir(prefix: str, pr: int, keep_temp: bool) -> Path:
294288def download_actions_job_log (owner : str , repo : str , job_id : int , path : Path , summary : Summary ) -> None :
295289 api_path = f"repos/{ owner } /{ repo } /actions/jobs/{ job_id } /logs"
296290 cmd = ["gh" , "api" , "-H" , "Accept: application/vnd.github+json" , api_path ]
297- summary .add_command ([* cmd , ">" , str (path )])
298291 progress (f"Downloading Actions job log { job_id } to { path } " )
299292 with path .open ("wb" ) as output :
300- proc = subprocess .run (
293+ subprocess .run (
301294 cmd ,
302295 cwd = REPO_ROOT ,
303296 stdout = output ,
304297 stderr = subprocess .PIPE ,
305298 text = True ,
306299 encoding = "utf-8" ,
307300 errors = "replace" ,
308- check = False ,
301+ check = True ,
309302 )
310- if proc .returncode != 0 :
311- raise CommandError (cmd , proc .returncode , "" , proc .stderr )
312303
313304
314305def extract_job_id (check : dict [str , Any ]) -> int | None :
@@ -322,7 +313,6 @@ def extract_job_id(check: dict[str, Any]) -> int | None:
322313
323314def invoke_copilot (prompt : str , summary : Summary ) -> str :
324315 cmd = ["copilot" , "-p" , prompt , "--allow-all-tools" , "--model" , COPILOT_MODEL ]
325- summary .add_command (["copilot" , "-p" , "<generated prompt>" , "--allow-all-tools" , "--model" , COPILOT_MODEL ])
326316 progress (f"Handing off to Copilot CLI using { COPILOT_MODEL } ; streaming output below" )
327317 proc = subprocess .Popen (
328318 cmd ,
@@ -342,9 +332,9 @@ def invoke_copilot(prompt: str, summary: Summary) -> str:
342332 returncode = proc .wait ()
343333 output = "" .join (output_parts )
344334 if returncode != 0 :
345- raise CommandError (
346- ["copilot" , "-p" , "<generated prompt>" , "--allow-all-tools" , "--model" , COPILOT_MODEL ],
335+ raise subprocess .CalledProcessError (
347336 returncode ,
337+ ["copilot" , "-p" , "<generated prompt>" , "--allow-all-tools" , "--model" , COPILOT_MODEL ],
348338 output ,
349339 "" ,
350340 )
@@ -358,10 +348,10 @@ def write_json(path: Path, value: Any) -> None:
358348
359349def print_failure (error : Exception ) -> None :
360350 print (f"ERROR: { error } " , file = sys .stderr )
361- if isinstance (error , CommandError ):
362- if error .stdout .strip ():
351+ if isinstance (error , subprocess . CalledProcessError ):
352+ if error .stdout and error . stdout .strip ():
363353 print ("--- stdout ---" , file = sys .stderr )
364354 print (error .stdout , file = sys .stderr )
365- if error .stderr .strip ():
355+ if error .stderr and error . stderr .strip ():
366356 print ("--- stderr ---" , file = sys .stderr )
367357 print (error .stderr , file = sys .stderr )
0 commit comments