Skip to content

Commit f9bc84f

Browse files
committed
redirect handling in middleware
1 parent 5e2ebd7 commit f9bc84f

2 files changed

Lines changed: 39 additions & 1 deletion

File tree

src/api/__init__.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,19 @@
33
import logging.config
44
import re
55
import sys
6+
from collections.abc import MutableMapping
67
from importlib.resources import files
78
from typing import Any, Literal
9+
from urllib.parse import urlparse, urlunparse
810

911
from alembic import command
1012
from alembic.config import Config
1113
from fastapi import FastAPI
1214
from fastapi.middleware.cors import CORSMiddleware
1315
from fastapi.routing import APIRoute
1416
from pydantic import BaseModel
17+
from starlette.datastructures import MutableHeaders
18+
from starlette.types import ASGIApp, Receive, Scope, Send
1519

1620
from ..deployment.monitors.health import vm_monitor
1721
from .backup import router as backup_router
@@ -95,6 +99,39 @@ def _logging_config() -> dict[str, Any]:
9599
logger = logging.getLogger(__name__)
96100

97101

102+
class _RootPathRedirectMiddleware:
103+
"""Fix Starlette's automatic redirect responses to include root_path.
104+
105+
Starlette's Router generates trailing-slash redirects (307) using
106+
URL(scope=...) which only reads scope["path"] and ignores
107+
scope["root_path"]. This middleware rewrites the Location header on
108+
redirect responses so the root_path prefix is included.
109+
"""
110+
111+
def __init__(self, app: ASGIApp, root_path: str = "") -> None:
112+
self.app = app
113+
self.root_path = root_path
114+
115+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
116+
if scope["type"] != "http" or not self.root_path:
117+
await self.app(scope, receive, send)
118+
return
119+
120+
root_path = self.root_path
121+
122+
async def send_wrapper(message: MutableMapping[str, Any]) -> None:
123+
if message["type"] == "http.response.start" and 300 <= message.get("status", 200) < 400:
124+
headers = MutableHeaders(scope=message)
125+
location = headers.get("location")
126+
if location:
127+
parsed = urlparse(location)
128+
if parsed.path and not parsed.path.startswith(root_path):
129+
headers["location"] = urlunparse(parsed._replace(path=root_path + parsed.path))
130+
await send(message)
131+
132+
await self.app(scope, receive, send_wrapper)
133+
134+
98135
class _FastAPI(FastAPI):
99136
def openapi(self) -> dict[str, Any]:
100137
if self.openapi_schema:
@@ -220,6 +257,7 @@ def _use_route_names_as_operation_ids(app: FastAPI) -> None:
220257

221258
app = _FastAPI(openapi_tags=_tags, root_path=get_settings().root_path)
222259

260+
app.add_middleware(_RootPathRedirectMiddleware, root_path=get_settings().root_path)
223261
app.add_middleware(
224262
CORSMiddleware,
225263
allow_origins=get_settings().cors_origins,

src/api/_util/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ def wrapper(*args, **kwargs):
3030

3131

3232
def url_path_for(request: Request, name: str, **kwargs) -> str:
33-
return request.scope.get("root_path") + request.app.url_path_for(name, **kwargs)
33+
return request.url_for(name, **kwargs).path

0 commit comments

Comments
 (0)