Skip to content

Commit 4e9f10c

Browse files
committed
fold GitHubRun into github launcher
1 parent 4067620 commit 4e9f10c

3 files changed

Lines changed: 159 additions & 162 deletions

File tree

src/discord-cluster-manager/bot.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import consts
66
import discord
7+
import env
78
import uvicorn
89
from api.main import app, init_api
910
from cogs.admin_cog import AdminCog
@@ -78,7 +79,7 @@ async def setup_hook(self):
7879
# Load cogs
7980
submit_cog = SubmitCog(self)
8081
submit_cog.register_launcher(ModalLauncher(consts.MODAL_CUDA_INCLUDE_DIRS))
81-
submit_cog.register_launcher(GitHubLauncher())
82+
submit_cog.register_launcher(GitHubLauncher(env.GITHUB_REPO, env.GITHUB_TOKEN))
8283
await self.add_cog(submit_cog)
8384
await self.add_cog(BotManagerCog(self))
8485
await self.add_cog(LeaderboardCog(self))

src/discord-cluster-manager/github_runner.py

Lines changed: 0 additions & 156 deletions
This file was deleted.

src/discord-cluster-manager/launchers/github.py

Lines changed: 157 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,28 @@
1+
import asyncio
12
import datetime
23
import json
4+
import pprint
5+
import tempfile
6+
import zipfile
7+
from typing import Awaitable, Callable, Optional
38

9+
import requests
410
from consts import AMD_REQUIREMENTS, GPU, NVIDIA_REQUIREMENTS, GitHubGPU, GPUType
5-
from github_runner import GitHubRun
11+
from github import Github, UnknownObjectException, WorkflowRun
612
from report import RunProgressReporter
713
from run_eval import CompileResult, EvalResult, FullResult, RunResult, SystemInfo
8-
from utils import setup_logging
14+
from utils import get_github_branch_name, setup_logging
915

1016
from .launcher import Launcher
1117

1218
logger = setup_logging()
1319

1420

1521
class GitHubLauncher(Launcher):
16-
def __init__(self):
22+
def __init__(self, repo: str, token: str):
1723
super().__init__(name="GitHub", gpus=GitHubGPU)
24+
self.repo = repo
25+
self.token = token
1826

1927
async def run_submission(
2028
self, config: dict, gpu_type: GPU, status: RunProgressReporter
@@ -37,7 +45,7 @@ async def run_submission(
3745
logger.info(f"Running on {gpu_name} amd gpu")
3846

3947
workflow_file = selected_gpu.value
40-
run = GitHubRun(workflow_file)
48+
run = GitHubRun(self.repo, self.token, workflow_file)
4149

4250
payload = json.dumps(config)
4351

@@ -89,8 +97,152 @@ async def run_submission(
8997
system = SystemInfo(**data.get("system", {}))
9098
return FullResult(success=True, error="", runs=runs, system=system)
9199

92-
async def wait_callback(self, run: GitHubRun, status: RunProgressReporter):
100+
async def wait_callback(self, run: "GitHubRun", status: RunProgressReporter):
93101
await status.update(
94102
f"⏳ Workflow [{run.run_id}]({run.html_url}): {run.status} "
95103
f"({run.elapsed_time.total_seconds():.1f}s)"
96104
)
105+
106+
107+
class GitHubRun:
108+
def __init__(self, repo: str, token: str, workflow_file: str):
109+
gh = Github(token)
110+
self.repo = gh.get_repo(repo)
111+
self.token = token
112+
self.workflow_file = workflow_file
113+
self.run: Optional[WorkflowRun.WorkflowRun] = None
114+
self.start_time = None
115+
116+
@property
117+
def run_id(self):
118+
if self.run is None:
119+
return None
120+
return self.run.id
121+
122+
@property
123+
def html_url(self):
124+
if self.run is None:
125+
return None
126+
return self.run.html_url
127+
128+
@property
129+
def status(self):
130+
if self.run is None:
131+
return None
132+
return self.run.status
133+
134+
@property
135+
def elapsed_time(self):
136+
if self.start_time is None:
137+
return None
138+
return datetime.datetime.now(datetime.timezone.utc) - self.start_time
139+
140+
async def trigger(self, inputs: dict) -> bool:
141+
"""
142+
Trigger this run with the provided inputs.
143+
Sets `self.run` to the new WorkflowRun on success.
144+
145+
Returns: Whether the run was successfully triggered,
146+
"""
147+
trigger_time = datetime.datetime.now(datetime.timezone.utc)
148+
try:
149+
workflow = self.repo.get_workflow(self.workflow_file)
150+
except UnknownObjectException as e:
151+
logger.error(f"Could not find workflow {self.workflow_file}", exc_info=e)
152+
raise ValueError(f"Could not find workflow {self.workflow_file}") from e
153+
154+
logger.debug(
155+
"Dispatching workflow %s on branch %s with inputs %s",
156+
self.workflow_file,
157+
get_github_branch_name(),
158+
pprint.pformat(inputs),
159+
)
160+
success = workflow.create_dispatch(get_github_branch_name(), inputs=inputs)
161+
if success:
162+
await asyncio.sleep(2)
163+
runs = list(workflow.get_runs())
164+
165+
for run in runs:
166+
if run.created_at.replace(tzinfo=datetime.timezone.utc) > trigger_time:
167+
self.run = run
168+
return True
169+
return False
170+
171+
async def wait_for_completion(
172+
self, callback: Callable[["GitHubRun"], Awaitable[None]], timeout_minutes: int = 5
173+
):
174+
if self.run is None:
175+
raise ValueError("Run needs to be triggered before a status check!")
176+
177+
self.start_time = datetime.datetime.now(datetime.timezone.utc)
178+
timeout = datetime.timedelta(minutes=timeout_minutes)
179+
180+
while True:
181+
try:
182+
# update run status
183+
self.run = run = self.repo.get_workflow_run(self.run_id)
184+
185+
if self.elapsed_time > timeout:
186+
try:
187+
self.run.cancel()
188+
# Wait briefly to ensure cancellation is processed
189+
# And Verify the run was actually cancelled
190+
await asyncio.sleep(5)
191+
run = self.repo.get_workflow_run(self.run_id)
192+
if run.status != "completed":
193+
logger.warning(f"Failed to cancel workflow run {self.run_id}")
194+
except Exception as e:
195+
logger.error(f"Error cancelling workflow: {str(e)}", exc_info=e)
196+
raise
197+
198+
logger.warning(
199+
f"Workflow {self.run_id} cancelled - "
200+
f"exceeded {timeout_minutes} minute timeout"
201+
)
202+
raise TimeoutError(
203+
f"Workflow {self.run_id} cancelled - "
204+
f"exceeded {timeout_minutes} minute timeout"
205+
)
206+
207+
if run.status == "completed":
208+
return
209+
210+
await callback(self)
211+
await asyncio.sleep(3)
212+
except TimeoutError:
213+
raise
214+
except Exception as e:
215+
logger.error(f"Error waiting for GitHub run {self.run_id}: {e}", exc_info=e)
216+
raise
217+
218+
async def download_artifacts(self) -> dict:
219+
logger.info("Attempting to download artifacts for run %s", self.run_id)
220+
artifacts = self.run.get_artifacts()
221+
222+
extracted = {}
223+
224+
for artifact in artifacts:
225+
url = artifact.archive_download_url
226+
headers = {"Authorization": f"token {self.token}"}
227+
response = requests.get(url, headers=headers)
228+
229+
if response.status_code == 200:
230+
with tempfile.NamedTemporaryFile("w+b") as temp:
231+
temp.write(response.content)
232+
temp.flush()
233+
234+
with zipfile.ZipFile(temp.name) as z:
235+
artifact_dict = {}
236+
for file in z.namelist():
237+
with z.open(file) as f:
238+
artifact_dict[file] = f.read()
239+
240+
extracted[artifact.name] = artifact_dict
241+
else:
242+
raise RuntimeError(
243+
f"Failed to download artifact {artifact.name}. "
244+
f"Status code: {response.status_code}"
245+
)
246+
247+
logger.info("Download artifacts for run %s: %s", self.run_id, list(extracted.keys()))
248+
return extracted

0 commit comments

Comments
 (0)