|
1 | 1 | import asyncio |
| 2 | +import base64 |
| 3 | +import os |
2 | 4 | import time |
3 | 5 | from dataclasses import asdict |
4 | 6 |
|
| 7 | +import requests |
5 | 8 | from cogs.submit_cog import SubmitCog |
6 | 9 | from consts import _GPU_LOOKUP, SubmissionMode, get_gpu_by_name |
7 | 10 | from discord import app_commands |
| 11 | +from env import CLI_DISCORD_CLIENT_ID, CLI_DISCORD_CLIENT_SECRET, CLI_TOKEN_URL |
8 | 12 | from fastapi import FastAPI, HTTPException, UploadFile |
9 | 13 | from utils import LeaderboardItem, build_task_config |
10 | 14 |
|
@@ -53,6 +57,75 @@ async def update(self, message: str): |
53 | 57 | pass |
54 | 58 |
|
55 | 59 |
|
| 60 | +@app.get("/auth/cli") |
| 61 | +async def cli_auth(code: str, state: str = None): |
| 62 | + """ |
| 63 | + Handle Discord OAuth redirect. This endpoint receives the authorization code |
| 64 | + and state parameter from Discord's OAuth flow. |
| 65 | +
|
| 66 | + Args: |
| 67 | + code (str): Authorization code from Discord OAuth |
| 68 | + state (str): Base64 encoded client ID from CLI |
| 69 | + """ |
| 70 | + |
| 71 | + if not code or not state: |
| 72 | + raise HTTPException(status_code=400, detail="Missing authorization code or state") |
| 73 | + |
| 74 | + client_id = CLI_DISCORD_CLIENT_ID |
| 75 | + client_secret = CLI_DISCORD_CLIENT_SECRET |
| 76 | + redirect_uri = os.environ.get("HEROKU_APP_DEFAULT_DOMAIN_NAME") or os.getenv("POPCORN_API_URL") |
| 77 | + token_url = CLI_TOKEN_URL |
| 78 | + |
| 79 | + if not client_id or not client_secret: |
| 80 | + raise HTTPException(status_code=500, detail="Discord client ID or secret not configured.") |
| 81 | + |
| 82 | + if not token_url: |
| 83 | + raise HTTPException(status_code=500, detail="Discord token URL not configured.") |
| 84 | + |
| 85 | + if not redirect_uri: |
| 86 | + raise HTTPException( |
| 87 | + status_code=500, |
| 88 | + detail="Redirect URI not configured. " |
| 89 | + "If running locally, set env variable `POPCORN_API_URL` to your local API URL.", |
| 90 | + ) |
| 91 | + |
| 92 | + token_data = { |
| 93 | + "client_id": client_id, |
| 94 | + "client_secret": client_secret, |
| 95 | + "grant_type": "authorization_code", |
| 96 | + "code": code, |
| 97 | + "redirect_uri": redirect_uri + "/auth/cli", |
| 98 | + } |
| 99 | + |
| 100 | + token_response = requests.post(token_url, data=token_data) |
| 101 | + if token_response.status_code != 200: |
| 102 | + raise HTTPException( |
| 103 | + status_code=401, detail=f"Failed to authenticate with Discord: {token_response.text}" |
| 104 | + ) |
| 105 | + |
| 106 | + token_json = token_response.json() |
| 107 | + access_token = token_json.get("access_token") |
| 108 | + |
| 109 | + user_url = "https://discord.com/api/users/@me" |
| 110 | + headers = {"Authorization": f"Bearer {access_token}"} |
| 111 | + |
| 112 | + user_response = requests.get(user_url, headers=headers) |
| 113 | + if user_response.status_code != 200: |
| 114 | + raise HTTPException(status_code=401, detail="Failed to retrieve user information") |
| 115 | + |
| 116 | + user_json = user_response.json() |
| 117 | + user_id = user_json.get("id") |
| 118 | + |
| 119 | + cli_id = "" |
| 120 | + if state: |
| 121 | + try: |
| 122 | + cli_id = base64.b64decode(state).decode("utf-8") |
| 123 | + except Exception as e: |
| 124 | + raise HTTPException(status_code=400, detail=f"Invalid state parameter: {str(e)}") from e |
| 125 | + |
| 126 | + return {"status": "success", "user_id": user_id, "cli_id": cli_id} |
| 127 | + |
| 128 | + |
56 | 129 | @app.post("/{leaderboard_name}/{gpu_type}/{submission_mode}") |
57 | 130 | async def run_submission( |
58 | 131 | leaderboard_name: str, gpu_type: str, submission_mode: str, file: UploadFile |
|
0 commit comments