1+ import asyncio
12import datetime
23import json
4+ import pprint
5+ import tempfile
6+ import zipfile
7+ from typing import Awaitable , Callable , Optional
38
9+ import requests
410from consts import AMD_REQUIREMENTS , GPU , NVIDIA_REQUIREMENTS , GitHubGPU , GPUType
5- from github_runner import GitHubRun
11+ from github import Github , UnknownObjectException , WorkflowRun
612from report import RunProgressReporter
713from run_eval import CompileResult , EvalResult , FullResult , RunResult , SystemInfo
8- from utils import setup_logging
14+ from utils import get_github_branch_name , setup_logging
915
1016from .launcher import Launcher
1117
1218logger = setup_logging ()
1319
1420
1521class GitHubLauncher (Launcher ):
16- def __init__ (self ):
22+ def __init__ (self , repo : str , token : str ):
1723 super ().__init__ (name = "GitHub" , gpus = GitHubGPU )
24+ self .repo = repo
25+ self .token = token
1826
1927 async def run_submission (
2028 self , config : dict , gpu_type : GPU , status : RunProgressReporter
@@ -37,7 +45,7 @@ async def run_submission(
3745 logger .info (f"Running on { gpu_name } amd gpu" )
3846
3947 workflow_file = selected_gpu .value
40- run = GitHubRun (workflow_file )
48+ run = GitHubRun (self . repo , self . token , workflow_file )
4149
4250 payload = json .dumps (config )
4351
@@ -89,8 +97,152 @@ async def run_submission(
8997 system = SystemInfo (** data .get ("system" , {}))
9098 return FullResult (success = True , error = "" , runs = runs , system = system )
9199
92- async def wait_callback (self , run : GitHubRun , status : RunProgressReporter ):
100+ async def wait_callback (self , run : " GitHubRun" , status : RunProgressReporter ):
93101 await status .update (
94102 f"⏳ Workflow [{ run .run_id } ]({ run .html_url } ): { run .status } "
95103 f"({ run .elapsed_time .total_seconds ():.1f} s)"
96104 )
105+
106+
107+ class GitHubRun :
108+ def __init__ (self , repo : str , token : str , workflow_file : str ):
109+ gh = Github (token )
110+ self .repo = gh .get_repo (repo )
111+ self .token = token
112+ self .workflow_file = workflow_file
113+ self .run : Optional [WorkflowRun .WorkflowRun ] = None
114+ self .start_time = None
115+
116+ @property
117+ def run_id (self ):
118+ if self .run is None :
119+ return None
120+ return self .run .id
121+
122+ @property
123+ def html_url (self ):
124+ if self .run is None :
125+ return None
126+ return self .run .html_url
127+
128+ @property
129+ def status (self ):
130+ if self .run is None :
131+ return None
132+ return self .run .status
133+
134+ @property
135+ def elapsed_time (self ):
136+ if self .start_time is None :
137+ return None
138+ return datetime .datetime .now (datetime .timezone .utc ) - self .start_time
139+
140+ async def trigger (self , inputs : dict ) -> bool :
141+ """
142+ Trigger this run with the provided inputs.
143+ Sets `self.run` to the new WorkflowRun on success.
144+
145+ Returns: Whether the run was successfully triggered,
146+ """
147+ trigger_time = datetime .datetime .now (datetime .timezone .utc )
148+ try :
149+ workflow = self .repo .get_workflow (self .workflow_file )
150+ except UnknownObjectException as e :
151+ logger .error (f"Could not find workflow { self .workflow_file } " , exc_info = e )
152+ raise ValueError (f"Could not find workflow { self .workflow_file } " ) from e
153+
154+ logger .debug (
155+ "Dispatching workflow %s on branch %s with inputs %s" ,
156+ self .workflow_file ,
157+ get_github_branch_name (),
158+ pprint .pformat (inputs ),
159+ )
160+ success = workflow .create_dispatch (get_github_branch_name (), inputs = inputs )
161+ if success :
162+ await asyncio .sleep (2 )
163+ runs = list (workflow .get_runs ())
164+
165+ for run in runs :
166+ if run .created_at .replace (tzinfo = datetime .timezone .utc ) > trigger_time :
167+ self .run = run
168+ return True
169+ return False
170+
171+ async def wait_for_completion (
172+ self , callback : Callable [["GitHubRun" ], Awaitable [None ]], timeout_minutes : int = 5
173+ ):
174+ if self .run is None :
175+ raise ValueError ("Run needs to be triggered before a status check!" )
176+
177+ self .start_time = datetime .datetime .now (datetime .timezone .utc )
178+ timeout = datetime .timedelta (minutes = timeout_minutes )
179+
180+ while True :
181+ try :
182+ # update run status
183+ self .run = run = self .repo .get_workflow_run (self .run_id )
184+
185+ if self .elapsed_time > timeout :
186+ try :
187+ self .run .cancel ()
188+ # Wait briefly to ensure cancellation is processed
189+ # And Verify the run was actually cancelled
190+ await asyncio .sleep (5 )
191+ run = self .repo .get_workflow_run (self .run_id )
192+ if run .status != "completed" :
193+ logger .warning (f"Failed to cancel workflow run { self .run_id } " )
194+ except Exception as e :
195+ logger .error (f"Error cancelling workflow: { str (e )} " , exc_info = e )
196+ raise
197+
198+ logger .warning (
199+ f"Workflow { self .run_id } cancelled - "
200+ f"exceeded { timeout_minutes } minute timeout"
201+ )
202+ raise TimeoutError (
203+ f"Workflow { self .run_id } cancelled - "
204+ f"exceeded { timeout_minutes } minute timeout"
205+ )
206+
207+ if run .status == "completed" :
208+ return
209+
210+ await callback (self )
211+ await asyncio .sleep (3 )
212+ except TimeoutError :
213+ raise
214+ except Exception as e :
215+ logger .error (f"Error waiting for GitHub run { self .run_id } : { e } " , exc_info = e )
216+ raise
217+
218+ async def download_artifacts (self ) -> dict :
219+ logger .info ("Attempting to download artifacts for run %s" , self .run_id )
220+ artifacts = self .run .get_artifacts ()
221+
222+ extracted = {}
223+
224+ for artifact in artifacts :
225+ url = artifact .archive_download_url
226+ headers = {"Authorization" : f"token { self .token } " }
227+ response = requests .get (url , headers = headers )
228+
229+ if response .status_code == 200 :
230+ with tempfile .NamedTemporaryFile ("w+b" ) as temp :
231+ temp .write (response .content )
232+ temp .flush ()
233+
234+ with zipfile .ZipFile (temp .name ) as z :
235+ artifact_dict = {}
236+ for file in z .namelist ():
237+ with z .open (file ) as f :
238+ artifact_dict [file ] = f .read ()
239+
240+ extracted [artifact .name ] = artifact_dict
241+ else :
242+ raise RuntimeError (
243+ f"Failed to download artifact { artifact .name } . "
244+ f"Status code: { response .status_code } "
245+ )
246+
247+ logger .info ("Download artifacts for run %s: %s" , self .run_id , list (extracted .keys ()))
248+ return extracted
0 commit comments