-
Notifications
You must be signed in to change notification settings - Fork 681
Expand file tree
/
Copy pathresponse_middleware.py
More file actions
108 lines (92 loc) · 3.97 KB
/
response_middleware.py
File metadata and controls
108 lines (92 loc) · 3.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import json
from redis import typing
from starlette.exceptions import HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from common.core.config import settings
from common.utils.utils import SQLBotLogUtil
class ResponseMiddleware(BaseHTTPMiddleware):
def __init__(self, app):
self.allow_origins = ["'self'"]
super().__init__(app)
async def dispatch(self, request, call_next):
response = await call_next(request)
direct_paths = [
f"{settings.API_V1_STR}/mcp/mcp_question",
f"{settings.API_V1_STR}/mcp/mcp_assistant",
"/openapi.json",
"/docs",
"/redoc"
]
route = request.scope.get("route")
# 获取定义的路径模式,例如 '/items/{item_id}'
path_pattern = '' if not route else route.path_format
if (isinstance(response, JSONResponse)
or request.url.path == f"{settings.API_V1_STR}/openapi.json"
or path_pattern in direct_paths):
return response
if response.status_code != 200:
return response
if response.headers.get("content-type") == "application/json":
try:
body = b""
async for chunk in response.body_iterator:
body += chunk
raw_data = json.loads(body.decode())
if isinstance(raw_data, dict) and all(k in raw_data for k in ["code", "data", "msg"]):
return JSONResponse(
content=raw_data,
status_code=response.status_code,
headers={
k: v for k, v in response.headers.items()
if k.lower() not in ("content-length", "content-type")
}
)
wrapped_data = {
"code": 0,
"data": raw_data,
"msg": None
}
return JSONResponse(
content=wrapped_data,
status_code=response.status_code,
headers={
k: v for k, v in response.headers.items()
if k.lower() not in ("content-length", "content-type")
}
)
except Exception as e:
SQLBotLogUtil.error(f"Response processing error: {str(e)}", exc_info=True)
return JSONResponse(
status_code=500,
content=str(e),
headers={
k: v for k, v in response.headers.items()
if k.lower() not in ("content-length", "content-type")
}
)
content_type = response.headers.get("content-type", "")
static_content_types = ["text/html", "javascript", "typescript", "css"]
if any(ct in content_type for ct in static_content_types):
if self.allow_origins:
frame_ancestors_value = " ".join(self.allow_origins)
response.headers["Content-Security-Policy"] = f"frame-ancestors {frame_ancestors_value};"
return response
class exception_handler():
@staticmethod
async def http_exception_handler(request: Request, exc: HTTPException):
SQLBotLogUtil.error(f"HTTP Exception: {exc.detail}", exc_info=True)
return JSONResponse(
status_code=exc.status_code,
content=exc.detail,
headers={"Access-Control-Allow-Origin": "*"}
)
@staticmethod
async def global_exception_handler(request: Request, exc: Exception):
SQLBotLogUtil.error(f"Unhandled Exception: {str(exc)}", exc_info=True)
return JSONResponse(
status_code=500,
content=str(exc),
headers={"Access-Control-Allow-Origin": "*"}
)