Skip to content

Commit 85c2d4f

Browse files
authored
Merge pull request #27 from AstraBert/feat/lobsterx-api
wip: lobsterx api
2 parents 0204e25 + 203096a commit 85c2d4f

20 files changed

Lines changed: 1240 additions & 13 deletions

File tree

.github/workflows/docker-image.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
name: Build and Publish Docker Image
22

33
on:
4-
push:
5-
branches: [main]
64
workflow_dispatch:
75

86
env:

packages/lobsterx/config.api.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"allow_origins": [],
3+
"file_downloads_per_minute": 300,
4+
"create_tasks_per_minute": 60,
5+
"delete_tasks_per_minute": 60,
6+
"poll_tasks_per_minute": 300,
7+
"host": "0.0.0.0",
8+
"port": 9000
9+
}

packages/lobsterx/pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,19 @@ build-backend = "uv_build"
44

55
[project]
66
name = "lobsterx"
7-
version = "0.1.1-beta"
7+
version = "0.2.0-beta"
88
description = "Background AI assistant working as a Telegram bot, built specifically for document-related use cases"
99
readme = "README.md"
1010
requires-python = ">=3.11"
1111
dependencies = [
1212
"aiofiles>=25.1.0",
1313
"diskcache>=5.6.3",
14+
"fastapi>=0.129.0",
15+
"fastapi-throttle>=0.1.8",
16+
"httpx>=0.28.1",
1417
"llama-cloud>=1.2.0,<1.3",
1518
"python-dotenv>=1.2.1",
19+
"python-multipart>=0.0.21",
1620
"python-telegram-bot>=22.6",
1721
"random-name>=0.1.1",
1822
"workflows-acp",

packages/lobsterx/src/lobsterx/api/__init__.py

Whitespace-only changes.
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import asyncio
2+
import mimetypes
3+
import os
4+
5+
from fastapi import FastAPI, HTTPException
6+
from fastapi.datastructures import UploadFile
7+
from fastapi.param_functions import Depends, File
8+
from fastapi_throttle import RateLimiter
9+
from random_name import generate_name
10+
from starlette.middleware.authentication import AuthenticationMiddleware
11+
from starlette.middleware.cors import CORSMiddleware
12+
from starlette.responses import JSONResponse
13+
14+
from ..constants import DATA_DIR
15+
from ..utils import _download_file_to_agentfs, handle_prompt
16+
from .auth import LobsterXAuthentication, on_auth_error
17+
from .shared import (
18+
GetTaskResponse,
19+
TaskRequest,
20+
TaskResponse,
21+
UploadFileResponse,
22+
get_server_key_from_env,
23+
validate_api_key,
24+
)
25+
from .task_manager import get_task_manager
26+
27+
DEFAULT_FILE_DOWNLOADS_PER_MINUTE = 300
28+
DEFAULT_TASKS_PER_MINUTE = 60
29+
DEFAULT_DELETE_TASKS_PER_MINUTE = 60
30+
DEFAULT_POLL_TASKS_PER_MINUTE = 300
31+
32+
33+
def _get_file_name(document: UploadFile) -> str:
34+
extension = (
35+
mimetypes.guess_extension(document.content_type or "application/pdf") or ".pdf"
36+
)
37+
if document.filename is None:
38+
return generate_name() + extension
39+
else:
40+
if document.filename.endswith(extension):
41+
return document.filename
42+
return document.filename + extension
43+
44+
45+
def create_api_app(
46+
allow_origins: list[str],
47+
file_downloads_per_minute: int | None,
48+
create_tasks_per_minute: int | None,
49+
delete_tasks_per_minute: int | None,
50+
poll_tasks_per_minute: int | None,
51+
server_api_key: str | None,
52+
) -> FastAPI:
53+
app = FastAPI()
54+
55+
file_downloads_per_minute = (
56+
file_downloads_per_minute or DEFAULT_FILE_DOWNLOADS_PER_MINUTE
57+
)
58+
tasks_per_minute = create_tasks_per_minute or DEFAULT_TASKS_PER_MINUTE
59+
delete_tasks_per_minute = delete_tasks_per_minute or DEFAULT_DELETE_TASKS_PER_MINUTE
60+
poll_tasks_per_minute = poll_tasks_per_minute or DEFAULT_POLL_TASKS_PER_MINUTE
61+
api_key = server_api_key or get_server_key_from_env()
62+
if api_key is None:
63+
raise ValueError(
64+
"API key not provided and `LOBSTERX_SERVER_KEY` not found within the current environment"
65+
)
66+
67+
if not validate_api_key(api_key):
68+
raise ValueError(
69+
"API key should be an a string of letters, numbers, hyphens and underscores, with a minimum length of 32."
70+
)
71+
72+
app.add_middleware(
73+
CORSMiddleware, # type: ignore[invalid-argument-type]
74+
allow_origins=allow_origins,
75+
allow_methods=["POST", "GET", "DELETE"],
76+
allow_headers=["Content-Type", "Authorization"],
77+
)
78+
79+
app.add_middleware(
80+
AuthenticationMiddleware, # type: ignore[invalid-argument-type]
81+
backend=LobsterXAuthentication(api_key=api_key),
82+
on_error=on_auth_error,
83+
)
84+
85+
@app.post(
86+
"/files",
87+
dependencies=[
88+
Depends(RateLimiter(times=file_downloads_per_minute, seconds=60))
89+
],
90+
)
91+
async def download_file(
92+
file: UploadFile = File(...),
93+
) -> UploadFileResponse:
94+
file_name = _get_file_name(file)
95+
file_content = await file.read()
96+
path = os.path.join(DATA_DIR, file_name)
97+
await _download_file_to_agentfs(path, file_content)
98+
return UploadFileResponse(new_file_path=path)
99+
100+
@app.post(
101+
"/tasks",
102+
dependencies=[Depends(RateLimiter(times=tasks_per_minute, seconds=60))],
103+
)
104+
async def create_task(request: TaskRequest) -> TaskResponse:
105+
task_manager = get_task_manager()
106+
task = asyncio.create_task(handle_prompt(request.prompt))
107+
task_id = await task_manager.add_task(task)
108+
return TaskResponse(task_id=task_id)
109+
110+
@app.delete(
111+
"/tasks/{task_id}",
112+
dependencies=[Depends(RateLimiter(times=delete_tasks_per_minute, seconds=60))],
113+
)
114+
async def cancel_task(task_id: str) -> JSONResponse:
115+
task_manager = get_task_manager()
116+
await task_manager.cancel_task(task_id)
117+
return JSONResponse(status_code=204, content={})
118+
119+
@app.get(
120+
"/tasks/{task_id}",
121+
dependencies=[Depends(RateLimiter(times=poll_tasks_per_minute, seconds=60))],
122+
)
123+
async def get_task(task_id: str) -> GetTaskResponse:
124+
task_manager = get_task_manager()
125+
task = await task_manager.check_task(task_id)
126+
if task is None:
127+
raise HTTPException(
128+
status_code=404, detail=f"Task {task_id} does not exist"
129+
)
130+
return GetTaskResponse.from_dataclass(task)
131+
132+
return app
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from starlette.authentication import (
2+
AuthCredentials,
3+
AuthenticationBackend,
4+
AuthenticationError,
5+
BaseUser,
6+
)
7+
from starlette.requests import HTTPConnection
8+
from starlette.responses import PlainTextResponse, Response
9+
10+
from .shared import get_auth_header_pattern
11+
12+
DEFAULT_DISPLAY_NAME = "lobsterx"
13+
DEFAULT_IDENTITY = "user"
14+
15+
16+
class LobsterXUser(BaseUser):
17+
def __init__(self, authenticated: bool) -> None:
18+
self._authenticated = authenticated
19+
20+
@property
21+
def is_authenticated(self) -> bool:
22+
return self._authenticated
23+
24+
@property
25+
def identity(self) -> str:
26+
return DEFAULT_IDENTITY
27+
28+
@property
29+
def display_name(self) -> str:
30+
return DEFAULT_DISPLAY_NAME
31+
32+
33+
class LobsterXAuthentication(AuthenticationBackend):
34+
def __init__(self, api_key: str) -> None:
35+
self.api_key = api_key
36+
37+
async def authenticate(
38+
self, conn: HTTPConnection
39+
) -> tuple[AuthCredentials, BaseUser] | None:
40+
auth_header = conn.headers.get("Authorization", None)
41+
if auth_header is None:
42+
raise AuthenticationError("No authorization header in request")
43+
matches = get_auth_header_pattern().findall(auth_header)
44+
try:
45+
assert len(matches) == 1, "Should only provide one bearer token"
46+
except AssertionError as e:
47+
raise AuthenticationError("Should only provide one bearer token") from e
48+
api_key = matches[0]
49+
if api_key == self.api_key:
50+
return AuthCredentials(scopes=["http"]), LobsterXUser(authenticated=True)
51+
raise AuthenticationError("API key not authorized")
52+
53+
54+
def on_auth_error(conn: HTTPConnection, exc: AuthenticationError) -> Response:
55+
return PlainTextResponse(str(exc), status_code=401)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import os
2+
from mimetypes import guess_type
3+
from typing import Literal
4+
5+
from httpx import AsyncClient
6+
7+
from .shared import (
8+
GetTaskResponse,
9+
TaskRequest,
10+
TaskResponse,
11+
UploadFileResponse,
12+
get_server_key_from_env,
13+
validate_api_key,
14+
)
15+
16+
17+
class LobsterXClient:
18+
def __init__(
19+
self,
20+
api_key: str | None,
21+
host: str,
22+
port: int,
23+
protocol: Literal["http", "https"],
24+
) -> None:
25+
self.base_url = f"{protocol}://{host}:{port}"
26+
self.api_key = api_key or get_server_key_from_env()
27+
if self.api_key is None:
28+
raise ValueError(
29+
"API key not provided and `LOBSTERX_SERVER_KEY` not found within the current environment"
30+
)
31+
if not validate_api_key(self.api_key):
32+
raise ValueError(
33+
"API key should be an a string of letters, numbers, hyphens and underscores, with a minimum length of 32."
34+
)
35+
36+
async def upload_file(self, file_path: str) -> str:
37+
async with AsyncClient(
38+
base_url=self.base_url,
39+
headers={"Authorization": f"Bearer {self.api_key}"},
40+
timeout=600,
41+
) as client:
42+
with open(file_path, "rb") as f:
43+
mimetype, _ = guess_type(file_path)
44+
file_type = mimetype or "application/pdf"
45+
file = (os.path.basename(file_path), f, file_type)
46+
response = await client.post("/files", files={"file": file})
47+
response.raise_for_status()
48+
payload = response.json()
49+
validated = UploadFileResponse.model_validate(payload)
50+
return validated.new_file_path
51+
52+
async def create_task(self, prompt: str) -> str:
53+
async with AsyncClient(
54+
base_url=self.base_url,
55+
headers={"Authorization": f"Bearer {self.api_key}"},
56+
timeout=600,
57+
) as client:
58+
payload = TaskRequest(prompt=prompt).model_dump()
59+
response = await client.post("/tasks", json=payload)
60+
response.raise_for_status()
61+
json_response = response.json()
62+
validated = TaskResponse.model_validate(json_response)
63+
return validated.task_id
64+
65+
async def get_task(self, task_id: str) -> GetTaskResponse:
66+
async with AsyncClient(
67+
base_url=self.base_url,
68+
headers={"Authorization": f"Bearer {self.api_key}"},
69+
timeout=600,
70+
) as client:
71+
response = await client.get(f"/tasks/{task_id}")
72+
response.raise_for_status()
73+
json_response = response.json()
74+
return GetTaskResponse.model_validate(json_response)
75+
76+
async def cancel_task(self, task_id: str) -> None:
77+
async with AsyncClient(
78+
base_url=self.base_url,
79+
headers={"Authorization": f"Bearer {self.api_key}"},
80+
timeout=600,
81+
) as client:
82+
response = await client.delete(f"/tasks/{task_id}")
83+
response.raise_for_status()
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import functools
2+
import json
3+
import os
4+
import re
5+
from dataclasses import dataclass
6+
from typing import Any, Literal
7+
8+
from dotenv import load_dotenv
9+
from pydantic import BaseModel
10+
11+
from .task_manager import StatusEnum, TaskRepr
12+
13+
14+
@dataclass
15+
class LobsterXApiConfig:
16+
allow_origins: list[str]
17+
file_downloads_per_minute: int | None = None
18+
create_tasks_per_minute: int | None = None
19+
delete_tasks_per_minute: int | None = None
20+
poll_tasks_per_minute: int | None = None
21+
server_api_key: str | None = None
22+
host: str | None = None
23+
port: int | None = None
24+
protocol: Literal["http", "https"] = "http"
25+
26+
@classmethod
27+
def load_from_config(cls, config_file: str) -> "LobsterXApiConfig":
28+
with open(config_file, "r") as f:
29+
config = json.load(f)
30+
return cls(**config)
31+
32+
def to_args(self) -> dict[str, Any]:
33+
return {
34+
"allow_origins": self.allow_origins,
35+
"file_downloads_per_minute": self.file_downloads_per_minute,
36+
"create_tasks_per_minute": self.create_tasks_per_minute,
37+
"delete_tasks_per_minute": self.delete_tasks_per_minute,
38+
"poll_tasks_per_minute": self.poll_tasks_per_minute,
39+
"server_api_key": self.server_api_key,
40+
}
41+
42+
43+
@functools.lru_cache(maxsize=1)
44+
def get_api_key_pattern() -> re.Pattern:
45+
return re.compile(r"[a-zA-Z0-9_-]{32,}")
46+
47+
48+
@functools.lru_cache(maxsize=1)
49+
def get_auth_header_pattern() -> re.Pattern:
50+
return re.compile(r"Bearer\s([a-zA-Z0-9_-]{32,})")
51+
52+
53+
def validate_api_key(api_key: str) -> bool:
54+
pattern = get_api_key_pattern()
55+
return pattern.match(api_key) is not None
56+
57+
58+
@functools.lru_cache(maxsize=1)
59+
def get_server_key_from_env() -> str | None:
60+
load_dotenv(".env")
61+
return os.getenv("LOBSTERX_SERVER_KEY")
62+
63+
64+
class TaskRequest(BaseModel):
65+
prompt: str
66+
67+
68+
class TaskResponse(BaseModel):
69+
task_id: str
70+
71+
72+
class GetTaskResponse(BaseModel):
73+
status: StatusEnum
74+
output: tuple[str, str] | None = None
75+
error: str | None = None
76+
77+
@classmethod
78+
def from_dataclass(cls, task_repr: TaskRepr) -> "GetTaskResponse":
79+
return cls(
80+
status=task_repr.status, output=task_repr.output, error=task_repr.error
81+
)
82+
83+
84+
class UploadFileResponse(BaseModel):
85+
new_file_path: str

0 commit comments

Comments
 (0)