Skip to content

Commit 1ee054e

Browse files
committed
redirect handling in middleware
1 parent ab64849 commit 1ee054e

2 files changed

Lines changed: 38 additions & 1 deletion

File tree

src/api/__init__.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
import sys
66
from importlib.resources import files
77
from typing import Any, Literal
8+
from urllib.parse import urlparse, urlunparse
89

910
from alembic import command
1011
from alembic.config import Config
1112
from fastapi import FastAPI
1213
from fastapi.middleware.cors import CORSMiddleware
1314
from fastapi.routing import APIRoute
1415
from pydantic import BaseModel
16+
from starlette.datastructures import MutableHeaders
17+
from starlette.types import ASGIApp, Receive, Scope, Send
1518

1619
from ..deployment.monitors.health import vm_monitor
1720
from .backup import router as backup_router
@@ -90,6 +93,39 @@ def _logging_config() -> dict[str, Any]:
9093
logger = logging.getLogger(__name__)
9194

9295

96+
class _RootPathRedirectMiddleware:
97+
"""Fix Starlette's automatic redirect responses to include root_path.
98+
99+
Starlette's Router generates trailing-slash redirects (307) using
100+
URL(scope=...) which only reads scope["path"] and ignores
101+
scope["root_path"]. This middleware rewrites the Location header on
102+
redirect responses so the root_path prefix is included.
103+
"""
104+
105+
def __init__(self, app: ASGIApp, root_path: str = "") -> None:
106+
self.app = app
107+
self.root_path = root_path
108+
109+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
110+
if scope["type"] != "http" or not self.root_path:
111+
await self.app(scope, receive, send)
112+
return
113+
114+
root_path = self.root_path
115+
116+
async def send_wrapper(message: dict) -> None:
117+
if message["type"] == "http.response.start" and 300 <= message.get("status", 200) < 400:
118+
headers = MutableHeaders(scope=message)
119+
location = headers.get("location")
120+
if location:
121+
parsed = urlparse(location)
122+
if parsed.path and not parsed.path.startswith(root_path):
123+
headers["location"] = urlunparse(parsed._replace(path=root_path + parsed.path))
124+
await send(message)
125+
126+
await self.app(scope, receive, send_wrapper)
127+
128+
93129
class _FastAPI(FastAPI):
94130
def openapi(self) -> dict[str, Any]:
95131
if self.openapi_schema:
@@ -215,6 +251,7 @@ def _use_route_names_as_operation_ids(app: FastAPI) -> None:
215251

216252
app = _FastAPI(openapi_tags=_tags, root_path=get_settings().root_path)
217253

254+
app.add_middleware(_RootPathRedirectMiddleware, root_path=get_settings().root_path)
218255
app.add_middleware(
219256
CORSMiddleware,
220257
allow_origins=get_settings().cors_origins,

src/api/_util/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ def wrapper(*args, **kwargs):
3030

3131

3232
def url_path_for(request: Request, name: str, **kwargs) -> str:
33-
return request.scope.get("root_path") + request.app.url_path_for(name, **kwargs)
33+
return request.url_for(name, **kwargs).path

0 commit comments

Comments
 (0)