|
4 | 4 |
|
5 | 5 | from fastapi import FastAPI, HTTPException, Request |
6 | 6 | from fastapi.exception_handlers import http_exception_handler |
| 7 | +from fastapi.responses import JSONResponse |
7 | 8 | from fastapi.routing import APIRouter |
8 | 9 | from loguru import logger |
9 | 10 |
|
|
29 | 30 | from basic_memory.config import init_api_logging |
30 | 31 | from basic_memory.services.exceptions import EntityAlreadyExistsError |
31 | 32 | from basic_memory.services.initialization import initialize_app |
| 33 | +from basic_memory.workspace_context import ( |
| 34 | + WORKSPACE_SLUG_HEADER, |
| 35 | + WORKSPACE_TYPE_HEADER, |
| 36 | + workspace_permalink_context_validation_error, |
| 37 | + workspace_permalink_context, |
| 38 | +) |
32 | 39 |
|
33 | 40 |
|
34 | 41 | @asynccontextmanager |
@@ -87,6 +94,32 @@ async def lifespan(app: FastAPI): # pragma: no cover |
87 | 94 | lifespan=lifespan, |
88 | 95 | ) |
89 | 96 |
|
| 97 | + |
| 98 | +@app.middleware("http") |
| 99 | +async def workspace_permalink_context_middleware(request: Request, call_next): |
| 100 | + """Populate workspace permalink context from request headers.""" |
| 101 | + workspace_slug = request.headers.get(WORKSPACE_SLUG_HEADER) |
| 102 | + workspace_type = request.headers.get(WORKSPACE_TYPE_HEADER) |
| 103 | + |
| 104 | + validation_error = workspace_permalink_context_validation_error(workspace_slug, workspace_type) |
| 105 | + if validation_error is not None: |
| 106 | + return JSONResponse( |
| 107 | + status_code=400, |
| 108 | + content={"detail": validation_error}, |
| 109 | + ) |
| 110 | + |
| 111 | + if not workspace_slug: |
| 112 | + return await call_next(request) |
| 113 | + |
| 114 | + # ContextVar state remains active across the awaited downstream handler while |
| 115 | + # this context manager is open, so entity creation can see request metadata. |
| 116 | + with workspace_permalink_context( |
| 117 | + workspace_slug=workspace_slug, |
| 118 | + workspace_type=workspace_type, |
| 119 | + ): |
| 120 | + return await call_next(request) |
| 121 | + |
| 122 | + |
90 | 123 | # Include v2 routers FIRST (more specific paths must match before /{project} catch-all) |
91 | 124 | app.include_router(v2_knowledge, prefix="/v2/projects/{project_id}") |
92 | 125 | app.include_router(v2_memory, prefix="/v2/projects/{project_id}") |
@@ -146,4 +179,7 @@ async def exception_handler(request, exc): # pragma: no cover |
146 | 179 | error_type=type(exc).__name__, |
147 | 180 | error=str(exc), |
148 | 181 | ) |
149 | | - return await http_exception_handler(request, HTTPException(status_code=500, detail=str(exc))) |
| 182 | + return await http_exception_handler( |
| 183 | + request, |
| 184 | + HTTPException(status_code=500, detail="Internal server error"), |
| 185 | + ) |
0 commit comments