|
1 | | -from collections.abc import Callable |
2 | 1 | from secrets import token_urlsafe |
3 | 2 | from time import time |
4 | | -from typing import Final |
| 3 | +from typing import Any, Final |
5 | 4 |
|
6 | | -from fastapi import Request, Response, status |
| 5 | +from fastapi import Request, status |
7 | 6 | from fastapi.responses import PlainTextResponse |
8 | 7 | from loguru import logger |
| 8 | +from starlette.types import ASGIApp, Message, Receive, Scope, Send |
9 | 9 |
|
10 | 10 | from . import config |
11 | 11 | from . import resources as res |
12 | 12 | from .resources import connection_ctx |
13 | 13 |
|
14 | 14 |
|
15 | | -async def log_request_middleware(request: Request, call_next: Callable) -> Response: |
| 15 | +class LogRequestMiddleware: |
16 | 16 | """ |
17 | 17 | Uniquely identify each request and logs its processing time. |
18 | 18 | """ |
19 | | - start_time = time() |
20 | | - request_id: str = token_urlsafe(config.REQUEST_ID_LENGTH) |
21 | | - exception = None |
22 | 19 |
|
23 | | - # keep the same request_id in the context of all subsequent calls to logger |
24 | | - with logger.contextualize(request_id=request_id): |
25 | | - try: |
26 | | - response = await call_next(request) |
27 | | - except Exception as exc: |
28 | | - exception = exc |
29 | | - response = PlainTextResponse('Internal Server Error', status_code=500) |
30 | | - final_time = time() |
31 | | - elapsed = final_time - start_time |
32 | | - response_length = request.headers.get('content-length', 0) |
33 | | - query_string = request['query_string'].decode() |
34 | | - path_with_qs = request['path'] + ('?' + query_string if query_string else '') |
35 | | - data = { |
36 | | - 'remote_ip': request.headers.get('x-forwarded-for') or request['client'], |
37 | | - 'schema': request.headers.get('x-forwarded-proto') or request['scheme'], |
38 | | - 'protocol': request.get('http_version', 'ws'), |
39 | | - 'method': request.get('method', 'GET'), |
40 | | - 'path_with_query': path_with_qs, |
41 | | - 'status_code': response.status_code, |
42 | | - 'response_length': response_length, |
43 | | - 'elapsed': elapsed, |
44 | | - 'referer': request.headers.get('referer', ''), |
45 | | - 'user_agent': request.headers.get('user-agent', ''), |
46 | | - } |
47 | | - if not exception: |
48 | | - logger.info('log request', **data) |
49 | | - else: |
50 | | - logger.opt(exception=exception).error('Unhandled exception', **data) |
51 | | - response.headers['X-Request-ID'] = request_id |
52 | | - response.headers['X-Processed-Time'] = str(elapsed) |
53 | | - return response |
| 20 | + def __init__(self, app: ASGIApp, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 |
| 21 | + self.app = app |
| 22 | + |
| 23 | + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
| 24 | + if scope['type'] != 'http': |
| 25 | + await self.app(scope, receive, send) |
| 26 | + return |
| 27 | + |
| 28 | + request = Request(scope, receive, send) |
| 29 | + start_time = time() |
| 30 | + request_id: str = token_urlsafe(config.REQUEST_ID_LENGTH) |
| 31 | + response_status = None |
| 32 | + response_length = 0 |
| 33 | + exception = None |
| 34 | + elapsed = 0.0 |
| 35 | + |
| 36 | + async def send_wrapper(message: Message) -> None: |
| 37 | + nonlocal elapsed, response_status, response_length |
| 38 | + |
| 39 | + if message['type'] == 'http.response.start': |
| 40 | + response_status = message['status'] |
| 41 | + headers = message.get('headers', []) |
| 42 | + elapsed = time() - start_time |
| 43 | + # Add our custom headers to the response |
| 44 | + headers.append((b'X-Request-ID', request_id.encode())) |
| 45 | + headers.append((b'X-Processed-Time', str(elapsed).encode())) |
| 46 | + message['headers'] = headers |
| 47 | + elif message['type'] == 'http.response.body': |
| 48 | + response_length += len(message.get('body', b'')) |
| 49 | + await send(message) |
| 50 | + |
| 51 | + # keep the same request_id in the context of all subsequent calls to logger |
| 52 | + with logger.contextualize(request_id=request_id): |
| 53 | + try: |
| 54 | + await self.app(scope, receive, send_wrapper) |
| 55 | + except Exception as exc: |
| 56 | + exception = exc |
| 57 | + # Send error response |
| 58 | + error_response = PlainTextResponse('Internal Server Error', status_code=500) |
| 59 | + await error_response(scope, receive, send) |
| 60 | + |
| 61 | + # Log the request after processing |
| 62 | + query_string = request['query_string'].decode() |
| 63 | + path_with_qs = f'{request["path"]}?{query_string}' if query_string else request['path'] |
| 64 | + data = { |
| 65 | + 'remote_ip': request.headers.get('x-forwarded-for') or request['client'], |
| 66 | + 'schema': request.headers.get('x-forwarded-proto') or request['scheme'], |
| 67 | + 'protocol': request.get('http_version', 'ws'), |
| 68 | + 'method': request.get('method', 'GET'), |
| 69 | + 'path_with_query': path_with_qs, |
| 70 | + 'status_code': response_status or 500, |
| 71 | + 'response_length': response_length, |
| 72 | + 'elapsed': elapsed, |
| 73 | + 'referer': request.headers.get('referer', ''), |
| 74 | + 'user_agent': request.headers.get('user-agent', ''), |
| 75 | + } |
| 76 | + if not exception: |
| 77 | + logger.info('log request', **data) |
| 78 | + else: |
| 79 | + logger.opt(exception=exception).error('Unhandled exception', **data) |
54 | 80 |
|
55 | 81 |
|
56 | 82 | COMMIT: Final[int] = 0 |
57 | 83 | ROLLBACK: Final[int] = 1 |
58 | 84 |
|
59 | | -async def database_connection_middleware(request: Request, call_next: Callable) -> Response: |
60 | | - ''' |
61 | | - Middleware that ensures that the database connection is closed after the request is processed. |
62 | | - ''' |
63 | | - if res.connection_ctx.get(): # the database connection will be managed elsewhere in tests |
64 | | - return await call_next(request) |
65 | | - db_action: int = ROLLBACK |
66 | | - connection = await res.engine.connect() |
67 | | - token = connection_ctx.set(connection) |
68 | | - try: |
69 | | - response = await call_next(request) |
70 | | - if response.status_code < status.HTTP_400_BAD_REQUEST: |
71 | | - db_action = COMMIT |
72 | | - finally: |
73 | | - if db_action == COMMIT: |
74 | | - await connection.commit() |
75 | | - else: |
76 | | - await connection.rollback() |
77 | | - await connection.close() |
78 | | - connection_ctx.reset(token) |
79 | | - return response |
| 85 | + |
| 86 | +class DatabaseConnectionMiddleware: |
| 87 | + """ |
| 88 | + Ensures that the database connection is closed after the request is processed. |
| 89 | + """ |
| 90 | + def __init__(self, app: ASGIApp, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 |
| 91 | + self.app = app |
| 92 | + |
| 93 | + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
| 94 | + if ( |
| 95 | + scope['type'] != 'http' |
| 96 | + or res.connection_ctx.get() # the database connection will be managed in tests |
| 97 | + ): |
| 98 | + await self.app(scope, receive, send) |
| 99 | + return |
| 100 | + |
| 101 | + db_action: int = ROLLBACK |
| 102 | + |
| 103 | + async def send_wrapper(message: Message) -> None: |
| 104 | + nonlocal db_action, send |
| 105 | + |
| 106 | + if message['type'] == 'http.response.start': |
| 107 | + status_code = message['status'] |
| 108 | + if status_code < status.HTTP_400_BAD_REQUEST: |
| 109 | + db_action = COMMIT |
| 110 | + await send(message) |
| 111 | + |
| 112 | + try: |
| 113 | + await self.app(scope, receive, send_wrapper) |
| 114 | + finally: |
| 115 | + if connection := res.connection_ctx.get(): |
| 116 | + if db_action == COMMIT: |
| 117 | + await connection.commit() |
| 118 | + else: |
| 119 | + await connection.rollback() |
| 120 | + await connection.close() |
| 121 | + connection_ctx.set(None) |
0 commit comments