Skip to content

Commit 3cfce41

Browse files
committed
Addressed Union type review
1 parent 67b1d75 commit 3cfce41

2 files changed

Lines changed: 14 additions & 2 deletions

File tree

fastapi_gcp_tasks/requester.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _body(self, *, values: Dict[str, Any]) -> bytes | None:
9090
got_body = body_field.get_default()
9191
body_type = body_field.field_info.annotation
9292
check_type = get_origin(body_type) or body_type
93-
if body_type is not None and check_type is not None and not isinstance(got_body, check_type):
93+
if body_type is not None and isinstance(check_type, type) and not isinstance(got_body, check_type):
9494
raise WrongTypeError(field=body_field.name, type=body_type)
9595
body = json.dumps(jsonable_encoder(got_body)).encode()
9696
return body

tests/test_requester.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Unit tests for Requester._body covering missing-param and generic-type bugs."""
22

33
# Standard Library Imports
4-
from typing import List
4+
from typing import List, Union
55

66
# Third Party Imports
77
import pytest
@@ -38,6 +38,11 @@ async def list_body_endpoint(items: List[Item]) -> None:
3838
"""Endpoint with a parameterized generic body."""
3939

4040

41+
@app.post("/union_body")
42+
async def union_body_endpoint(item: Union[Item, str] = "fallback") -> None:
43+
"""Endpoint with a Union-typed body."""
44+
45+
4146
def _get_route(path: str) -> APIRoute:
4247
for route in app.routes:
4348
if isinstance(route, APIRoute) and route.path == path:
@@ -88,3 +93,10 @@ def test_simple_body_wrong_type_raises(self) -> None:
8893
requester = Requester(route=route, base_url="http://localhost")
8994
with pytest.raises(WrongTypeError):
9095
requester._body(values={"item": "not an Item"})
96+
97+
def test_union_body_does_not_crash(self) -> None:
98+
"""Union-typed body should not raise TypeError on isinstance check."""
99+
route = _get_route("/union_body")
100+
requester = Requester(route=route, base_url="http://localhost")
101+
body = requester._body(values={"item": Item(name="test")})
102+
assert body is not None

0 commit comments

Comments
 (0)