Skip to content

Commit 4a50e32

Browse files
committed
Handle ValueError safely on websockets to avoid ASGI crash
1 parent 20caa11 commit 4a50e32

2 files changed

Lines changed: 24 additions & 6 deletions

File tree

src/api/error_handling.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from http import HTTPStatus
55
from uuid import uuid4
66

7-
from fastapi import FastAPI, HTTPException, Request
7+
from fastapi import FastAPI, HTTPException, Request, WebSocket
88
from fastapi.exceptions import RequestValidationError
99
from fastapi.responses import JSONResponse
1010
from sqlalchemy.exc import OperationalError
@@ -14,7 +14,7 @@
1414
logger = logging.getLogger("api.errors")
1515

1616

17-
def _resolve_request_id(request: Request) -> str:
17+
def _resolve_request_id(request: Request | WebSocket) -> str:
1818
request_id = getattr(request.state, "request_id", None)
1919
if isinstance(request_id, str) and request_id:
2020
return request_id
@@ -128,17 +128,23 @@ async def validation_exception_handler(
128128
return JSONResponse(status_code=422, content=body)
129129

130130
@app.exception_handler(ValueError)
131-
async def value_error_handler(request: Request, exc: ValueError) -> JSONResponse:
131+
async def value_error_handler(request: Request | WebSocket, exc: ValueError) -> JSONResponse | None:
132132
request_id = _resolve_request_id(request)
133+
path = request.url.path if request.url is not None else "<unknown>"
134+
method = request.method if isinstance(request, Request) else "WEBSOCKET"
133135
logger.warning(
134136
"value_error",
135137
extra={
136138
"request_id": request_id,
137-
"method": request.method,
138-
"path": request.url.path,
139+
"method": method,
140+
"path": path,
139141
"error_detail": str(exc),
140142
},
141143
)
144+
if isinstance(request, WebSocket):
145+
reason = str(exc) or "Invalid value"
146+
await request.close(code=1008, reason=reason[:120])
147+
return None
142148
body = _build_error_body(
143149
status_code=422,
144150
request_id=request_id,

tests/test_api_error_handling.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import unittest
55
from pathlib import Path
66

7-
from fastapi import FastAPI, HTTPException
7+
from fastapi import FastAPI, HTTPException, WebSocket
88
from fastapi.testclient import TestClient
99
from pydantic import BaseModel
1010
from sqlalchemy.exc import DBAPIError, OperationalError
11+
from starlette.websockets import WebSocketDisconnect
1112

1213
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
1314

@@ -46,6 +47,11 @@ def crash() -> None:
4647
def bad_value() -> None:
4748
raise ValueError("legacy game row has invalid enum")
4849

50+
@app.websocket("/ws-bad-value")
51+
async def ws_bad_value(websocket: WebSocket) -> None:
52+
await websocket.accept()
53+
raise ValueError("invalid ws payload")
54+
4955
return app
5056

5157

@@ -123,6 +129,12 @@ def test_value_error_maps_to_422_standard_shape(self) -> None:
123129
self.assertEqual(payload["detail"], "legacy game row has invalid enum")
124130
self.assertIn("request_id", payload)
125131

132+
def test_value_error_on_websocket_closes_with_policy_violation(self) -> None:
133+
client = TestClient(build_test_app(), raise_server_exceptions=False)
134+
with self.assertRaises(WebSocketDisconnect) as ctx, client.websocket_connect("/ws-bad-value") as websocket:
135+
websocket.receive_json()
136+
self.assertEqual(ctx.exception.code, 1008)
137+
126138

127139
if __name__ == "__main__":
128140
unittest.main()

0 commit comments

Comments
 (0)