Skip to content

Commit 25adbb8

Browse files
committed
fixes #10 |Replace middleware based on BaseHTTPMiddleware class
1 parent 2dca70d commit 25adbb8

File tree

3 files changed

+107
-67
lines changed

3 files changed

+107
-67
lines changed

{{cookiecutter.project_slug}}/tests/test_logging.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@ async def logging_client() -> AsyncIterable[AsyncClient]:
2626
"""
2727
from {{cookiecutter.project_slug}}.logging import init_loguru
2828
from {{cookiecutter.project_slug}}.main import (
29-
BaseHTTPMiddleware,
3029
RequestValidationError,
31-
log_request_middleware,
30+
LogRequestMiddleware,
3231
request_validation_exception_handler,
3332
)
3433

@@ -44,7 +43,7 @@ async def divide(a: int, b: int) -> float:
4443

4544
app = FastAPI()
4645
app.include_router(router)
47-
app.add_middleware(BaseHTTPMiddleware, dispatch=log_request_middleware)
46+
app.add_middleware(LogRequestMiddleware)
4847
# type annotation issue. See: https://github.com/encode/starlette/pull/2403
4948
app.add_exception_handler(RequestValidationError, request_validation_exception_handler) # type: ignore
5049

{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/main.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
from fastapi import FastAPI
22
from fastapi.exceptions import RequestValidationError
3-
from starlette.middleware.base import BaseHTTPMiddleware
43

54
from . import config # noqa: F401
65
from .exception_handlers import request_validation_exception_handler
76
from .resources import lifespan
87
from .routers import hello, user
98
from .middleware import (
10-
database_connection_middleware,
11-
log_request_middleware,
9+
DatabaseConnectionMiddleware,
10+
LogRequestMiddleware,
1211
)
1312

1413
app = FastAPI(
@@ -24,7 +23,7 @@
2423
for router in routers:
2524
app.include_router(router)
2625

27-
app.add_middleware(BaseHTTPMiddleware, dispatch=database_connection_middleware)
28-
app.add_middleware(BaseHTTPMiddleware, dispatch=log_request_middleware)
26+
app.add_middleware(DatabaseConnectionMiddleware)
27+
app.add_middleware(LogRequestMiddleware)
2928
# type annotation problem. See: https://github.com/encode/starlette/pull/2403
3029
app.add_exception_handler(RequestValidationError, request_validation_exception_handler) # type: ignore
Lines changed: 101 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,121 @@
1-
from collections.abc import Callable
21
from secrets import token_urlsafe
32
from time import time
4-
from typing import Final
3+
from typing import Any, Final
54

6-
from fastapi import Request, Response, status
5+
from fastapi import Request, status
76
from fastapi.responses import PlainTextResponse
87
from loguru import logger
8+
from starlette.types import ASGIApp, Message, Receive, Scope, Send
99

1010
from . import config
1111
from . import resources as res
1212
from .resources import connection_ctx
1313

1414

15-
async def log_request_middleware(request: Request, call_next: Callable) -> Response:
15+
class LogRequestMiddleware:
1616
"""
1717
Uniquely identify each request and logs its processing time.
1818
"""
19-
start_time = time()
20-
request_id: str = token_urlsafe(config.REQUEST_ID_LENGTH)
21-
exception = None
2219

23-
# keep the same request_id in the context of all subsequent calls to logger
24-
with logger.contextualize(request_id=request_id):
25-
try:
26-
response = await call_next(request)
27-
except Exception as exc:
28-
exception = exc
29-
response = PlainTextResponse('Internal Server Error', status_code=500)
30-
final_time = time()
31-
elapsed = final_time - start_time
32-
response_length = request.headers.get('content-length', 0)
33-
query_string = request['query_string'].decode()
34-
path_with_qs = request['path'] + ('?' + query_string if query_string else '')
35-
data = {
36-
'remote_ip': request.headers.get('x-forwarded-for') or request['client'],
37-
'schema': request.headers.get('x-forwarded-proto') or request['scheme'],
38-
'protocol': request.get('http_version', 'ws'),
39-
'method': request.get('method', 'GET'),
40-
'path_with_query': path_with_qs,
41-
'status_code': response.status_code,
42-
'response_length': response_length,
43-
'elapsed': elapsed,
44-
'referer': request.headers.get('referer', ''),
45-
'user_agent': request.headers.get('user-agent', ''),
46-
}
47-
if not exception:
48-
logger.info('log request', **data)
49-
else:
50-
logger.opt(exception=exception).error('Unhandled exception', **data)
51-
response.headers['X-Request-ID'] = request_id
52-
response.headers['X-Processed-Time'] = str(elapsed)
53-
return response
20+
def __init__(self, app: ASGIApp, *args: Any, **kwargs: Any) -> None: # noqa: ARG002
21+
self.app = app
22+
23+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
24+
if scope['type'] != 'http':
25+
await self.app(scope, receive, send)
26+
return
27+
28+
request = Request(scope, receive, send)
29+
start_time = time()
30+
request_id: str = token_urlsafe(config.REQUEST_ID_LENGTH)
31+
response_status = None
32+
response_length = 0
33+
exception = None
34+
elapsed = 0.0
35+
36+
async def send_wrapper(message: Message) -> None:
37+
nonlocal elapsed, response_status, response_length
38+
39+
if message['type'] == 'http.response.start':
40+
response_status = message['status']
41+
headers = message.get('headers', [])
42+
elapsed = time() - start_time
43+
# Add our custom headers to the response
44+
headers.append((b'X-Request-ID', request_id.encode()))
45+
headers.append((b'X-Processed-Time', str(elapsed).encode()))
46+
message['headers'] = headers
47+
elif message['type'] == 'http.response.body':
48+
response_length += len(message.get('body', b''))
49+
await send(message)
50+
51+
# keep the same request_id in the context of all subsequent calls to logger
52+
with logger.contextualize(request_id=request_id):
53+
try:
54+
await self.app(scope, receive, send_wrapper)
55+
except Exception as exc:
56+
exception = exc
57+
# Send error response
58+
error_response = PlainTextResponse('Internal Server Error', status_code=500)
59+
await error_response(scope, receive, send)
60+
61+
# Log the request after processing
62+
query_string = request['query_string'].decode()
63+
path_with_qs = f'{request["path"]}?{query_string}' if query_string else request['path']
64+
data = {
65+
'remote_ip': request.headers.get('x-forwarded-for') or request['client'],
66+
'schema': request.headers.get('x-forwarded-proto') or request['scheme'],
67+
'protocol': request.get('http_version', 'ws'),
68+
'method': request.get('method', 'GET'),
69+
'path_with_query': path_with_qs,
70+
'status_code': response_status or 500,
71+
'response_length': response_length,
72+
'elapsed': elapsed,
73+
'referer': request.headers.get('referer', ''),
74+
'user_agent': request.headers.get('user-agent', ''),
75+
}
76+
if not exception:
77+
logger.info('log request', **data)
78+
else:
79+
logger.opt(exception=exception).error('Unhandled exception', **data)
5480

5581

5682
COMMIT: Final[int] = 0
5783
ROLLBACK: Final[int] = 1
5884

59-
async def database_connection_middleware(request: Request, call_next: Callable) -> Response:
60-
'''
61-
Middleware that ensures that the database connection is closed after the request is processed.
62-
'''
63-
if res.connection_ctx.get(): # the database connection will be managed elsewhere in tests
64-
return await call_next(request)
65-
db_action: int = ROLLBACK
66-
connection = await res.engine.connect()
67-
token = connection_ctx.set(connection)
68-
try:
69-
response = await call_next(request)
70-
if response.status_code < status.HTTP_400_BAD_REQUEST:
71-
db_action = COMMIT
72-
finally:
73-
if db_action == COMMIT:
74-
await connection.commit()
75-
else:
76-
await connection.rollback()
77-
await connection.close()
78-
connection_ctx.reset(token)
79-
return response
85+
86+
class DatabaseConnectionMiddleware:
87+
"""
88+
Ensures that the database connection is closed after the request is processed.
89+
"""
90+
def __init__(self, app: ASGIApp, *args: Any, **kwargs: Any) -> None: # noqa: ARG002
91+
self.app = app
92+
93+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
94+
if (
95+
scope['type'] != 'http'
96+
or res.connection_ctx.get() # the database connection will be managed in tests
97+
):
98+
await self.app(scope, receive, send)
99+
return
100+
101+
db_action: int = ROLLBACK
102+
103+
async def send_wrapper(message: Message) -> None:
104+
nonlocal db_action, send
105+
106+
if message['type'] == 'http.response.start':
107+
status_code = message['status']
108+
if status_code < status.HTTP_400_BAD_REQUEST:
109+
db_action = COMMIT
110+
await send(message)
111+
112+
try:
113+
await self.app(scope, receive, send_wrapper)
114+
finally:
115+
if connection := res.connection_ctx.get():
116+
if db_action == COMMIT:
117+
await connection.commit()
118+
else:
119+
await connection.rollback()
120+
await connection.close()
121+
connection_ctx.set(None)

0 commit comments

Comments
 (0)