|
5 | 5 | import sys |
6 | 6 | from importlib.resources import files |
7 | 7 | from typing import Any, Literal |
| 8 | +from urllib.parse import urlparse, urlunparse |
8 | 9 |
|
9 | 10 | from alembic import command |
10 | 11 | from alembic.config import Config |
11 | 12 | from fastapi import FastAPI |
12 | 13 | from fastapi.middleware.cors import CORSMiddleware |
13 | 14 | from fastapi.routing import APIRoute |
14 | 15 | from pydantic import BaseModel |
| 16 | +from starlette.datastructures import MutableHeaders |
| 17 | +from starlette.types import ASGIApp, Receive, Scope, Send |
15 | 18 |
|
16 | 19 | from ..deployment.monitors.health import vm_monitor |
17 | 20 | from .backup import router as backup_router |
@@ -90,6 +93,39 @@ def _logging_config() -> dict[str, Any]: |
90 | 93 | logger = logging.getLogger(__name__) |
91 | 94 |
|
92 | 95 |
|
| 96 | +class _RootPathRedirectMiddleware: |
| 97 | + """Fix Starlette's automatic redirect responses to include root_path. |
| 98 | +
|
| 99 | + Starlette's Router generates trailing-slash redirects (307) using |
| 100 | + URL(scope=...) which only reads scope["path"] and ignores |
| 101 | + scope["root_path"]. This middleware rewrites the Location header on |
| 102 | + redirect responses so the root_path prefix is included. |
| 103 | + """ |
| 104 | + |
| 105 | + def __init__(self, app: ASGIApp, root_path: str = "") -> None: |
| 106 | + self.app = app |
| 107 | + self.root_path = root_path |
| 108 | + |
| 109 | + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
| 110 | + if scope["type"] != "http" or not self.root_path: |
| 111 | + await self.app(scope, receive, send) |
| 112 | + return |
| 113 | + |
| 114 | + root_path = self.root_path |
| 115 | + |
| 116 | + async def send_wrapper(message: dict) -> None: |
| 117 | + if message["type"] == "http.response.start" and 300 <= message.get("status", 200) < 400: |
| 118 | + headers = MutableHeaders(scope=message) |
| 119 | + location = headers.get("location") |
| 120 | + if location: |
| 121 | + parsed = urlparse(location) |
| 122 | + if parsed.path and not parsed.path.startswith(root_path): |
| 123 | + headers["location"] = urlunparse(parsed._replace(path=root_path + parsed.path)) |
| 124 | + await send(message) |
| 125 | + |
| 126 | + await self.app(scope, receive, send_wrapper) |
| 127 | + |
| 128 | + |
93 | 129 | class _FastAPI(FastAPI): |
94 | 130 | def openapi(self) -> dict[str, Any]: |
95 | 131 | if self.openapi_schema: |
@@ -215,6 +251,7 @@ def _use_route_names_as_operation_ids(app: FastAPI) -> None: |
215 | 251 |
|
216 | 252 | app = _FastAPI(openapi_tags=_tags, root_path=get_settings().root_path) |
217 | 253 |
|
| 254 | +app.add_middleware(_RootPathRedirectMiddleware, root_path=get_settings().root_path) |
218 | 255 | app.add_middleware( |
219 | 256 | CORSMiddleware, |
220 | 257 | allow_origins=get_settings().cors_origins, |
|
0 commit comments