Skip to content

Commit af4c8c1

Browse files
committed
add support for missing marker for bytes field in TypedDict
1 parent fe8c810 commit af4c8c1

3 files changed

Lines changed: 45 additions & 3 deletions

File tree

src/py_avro_schema/_schemas.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,17 @@ class TDMissingMarker(str):
8787
...
8888

8989

90+
class BytesTDMissingMarker(bytes):
91+
"""
92+
Similar to `TDMissingMarker` above, but is required for Typed Dict that have non-required values in `bytes`, as Avro
93+
is unable to manage Unions of `string` and `bytes`.
94+
"""
95+
96+
...
97+
98+
9099
TD_MISSING_MARKER = TDMissingMarker("__td_missing__")
100+
BYTES_TD_MISSING_MARKER = BytesTDMissingMarker(b"__td_missing__")
91101

92102

93103
class TypeNotSupportedError(TypeError):
@@ -885,6 +895,9 @@ def _validate_union(args: tuple[Any, ...]) -> None:
885895
:return: None
886896
:raises: TypeError if the Union types are invalid
887897
"""
898+
if str in args and bytes in args:
899+
raise TypeError("Avro does not support Union of types bytes and string")
900+
888901
if type(None) not in args and TDMissingMarker not in args:
889902
if any(
890903
# Enum is treated as a Sequence
@@ -1146,7 +1159,8 @@ def __init__(
11461159
if self.default != dataclasses.MISSING:
11471160
if isinstance(self.schema, UnionSchema):
11481161
self.schema.sort_item_schemas(self.default)
1149-
typeguard.check_type("default_value", self.default, self.py_type)
1162+
if self.default != TD_MISSING_MARKER:
1163+
typeguard.check_type("default_value", self.default, self.py_type)
11501164
else:
11511165
if Option.DEFAULTS_MANDATORY in self.options:
11521166
raise TypeError(f"Default value for field {self} is missing")
@@ -1387,21 +1401,25 @@ def _record_field(self, py_field: tuple[str, Type]) -> RecordField:
13871401
"""Return an Avro record field object for a given TypedDict field"""
13881402
aliases, actual_type = get_field_aliases_and_actual_type(py_field[1])
13891403

1404+
# Avro does not handle Unions of bytes and string
1405+
marker_type = BytesTDMissingMarker if _is_bytes(actual_type) else TDMissingMarker
1406+
13901407
default = dataclasses.MISSING
13911408
if Option.MARK_NON_TOTAL_TYPED_DICTS in self.options and not self.is_total:
13921409
# If a TypedDict is marked as total=False, it does not need to contain all the field. However, we need to
13931410
# be able to distinguish between the fields that are missing from the ones that are present but set to None.
13941411
# To do that, we extend the original type with str. We will later add a special string
13951412
# (e.g., __td_missing__) as a marker at deserialization time.
1396-
actual_type = Union[actual_type, TDMissingMarker] # type: ignore
1413+
actual_type = Union[actual_type, marker_type] # type: ignore
13971414
if _is_optional(actual_type):
13981415
# Note: this works since this schema does not implement `make_default` and the base implementation
13991416
# simply return the provided type (None in this case).
1417+
# We need to use the string TD_MISSING_MARKER as the schema cannot serialize bytes
14001418
default = TD_MISSING_MARKER # type: ignore
14011419
elif _is_not_required(actual_type):
14021420
# A field can be marked with typing.NotRequired even in a TypedDict with is not marked with total=False.
14031421
# Similarly as above, we extend the wrapped type with string.
1404-
actual_type = Union[_unwrap_not_required(actual_type), TDMissingMarker] # type: ignore
1422+
actual_type = Union[_unwrap_not_required(actual_type), marker_type] # type: ignore
14051423

14061424
field_obj = RecordField(
14071425
py_type=actual_type,
@@ -1433,6 +1451,14 @@ def _is_optional(py_type: Type) -> bool:
14331451
return False
14341452

14351453

1454+
def _is_bytes(py_type: Type) -> bool:
1455+
"""Given a Union of types, checks if bytes is one of those"""
1456+
try:
1457+
return py_type is bytes or bytes in get_args(py_type)
1458+
except Exception:
1459+
return False
1460+
1461+
14361462
def _is_not_required(py_type: Type) -> bool:
14371463
"""Checks if a type is marked with typing.NotRequired"""
14381464
return get_origin(py_type) is NotRequired # noqa

tests/test_primitives.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,15 @@ def test_literal_different_types():
430430
py_avro_schema._schemas.schema(py_type)
431431

432432

433+
def test_union_bytes_string():
434+
py_type = Union[str, bytes]
435+
with pytest.raises(
436+
TypeError,
437+
match=re.escape("Avro does not support Union of types bytes and string"),
438+
):
439+
py_avro_schema._schemas.schema(py_type)
440+
441+
433442
def test_optional_str():
434443
py_type = Optional[str]
435444
expected = ["string", "null"]

tests/test_typed_dict.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ class PyType(TypedDict, total=False):
116116
age: int | None
117117
invalid: InvalidEnumSymbol | None
118118
valid: ValidEnumSymbol | None
119+
bytes_data: bytes
120+
bytes_data_nullable: bytes | None
119121

120122
expected = {
121123
"fields": [
@@ -144,6 +146,11 @@ class PyType(TypedDict, total=False):
144146
"null",
145147
],
146148
},
149+
{
150+
"name": "bytes_data",
151+
"type": "bytes",
152+
},
153+
{"default": "__td_missing__", "name": "bytes_data_nullable", "type": ["bytes", "null"]},
147154
],
148155
"name": "PyType",
149156
"type": "record",

0 commit comments

Comments
 (0)