@@ -87,7 +87,17 @@ class TDMissingMarker(str):
8787 ...
8888
8989
90+ class BytesTDMissingMarker (bytes ):
91+ """
92+ Similar to `TDMissingMarker` above, but is required for Typed Dict that have non-required values in `bytes`, as Avro
93+ is unable to manage Unions of `string` and `bytes`.
94+ """
95+
96+ ...
97+
98+
9099TD_MISSING_MARKER = TDMissingMarker ("__td_missing__" )
100+ BYTES_TD_MISSING_MARKER = BytesTDMissingMarker (b"__td_missing__" )
91101
92102
93103class TypeNotSupportedError (TypeError ):
@@ -885,6 +895,9 @@ def _validate_union(args: tuple[Any, ...]) -> None:
885895 :return: None
886896 :raises: TypeError if the Union types are invalid
887897 """
898+ if str in args and bytes in args :
899+ raise TypeError ("Avro does not support Union of types bytes and string" )
900+
888901 if type (None ) not in args and TDMissingMarker not in args :
889902 if any (
890903 # Enum is treated as a Sequence
@@ -1146,7 +1159,8 @@ def __init__(
11461159 if self .default != dataclasses .MISSING :
11471160 if isinstance (self .schema , UnionSchema ):
11481161 self .schema .sort_item_schemas (self .default )
1149- typeguard .check_type ("default_value" , self .default , self .py_type )
1162+ if self .default != TD_MISSING_MARKER :
1163+ typeguard .check_type ("default_value" , self .default , self .py_type )
11501164 else :
11511165 if Option .DEFAULTS_MANDATORY in self .options :
11521166 raise TypeError (f"Default value for field { self } is missing" )
@@ -1387,21 +1401,25 @@ def _record_field(self, py_field: tuple[str, Type]) -> RecordField:
13871401 """Return an Avro record field object for a given TypedDict field"""
13881402 aliases , actual_type = get_field_aliases_and_actual_type (py_field [1 ])
13891403
1404+ # Avro does not handle Unions of bytes and string
1405+ marker_type = BytesTDMissingMarker if _is_bytes (actual_type ) else TDMissingMarker
1406+
13901407 default = dataclasses .MISSING
13911408 if Option .MARK_NON_TOTAL_TYPED_DICTS in self .options and not self .is_total :
13921409 # If a TypedDict is marked as total=False, it does not need to contain all the field. However, we need to
13931410 # be able to distinguish between the fields that are missing from the ones that are present but set to None.
13941411 # To do that, we extend the original type with str. We will later add a special string
13951412 # (e.g., __td_missing__) as a marker at deserialization time.
1396- actual_type = Union [actual_type , TDMissingMarker ] # type: ignore
1413+ actual_type = Union [actual_type , marker_type ] # type: ignore
13971414 if _is_optional (actual_type ):
13981415 # Note: this works since this schema does not implement `make_default` and the base implementation
13991416 # simply return the provided type (None in this case).
1417+ # We need to use the string TD_MISSING_MARKER as the schema cannot serialize bytes
14001418 default = TD_MISSING_MARKER # type: ignore
14011419 elif _is_not_required (actual_type ):
14021420 # A field can be marked with typing.NotRequired even in a TypedDict with is not marked with total=False.
14031421 # Similarly as above, we extend the wrapped type with string.
1404- actual_type = Union [_unwrap_not_required (actual_type ), TDMissingMarker ] # type: ignore
1422+ actual_type = Union [_unwrap_not_required (actual_type ), marker_type ] # type: ignore
14051423
14061424 field_obj = RecordField (
14071425 py_type = actual_type ,
@@ -1433,6 +1451,14 @@ def _is_optional(py_type: Type) -> bool:
14331451 return False
14341452
14351453
1454+ def _is_bytes (py_type : Type ) -> bool :
1455+ """Given a Union of types, checks if bytes is one of those"""
1456+ try :
1457+ return py_type is bytes or bytes in get_args (py_type )
1458+ except Exception :
1459+ return False
1460+
1461+
14361462def _is_not_required (py_type : Type ) -> bool :
14371463 """Checks if a type is marked with typing.NotRequired"""
14381464 return get_origin (py_type ) is NotRequired # noqa
0 commit comments