|
1 | | -import secrets |
2 | | -from typing import Optional |
| 1 | +from dataclasses import dataclass, field |
| 2 | +from functools import lru_cache |
| 3 | +from typing import Any, Dict, Optional |
3 | 4 |
|
4 | | -from fastapi import Header, HTTPException |
| 5 | +import jwt |
| 6 | +from cryptography.hazmat.primitives import serialization |
| 7 | +from fastapi import Header, HTTPException, Request |
| 8 | +from loguru import logger |
5 | 9 |
|
6 | 10 | from app.shared.config import get_settings |
7 | 11 |
|
8 | 12 |
|
9 | | -async def verify_api_key(x_api_key: Optional[str] = Header(None)) -> None: |
10 | | - """Require a matching x-api-key header when an API key is configured. |
| 13 | +@dataclass(frozen=True) |
| 14 | +class AuthContext: |
| 15 | + """Verified identity from a LibreChat code-API JWT.""" |
11 | 16 |
|
12 | | - Auth is disabled when ``API_KEY`` is unset. Settings are read at request |
13 | | - time so the key can be toggled in tests. |
| 17 | + enabled: bool # False when auth is unconfigured (open mode) |
| 18 | + sub: Optional[str] = None # LibreChat user id (trustworthy identity) |
| 19 | + tenant_id: Optional[str] = None |
| 20 | + role: Optional[str] = None |
| 21 | + claims: Dict[str, Any] = field(default_factory=dict) |
| 22 | + |
| 23 | + |
| 24 | +@lru_cache(maxsize=4) |
| 25 | +def _load_public_key(pem: str): |
| 26 | + """Parse a PEM public key; cached on the PEM string so settings swaps in tests just work.""" |
| 27 | + return serialization.load_pem_public_key(pem.encode("utf-8")) |
| 28 | + |
| 29 | + |
| 30 | +def _unauthorized() -> HTTPException: |
| 31 | + return HTTPException(status_code=401, detail="Unauthorized", headers={"WWW-Authenticate": "Bearer"}) |
| 32 | + |
| 33 | + |
| 34 | +async def verify_jwt(request: Request, authorization: Optional[str] = Header(None)) -> AuthContext: |
| 35 | + """Verify the LibreChat-minted Bearer JWT on the request. |
| 36 | +
|
| 37 | + Auth is disabled when no public key is configured. Settings are read at |
| 38 | + request time so the key can be toggled in tests. The verified claims are |
| 39 | + attached to ``request.state.auth`` for route handlers. |
14 | 40 | """ |
15 | 41 | settings = get_settings() |
16 | | - if settings.API_KEY is None: |
17 | | - return |
18 | | - if x_api_key is None or not secrets.compare_digest(x_api_key, settings.API_KEY): |
19 | | - raise HTTPException(status_code=401, detail="Unauthorized") |
| 42 | + pem = settings.JWT_PUBLIC_KEY_PEM |
| 43 | + if pem is None: |
| 44 | + context = AuthContext(enabled=False) |
| 45 | + request.state.auth = context |
| 46 | + return context |
| 47 | + |
| 48 | + if authorization is None: |
| 49 | + raise _unauthorized() |
| 50 | + scheme, _, token = authorization.partition(" ") |
| 51 | + if scheme.lower() != "bearer" or not token.strip(): |
| 52 | + raise _unauthorized() |
| 53 | + token = token.strip() |
| 54 | + |
| 55 | + try: |
| 56 | + key = _load_public_key(pem) |
| 57 | + except ValueError: |
| 58 | + logger.error("CODEAPI_JWT_PUBLIC_KEY is not a valid PEM public key") |
| 59 | + raise HTTPException(status_code=500, detail="Server authentication misconfigured") |
| 60 | + |
| 61 | + try: |
| 62 | + kid = jwt.get_unverified_header(token).get("kid") |
| 63 | + logger.debug(f"Verifying code-API token with kid={kid}") |
| 64 | + except jwt.InvalidTokenError: |
| 65 | + raise _unauthorized() |
| 66 | + |
| 67 | + try: |
| 68 | + payload = jwt.decode( |
| 69 | + token, |
| 70 | + key, |
| 71 | + algorithms=settings.CODEAPI_JWT_ALGORITHMS, |
| 72 | + audience=settings.CODEAPI_JWT_AUDIENCE, |
| 73 | + issuer=settings.CODEAPI_JWT_ISSUER, |
| 74 | + leeway=settings.CODEAPI_JWT_LEEWAY, |
| 75 | + options={"require": ["exp", "iat", "sub"]}, |
| 76 | + ) |
| 77 | + except (jwt.InvalidTokenError, TypeError, ValueError) as exc: |
| 78 | + # PyJWT raises TypeError/ValueError (not InvalidTokenError) when the |
| 79 | + # token's alg does not match the configured key type |
| 80 | + logger.warning(f"Rejected code-API token: {exc}") |
| 81 | + raise _unauthorized() |
| 82 | + |
| 83 | + context = AuthContext( |
| 84 | + enabled=True, |
| 85 | + sub=payload["sub"], |
| 86 | + tenant_id=payload.get("tenant_id"), |
| 87 | + role=payload.get("role"), |
| 88 | + claims=payload, |
| 89 | + ) |
| 90 | + request.state.auth = context |
| 91 | + return context |
0 commit comments