Skip to content

Commit a9a086e

Browse files
committed
Feat: rate limiting
1 parent 7d97a31 commit a9a086e

13 files changed

Lines changed: 340 additions & 256 deletions

File tree

.ruff.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
line-length = 100 # ideally I want this to be less than 100 but don't wanna test and change files with longer lines
1+
line-length = 120 # ideally I want this to be less than 100 but don't wanna test and change files with longer lines
22
target-version = "py313"
33
lint.select = [
44
"E", # pycodestyle errors

AGENTS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
@CLAUDE.md

src/kernelbot/api/api_utils.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,7 @@ async def _handle_discord_oauth(code: str, redirect_uri: str) -> tuple[str, str]
7474
user_name = user_json.get("username")
7575

7676
if not user_id or not user_name:
77-
raise HTTPException(
78-
status_code=500, detail="Failed to retrieve user ID or username from Discord."
79-
)
77+
raise HTTPException(status_code=500, detail="Failed to retrieve user ID or username from Discord.")
8078

8179
return user_id, user_name
8280

@@ -135,16 +133,12 @@ async def _handle_github_oauth(code: str, redirect_uri: str) -> tuple[str, str]:
135133
user_name = user_json.get("login") # GitHub uses 'login' for username
136134

137135
if not user_id or not user_name:
138-
raise HTTPException(
139-
status_code=500, detail="Failed to retrieve user ID or username from GitHub."
140-
)
136+
raise HTTPException(status_code=500, detail="Failed to retrieve user ID or username from GitHub.")
141137

142138
return user_id, user_name
143139

144140

145-
async def _run_submission(
146-
submission: SubmissionRequest, mode: SubmissionMode, backend: KernelBackend
147-
):
141+
async def _run_submission(submission: SubmissionRequest, mode: SubmissionMode, backend: KernelBackend):
148142
try:
149143
req = prepare_submission(submission, backend)
150144
except Exception as e:
@@ -225,21 +219,6 @@ async def to_submit_info(
225219

226220
try:
227221
with db_context as db:
228-
# Per-user rate limit: max 1 submission per hour on Modal B200 for leaderboard 730
229-
if gpu_type == "B200":
230-
lb_id = db.get_leaderboard_id(leaderboard_name)
231-
if lb_id == 730:
232-
last_submission_time = db.check_user_rate_limit(user_id)
233-
if last_submission_time:
234-
raise HTTPException(
235-
status_code=429,
236-
detail=(
237-
f"Rate limit exceeded. You can submit once per hour. "
238-
f"Last submission: {last_submission_time.isoformat()}. "
239-
f"Consider using the NVIDIA runner instead of Modal for faster iteration."
240-
),
241-
)
242-
243222
leaderboard_item = db.get_leaderboard(leaderboard_name)
244223
gpus = leaderboard_item.get("gpu_types", [])
245224
if gpu_type not in gpus:

src/kernelbot/api/main.py

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343

4444
app = FastAPI()
4545

46+
4647
def json_serializer(obj):
4748
"""JSON serializer for objects not serializable by default json code"""
4849
if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
@@ -185,9 +186,7 @@ def require_admin(
185186
@app.get("/auth/init")
186187
async def auth_init(provider: str, db_context=Depends(get_db)) -> dict:
187188
if provider not in ["discord", "github"]:
188-
raise HTTPException(
189-
status_code=400, detail="Invalid provider, must be 'discord' or 'github'"
190-
)
189+
raise HTTPException(status_code=400, detail="Invalid provider, must be 'discord' or 'github'")
191190

192191
"""
193192
Initialize authentication flow for the specified provider.
@@ -230,9 +229,7 @@ async def cli_auth(auth_provider: str, code: str, state: str, db_context=Depends
230229
"""
231230

232231
if auth_provider not in ["discord", "github"]:
233-
raise HTTPException(
234-
status_code=400, detail="Invalid provider, must be 'discord' or 'github'"
235-
)
232+
raise HTTPException(status_code=400, detail="Invalid provider, must be 'discord' or 'github'")
236233

237234
if not code or not state:
238235
raise HTTPException(status_code=400, detail="Missing authorization code or state")
@@ -252,8 +249,7 @@ async def cli_auth(auth_provider: str, code: str, state: str, db_context=Depends
252249
if not api_base_url:
253250
raise HTTPException(
254251
status_code=500,
255-
detail="Redirect URI base not configured."
256-
"Set HEROKU_APP_DEFAULT_DOMAIN_NAME or POPCORN_API_URL.",
252+
detail="Redirect URI base not configured.Set HEROKU_APP_DEFAULT_DOMAIN_NAME or POPCORN_API_URL.",
257253
)
258254
redirect_uri_base = api_base_url.rstrip("/")
259255
redirect_uri = f"https://{redirect_uri_base}/auth/cli/{auth_provider}"
@@ -275,7 +271,10 @@ async def cli_auth(auth_provider: str, code: str, state: str, db_context=Depends
275271
raise HTTPException(status_code=500, detail=f"Error during {auth_provider} OAuth flow: {e}") from e
276272

277273
if not user_id or not user_name:
278-
raise HTTPException(status_code=500,detail="Failed to retrieve user ID or username from provider.",)
274+
raise HTTPException(
275+
status_code=500,
276+
detail="Failed to retrieve user ID or username from provider.",
277+
)
279278

280279
try:
281280
with db_context as db:
@@ -297,6 +296,7 @@ async def cli_auth(auth_provider: str, code: str, state: str, db_context=Depends
297296
"is_reset": is_reset,
298297
}
299298

299+
300300
async def _stream_submission_response(
301301
submission_request: SubmissionRequest,
302302
submission_mode_enum: SubmissionMode,
@@ -315,18 +315,18 @@ async def _stream_submission_response(
315315

316316
while not task.done():
317317
elapsed_time = time.time() - start_time
318-
yield f"event: status\ndata: {json.dumps({'status': 'processing',
319-
'elapsed_time': round(elapsed_time, 2)},
320-
default=json_serializer)}\n\n"
318+
yield f"event: status\ndata: {
319+
json.dumps({'status': 'processing', 'elapsed_time': round(elapsed_time, 2)}, default=json_serializer)
320+
}\n\n"
321321

322322
try:
323323
await asyncio.wait_for(asyncio.shield(task), timeout=15.0)
324324
except asyncio.TimeoutError:
325325
continue
326326
except asyncio.CancelledError:
327-
yield f"event: error\ndata: {json.dumps(
328-
{'status': 'error', 'detail': 'Submission cancelled'},
329-
default=json_serializer)}\n\n"
327+
yield f"event: error\ndata: {
328+
json.dumps({'status': 'error', 'detail': 'Submission cancelled'}, default=json_serializer)
329+
}\n\n"
330330
return
331331

332332
result, reports = await task
@@ -360,6 +360,7 @@ async def _stream_submission_response(
360360
except asyncio.CancelledError:
361361
pass
362362

363+
363364
@app.post("/{leaderboard_name}/{gpu_type}/{submission_mode}")
364365
async def run_submission( # noqa: C901
365366
leaderboard_name: str,
@@ -398,27 +399,28 @@ async def run_submission( # noqa: C901
398399
)
399400
return StreamingResponse(generator, media_type="text/event-stream")
400401

402+
401403
async def enqueue_background_job(
402404
req: ProcessedSubmissionRequest,
403405
mode: SubmissionMode,
404406
backend: KernelBackend,
405407
manager: BackgroundSubmissionManager,
406408
):
407-
408409
# pre-create the submission for api returns
409410
with backend.db as db:
410411
sub_id = db.create_submission(
411412
leaderboard=req.leaderboard,
412413
file_name=req.file_name,
413414
code=req.code,
414415
user_id=req.user_id,
415-
time=datetime.datetime.now(),
416+
time=datetime.datetime.now(datetime.timezone.utc),
416417
user_name=req.user_name,
417418
)
418419
job_id = db.upsert_submission_job_status(sub_id, "initial", None)
419420
# put submission request in queue
420421
await manager.enqueue(req, mode, sub_id)
421-
return sub_id,job_id
422+
return sub_id, job_id
423+
422424

423425
@app.post("/submission/{leaderboard_name}/{gpu_type}/{submission_mode}")
424426
async def run_submission_async(
@@ -445,15 +447,13 @@ async def run_submission_async(
445447
JSONResponse: A JSON response containing job_id and and submission_id for the client to poll for status.
446448
"""
447449
try:
448-
449450
await simple_rate_limit()
450451
logger.info(f"Received submission request for {leaderboard_name} {gpu_type} {submission_mode}")
451452

452-
453453
# throw error if submission request is invalid
454454
try:
455455
submission_request, submission_mode_enum = await to_submit_info(
456-
user_info, submission_mode, file, leaderboard_name, gpu_type, db_context
456+
user_info, submission_mode, file, leaderboard_name, gpu_type, db_context
457457
)
458458

459459
req = prepare_submission(submission_request, backend_instance)
@@ -466,13 +466,13 @@ async def run_submission_async(
466466
raise HTTPException(status_code=400, detail="Invalid GPU type")
467467

468468
# put submission request to background manager to run in background
469-
sub_id,job_status_id = await enqueue_background_job(
469+
sub_id, job_status_id = await enqueue_background_job(
470470
req, submission_mode_enum, backend_instance, background_submission_manager
471471
)
472472

473473
return JSONResponse(
474474
status_code=202,
475-
content={"details":{"id": sub_id, "job_status_id": job_status_id}, "status": "accepted"},
475+
content={"details": {"id": sub_id, "job_status_id": job_status_id}, "status": "accepted"},
476476
)
477477
# Preserve FastAPI HTTPException as-is
478478
except HTTPException:
@@ -536,8 +536,7 @@ async def create_dev_leaderboard(
536536
# GPUs must be specified in task.yml
537537
if not definition.gpus:
538538
raise HTTPException(
539-
status_code=400,
540-
detail="No gpus specified in task.yml. Add 'gpus:' field with list of GPU types."
539+
status_code=400, detail="No gpus specified in task.yml. Add 'gpus:' field with list of GPU types."
541540
)
542541

543542
with db_context as db:
@@ -629,7 +628,7 @@ async def admin_update_problems(
629628
branch=branch,
630629
force=force,
631630
creator_id=0, # API-created
632-
forum_id=-1, # No Discord forum
631+
forum_id=-1, # No Discord forum
633632
)
634633
except ValueError as e:
635634
raise HTTPException(status_code=400, detail=str(e)) from e
@@ -643,6 +642,33 @@ async def admin_update_problems(
643642
}
644643

645644

645+
@app.get("/leaderboard/rate-limits/{leaderboard_name}")
646+
async def get_leaderboard_rate_limits(leaderboard_name: str, db_context=Depends(get_db)) -> dict:
647+
with db_context as db:
648+
rate_limits = db.get_leaderboard_rate_limits(leaderboard_name)
649+
return {"status": "ok", "rate_limits": rate_limits}
650+
651+
652+
@app.post("/leaderboard/rate-limits/{leaderboard_name}/{gpu_type}")
653+
async def set_leaderboard_gpu_rate_limit(
654+
leaderboard_name: str,
655+
gpu_type: str,
656+
rate_limit_seconds: int,
657+
_: Annotated[None, Depends(require_admin)],
658+
db_context=Depends(get_db),
659+
) -> dict:
660+
if rate_limit_seconds <= 0:
661+
rate_limit_seconds = None
662+
with db_context as db:
663+
db.set_leaderboard_gpu_rate_limit(leaderboard_name, gpu_type, rate_limit_seconds)
664+
return {
665+
"status": "ok",
666+
"leaderboard_name": leaderboard_name,
667+
"gpu_type": gpu_type,
668+
"rate_limit_seconds": rate_limit_seconds,
669+
}
670+
671+
646672
@app.get("/leaderboards")
647673
async def get_leaderboards(db_context=Depends(get_db)):
648674
"""An endpoint that returns all leaderboards.
@@ -692,9 +718,7 @@ async def get_submissions(
692718
try:
693719
with db_context as db:
694720
# Add validation for leaderboard and GPU? Might be redundant if DB handles it.
695-
return db.get_leaderboard_submissions(
696-
leaderboard_name, gpu_name, limit=limit, offset=offset
697-
)
721+
return db.get_leaderboard_submissions(leaderboard_name, gpu_name, limit=limit, offset=offset)
698722
except Exception as e:
699723
raise HTTPException(status_code=500, detail=f"Error fetching submissions: {e}") from e
700724

0 commit comments

Comments
 (0)