|
| 1 | +import asyncio |
| 2 | +import mimetypes |
| 3 | +import os |
| 4 | + |
| 5 | +from fastapi import FastAPI, HTTPException |
| 6 | +from fastapi.datastructures import UploadFile |
| 7 | +from fastapi.param_functions import Depends, File |
| 8 | +from fastapi_throttle import RateLimiter |
| 9 | +from random_name import generate_name |
| 10 | +from starlette.middleware.authentication import AuthenticationMiddleware |
| 11 | +from starlette.middleware.cors import CORSMiddleware |
| 12 | +from starlette.responses import JSONResponse |
| 13 | + |
| 14 | +from ..constants import DATA_DIR |
| 15 | +from ..utils import _download_file_to_agentfs, handle_prompt |
| 16 | +from .auth import LobsterXAuthentication, on_auth_error |
| 17 | +from .shared import ( |
| 18 | + GetTaskResponse, |
| 19 | + TaskRequest, |
| 20 | + TaskResponse, |
| 21 | + UploadFileResponse, |
| 22 | + get_server_key_from_env, |
| 23 | + validate_api_key, |
| 24 | +) |
| 25 | +from .task_manager import get_task_manager |
| 26 | + |
| 27 | +DEFAULT_FILE_DOWNLOADS_PER_MINUTE = 300 |
| 28 | +DEFAULT_TASKS_PER_MINUTE = 60 |
| 29 | +DEFAULT_DELETE_TASKS_PER_MINUTE = 60 |
| 30 | +DEFAULT_POLL_TASKS_PER_MINUTE = 300 |
| 31 | + |
| 32 | + |
| 33 | +def _get_file_name(document: UploadFile) -> str: |
| 34 | + extension = ( |
| 35 | + mimetypes.guess_extension(document.content_type or "application/pdf") or ".pdf" |
| 36 | + ) |
| 37 | + if document.filename is None: |
| 38 | + return generate_name() + extension |
| 39 | + else: |
| 40 | + if document.filename.endswith(extension): |
| 41 | + return document.filename |
| 42 | + return document.filename + extension |
| 43 | + |
| 44 | + |
| 45 | +def create_api_app( |
| 46 | + allow_origins: list[str], |
| 47 | + file_downloads_per_minute: int | None, |
| 48 | + create_tasks_per_minute: int | None, |
| 49 | + delete_tasks_per_minute: int | None, |
| 50 | + poll_tasks_per_minute: int | None, |
| 51 | + server_api_key: str | None, |
| 52 | +) -> FastAPI: |
| 53 | + app = FastAPI() |
| 54 | + |
| 55 | + file_downloads_per_minute = ( |
| 56 | + file_downloads_per_minute or DEFAULT_FILE_DOWNLOADS_PER_MINUTE |
| 57 | + ) |
| 58 | + tasks_per_minute = create_tasks_per_minute or DEFAULT_TASKS_PER_MINUTE |
| 59 | + delete_tasks_per_minute = delete_tasks_per_minute or DEFAULT_DELETE_TASKS_PER_MINUTE |
| 60 | + poll_tasks_per_minute = poll_tasks_per_minute or DEFAULT_POLL_TASKS_PER_MINUTE |
| 61 | + api_key = server_api_key or get_server_key_from_env() |
| 62 | + if api_key is None: |
| 63 | + raise ValueError( |
| 64 | + "API key not provided and `LOBSTERX_SERVER_KEY` not found within the current environment" |
| 65 | + ) |
| 66 | + |
| 67 | + if not validate_api_key(api_key): |
| 68 | + raise ValueError( |
| 69 | + "API key should be an a string of letters, numbers, hyphens and underscores, with a minimum length of 32." |
| 70 | + ) |
| 71 | + |
| 72 | + app.add_middleware( |
| 73 | + CORSMiddleware, # type: ignore[invalid-argument-type] |
| 74 | + allow_origins=allow_origins, |
| 75 | + allow_methods=["POST", "GET", "DELETE"], |
| 76 | + allow_headers=["Content-Type", "Authorization"], |
| 77 | + ) |
| 78 | + |
| 79 | + app.add_middleware( |
| 80 | + AuthenticationMiddleware, # type: ignore[invalid-argument-type] |
| 81 | + backend=LobsterXAuthentication(api_key=api_key), |
| 82 | + on_error=on_auth_error, |
| 83 | + ) |
| 84 | + |
| 85 | + @app.post( |
| 86 | + "/files", |
| 87 | + dependencies=[ |
| 88 | + Depends(RateLimiter(times=file_downloads_per_minute, seconds=60)) |
| 89 | + ], |
| 90 | + ) |
| 91 | + async def download_file( |
| 92 | + file: UploadFile = File(...), |
| 93 | + ) -> UploadFileResponse: |
| 94 | + file_name = _get_file_name(file) |
| 95 | + file_content = await file.read() |
| 96 | + path = os.path.join(DATA_DIR, file_name) |
| 97 | + await _download_file_to_agentfs(path, file_content) |
| 98 | + return UploadFileResponse(new_file_path=path) |
| 99 | + |
| 100 | + @app.post( |
| 101 | + "/tasks", |
| 102 | + dependencies=[Depends(RateLimiter(times=tasks_per_minute, seconds=60))], |
| 103 | + ) |
| 104 | + async def create_task(request: TaskRequest) -> TaskResponse: |
| 105 | + task_manager = get_task_manager() |
| 106 | + task = asyncio.create_task(handle_prompt(request.prompt)) |
| 107 | + task_id = await task_manager.add_task(task) |
| 108 | + return TaskResponse(task_id=task_id) |
| 109 | + |
| 110 | + @app.delete( |
| 111 | + "/tasks/{task_id}", |
| 112 | + dependencies=[Depends(RateLimiter(times=delete_tasks_per_minute, seconds=60))], |
| 113 | + ) |
| 114 | + async def cancel_task(task_id: str) -> JSONResponse: |
| 115 | + task_manager = get_task_manager() |
| 116 | + await task_manager.cancel_task(task_id) |
| 117 | + return JSONResponse(status_code=204, content={}) |
| 118 | + |
| 119 | + @app.get( |
| 120 | + "/tasks/{task_id}", |
| 121 | + dependencies=[Depends(RateLimiter(times=poll_tasks_per_minute, seconds=60))], |
| 122 | + ) |
| 123 | + async def get_task(task_id: str) -> GetTaskResponse: |
| 124 | + task_manager = get_task_manager() |
| 125 | + task = await task_manager.check_task(task_id) |
| 126 | + if task is None: |
| 127 | + raise HTTPException( |
| 128 | + status_code=404, detail=f"Task {task_id} does not exist" |
| 129 | + ) |
| 130 | + return GetTaskResponse.from_dataclass(task) |
| 131 | + |
| 132 | + return app |
0 commit comments