Skip to content

Commit 1ea0bcc

Browse files
committed
redirect handling in middleware
1 parent 90490b9 commit 1ea0bcc

2 files changed

Lines changed: 39 additions & 2 deletions

File tree

src/api/__init__.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,17 @@
44
import re
55
import sys
66
from importlib.resources import files
7-
from typing import Any, Literal
7+
from typing import Any, Literal, MutableMapping
8+
from urllib.parse import urlparse, urlunparse
89

910
from alembic import command
1011
from alembic.config import Config
1112
from fastapi import FastAPI
1213
from fastapi.middleware.cors import CORSMiddleware
1314
from fastapi.routing import APIRoute
1415
from pydantic import BaseModel
16+
from starlette.datastructures import MutableHeaders
17+
from starlette.types import ASGIApp, Receive, Scope, Send
1518

1619
from ..deployment.monitors.health import vm_monitor
1720
from .backup import router as backup_router
@@ -95,6 +98,39 @@ def _logging_config() -> dict[str, Any]:
9598
logger = logging.getLogger(__name__)
9699

97100

101+
class _RootPathRedirectMiddleware:
102+
"""Fix Starlette's automatic redirect responses to include root_path.
103+
104+
Starlette's Router generates trailing-slash redirects (307) using
105+
URL(scope=...) which only reads scope["path"] and ignores
106+
scope["root_path"]. This middleware rewrites the Location header on
107+
redirect responses so the root_path prefix is included.
108+
"""
109+
110+
def __init__(self, app: ASGIApp, root_path: str = "") -> None:
111+
self.app = app
112+
self.root_path = root_path
113+
114+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
115+
if scope["type"] != "http" or not self.root_path:
116+
await self.app(scope, receive, send)
117+
return
118+
119+
root_path = self.root_path
120+
121+
async def send_wrapper(message: MutableMapping[str, Any]) -> None:
122+
if message["type"] == "http.response.start" and 300 <= message.get("status", 200) < 400:
123+
headers = MutableHeaders(scope=message)
124+
location = headers.get("location")
125+
if location:
126+
parsed = urlparse(location)
127+
if parsed.path and not parsed.path.startswith(root_path):
128+
headers["location"] = urlunparse(parsed._replace(path=root_path + parsed.path))
129+
await send(message)
130+
131+
await self.app(scope, receive, send_wrapper)
132+
133+
98134
class _FastAPI(FastAPI):
99135
def openapi(self) -> dict[str, Any]:
100136
if self.openapi_schema:
@@ -220,6 +256,7 @@ def _use_route_names_as_operation_ids(app: FastAPI) -> None:
220256

221257
app = _FastAPI(openapi_tags=_tags, root_path=get_settings().root_path)
222258

259+
app.add_middleware(_RootPathRedirectMiddleware, root_path=get_settings().root_path)
223260
app.add_middleware(
224261
CORSMiddleware,
225262
allow_origins=get_settings().cors_origins,

src/api/_util/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ def wrapper(*args, **kwargs):
3030

3131

3232
def url_path_for(request: Request, name: str, **kwargs) -> str:
33-
return request.scope.get("root_path") + request.app.url_path_for(name, **kwargs)
33+
return request.url_for(name, **kwargs).path

0 commit comments

Comments
 (0)