forked from MemTensor/MemOS
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrequest_context.py
More file actions
102 lines (76 loc) · 3.15 KB
/
request_context.py
File metadata and controls
102 lines (76 loc) · 3.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
Request context middleware for automatic trace_id injection.
"""
import time
from collections.abc import Callable
from fastapi.responses import StreamingResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
import memos.log
from memos.context.context import RequestContext, generate_trace_id, set_request_context
async def _tee_stream(
original: StreamingResponse,
) -> StreamingResponse:
chunks = []
async for chunk in original.body_iterator:
chunks.append(chunk)
yield chunk
body_str = "".join(chunks).decode("utf-8", errors="replace")
logger.info(f"Response content: {body_str}")
logger = memos.log.get_logger(__name__)
def extract_trace_id_from_headers(request: Request) -> str | None:
"""Extract trace_id from various possible headers with priority: g-trace-id > x-trace-id > trace-id."""
for header in ["g-trace-id", "x-trace-id", "trace-id"]:
if trace_id := request.headers.get(header):
return trace_id
return None
class RequestContextMiddleware(BaseHTTPMiddleware):
"""
Middleware to automatically inject request context for every HTTP request.
This middleware:
1. Extracts trace_id from headers or generates a new one
2. Creates a RequestContext and sets it globally
3. Ensures the context is available throughout the request lifecycle
"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Extract or generate trace_id
trace_id = extract_trace_id_from_headers(request) or generate_trace_id()
env = request.headers.get("x-env")
user_type = request.headers.get("x-user-type")
user_name = request.headers.get("x-user-name")
start_time = time.time()
# Create and set request context
context = RequestContext(
trace_id=trace_id,
api_path=request.url.path,
env=env,
user_type=user_type,
user_name=user_name,
)
set_request_context(context)
# Log request start with parameters
params_log = {}
# Get query parameters
if request.query_params:
params_log["query_params"] = dict(request.query_params)
logger.info(f"Request started, params: {params_log}, headers: {request.headers}")
# Process the request
try:
response = await call_next(request)
end_time = time.time()
if response.status_code == 200:
logger.info(
f"Request completed: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms"
)
else:
logger.error(
f"Request Failed: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms"
)
except Exception as e:
end_time = time.time()
logger.error(
f"Request Exception Error: {e}, cost: {(end_time - start_time) * 1000:.2f}ms"
)
raise e
return response