diff --git a/src/py_avro_schema/_schemas.py b/src/py_avro_schema/_schemas.py index 8ecd499..eee82bb 100644 --- a/src/py_avro_schema/_schemas.py +++ b/src/py_avro_schema/_schemas.py @@ -131,6 +131,12 @@ class Option(enum.Flag): #: See https://docs.pydantic.dev/dev/api/fields/#pydantic.fields.Field USE_FIELD_ALIAS = enum.auto() + #: TypedDict marked with ``total=False`` are valid structures when a field is missing. When of the field is also + # optional, we need to have a way to distinguish between a `None` and a non-set field. With this option, the type + # of each field is extended with `string`. This way, clients can add markers (e.g., `__td_missing__`) to discern + # the two cases. + MARK_NON_TOTAL_TYPED_DICTS = enum.auto() + JSON_OPTIONS = [opt for opt in Option if opt.name and opt.name.startswith("JSON_")] @@ -804,7 +810,18 @@ def data(self, names: NamesType) -> JSONType: schemas = (item_schema.data(names=names) for item_schema in self.item_schemas) # We need to deduplicate the schemas **after** rendering. This is because **different** Python types might # result in the **same** Avro schema. Preserving order as order may be significant in an Avro schema. - unique_schemas = list(more_itertools.unique_everseen(schemas)) + + def normalize_string_duplicates(_schema): + """We might have cases in which we have a schema both for ``StrSubclassSchema`` (e.g., a ``StrEnum`` with + invalid names is represented as a ``StrSubclassSchema``) and a string. These are technically duplicates, + but ``unique_everseen`` won't remove them by default.""" + if _schema == "string": + return "string" + elif isinstance(_schema, dict) and _schema.get("type") == "string": + return "string" + return _schema + + unique_schemas = list(more_itertools.unique_everseen(schemas, key=normalize_string_duplicates)) if len(unique_schemas) > 1: return unique_schemas else: @@ -1168,6 +1185,8 @@ def handles_type(cls, py_type: Type) -> bool: not dataclasses.is_dataclass(py_type) # Pydantic models are handled above and not hasattr(py_type, "__pydantic_private__") + # typed_dict handled separately + and not is_typeddict(py_type) # If we are subclassing a string, used the "named string" approach and (inspect.isclass(py_type) and not issubclass(py_type, str)) # and any other class with typed annotations @@ -1240,12 +1259,21 @@ def __init__(self, py_type: Type, namespace: str | None = None, options: Option """ super().__init__(py_type, namespace=namespace, options=options) py_type = _type_from_annotated(py_type) + self.is_total = py_type.__dict__.get("__total__", True) self.py_fields: dict[str, Type] = get_type_hints(py_type, include_extras=True) self.record_fields = [self._record_field(field) for field in self.py_fields.items()] def _record_field(self, py_field: tuple[str, Type]) -> RecordField: """Return an Avro record field object for a given TypedDict field""" aliases, actual_type = get_field_aliases_and_actual_type(py_field[1]) + + if Option.MARK_NON_TOTAL_TYPED_DICTS in self.options and not self.is_total: + # If a TypedDict is marked as total=None, it does not need to contain all the field. However, we need to + # be able to distinguish between the fields that are missing from the ones that are present but set to None. + # To do that, we extend the original type with str. We will later add a special string + # (e.g., __td_missing__) as a marker at deserialization time. + actual_type = Union[actual_type, str] # type: ignore + field_obj = RecordField( py_type=actual_type, name=py_field[0], diff --git a/tests/test_typed_dict.py b/tests/test_typed_dict.py index acbbeb7..3e1248b 100644 --- a/tests/test_typed_dict.py +++ b/tests/test_typed_dict.py @@ -1,5 +1,7 @@ +from enum import StrEnum from typing import Annotated, TypedDict +import py_avro_schema as pas from py_avro_schema._alias import Alias, register_type_alias from py_avro_schema._testing import assert_schema @@ -85,3 +87,29 @@ class User(TypedDict): } assert_schema(User, expected) + + +def test_non_total_typed_dict(): + class Opt(StrEnum): + val = "invalid-val" + + class PyType(TypedDict, total=False): + name: str + nickname: str | None + age: int | None + opt: Opt | None + + expected = { + "type": "record", + "name": "PyType", + "fields": [ + { + "name": "name", + "type": "string", + }, + {"name": "nickname", "type": ["string", "null"]}, + {"name": "age", "type": ["long", "null", "string"]}, + {"name": "opt", "type": [{"namedString": "Opt", "type": "string"}, "null"]}, + ], + } + assert_schema(PyType, expected, options=pas.Option.MARK_NON_TOTAL_TYPED_DICTS)