Skip to content

Commit 3a705c7

Browse files
authored
Improve schemas for non-total TypedDict (#11)
1 parent fd4ec45 commit 3a705c7

2 files changed

Lines changed: 57 additions & 1 deletion

File tree

src/py_avro_schema/_schemas.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,12 @@ class Option(enum.Flag):
131131
#: See https://docs.pydantic.dev/dev/api/fields/#pydantic.fields.Field
132132
USE_FIELD_ALIAS = enum.auto()
133133

134+
#: TypedDict marked with ``total=False`` are valid structures when a field is missing. When of the field is also
135+
# optional, we need to have a way to distinguish between a `None` and a non-set field. With this option, the type
136+
# of each field is extended with `string`. This way, clients can add markers (e.g., `__td_missing__`) to discern
137+
# the two cases.
138+
MARK_NON_TOTAL_TYPED_DICTS = enum.auto()
139+
134140

135141
JSON_OPTIONS = [opt for opt in Option if opt.name and opt.name.startswith("JSON_")]
136142

@@ -804,7 +810,18 @@ def data(self, names: NamesType) -> JSONType:
804810
schemas = (item_schema.data(names=names) for item_schema in self.item_schemas)
805811
# We need to deduplicate the schemas **after** rendering. This is because **different** Python types might
806812
# result in the **same** Avro schema. Preserving order as order may be significant in an Avro schema.
807-
unique_schemas = list(more_itertools.unique_everseen(schemas))
813+
814+
def normalize_string_duplicates(_schema):
815+
"""We might have cases in which we have a schema both for ``StrSubclassSchema`` (e.g., a ``StrEnum`` with
816+
invalid names is represented as a ``StrSubclassSchema``) and a string. These are technically duplicates,
817+
but ``unique_everseen`` won't remove them by default."""
818+
if _schema == "string":
819+
return "string"
820+
elif isinstance(_schema, dict) and _schema.get("type") == "string":
821+
return "string"
822+
return _schema
823+
824+
unique_schemas = list(more_itertools.unique_everseen(schemas, key=normalize_string_duplicates))
808825
if len(unique_schemas) > 1:
809826
return unique_schemas
810827
else:
@@ -1168,6 +1185,8 @@ def handles_type(cls, py_type: Type) -> bool:
11681185
not dataclasses.is_dataclass(py_type)
11691186
# Pydantic models are handled above
11701187
and not hasattr(py_type, "__pydantic_private__")
1188+
# typed_dict handled separately
1189+
and not is_typeddict(py_type)
11711190
# If we are subclassing a string, used the "named string" approach
11721191
and (inspect.isclass(py_type) and not issubclass(py_type, str))
11731192
# and any other class with typed annotations
@@ -1240,12 +1259,21 @@ def __init__(self, py_type: Type, namespace: str | None = None, options: Option
12401259
"""
12411260
super().__init__(py_type, namespace=namespace, options=options)
12421261
py_type = _type_from_annotated(py_type)
1262+
self.is_total = py_type.__dict__.get("__total__", True)
12431263
self.py_fields: dict[str, Type] = get_type_hints(py_type, include_extras=True)
12441264
self.record_fields = [self._record_field(field) for field in self.py_fields.items()]
12451265

12461266
def _record_field(self, py_field: tuple[str, Type]) -> RecordField:
12471267
"""Return an Avro record field object for a given TypedDict field"""
12481268
aliases, actual_type = get_field_aliases_and_actual_type(py_field[1])
1269+
1270+
if Option.MARK_NON_TOTAL_TYPED_DICTS in self.options and not self.is_total:
1271+
# If a TypedDict is marked as total=None, it does not need to contain all the field. However, we need to
1272+
# be able to distinguish between the fields that are missing from the ones that are present but set to None.
1273+
# To do that, we extend the original type with str. We will later add a special string
1274+
# (e.g., __td_missing__) as a marker at deserialization time.
1275+
actual_type = Union[actual_type, str] # type: ignore
1276+
12491277
field_obj = RecordField(
12501278
py_type=actual_type,
12511279
name=py_field[0],

tests/test_typed_dict.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from enum import StrEnum
12
from typing import Annotated, TypedDict
23

4+
import py_avro_schema as pas
35
from py_avro_schema._alias import Alias, register_type_alias
46
from py_avro_schema._testing import assert_schema
57

@@ -85,3 +87,29 @@ class User(TypedDict):
8587
}
8688

8789
assert_schema(User, expected)
90+
91+
92+
def test_non_total_typed_dict():
93+
class Opt(StrEnum):
94+
val = "invalid-val"
95+
96+
class PyType(TypedDict, total=False):
97+
name: str
98+
nickname: str | None
99+
age: int | None
100+
opt: Opt | None
101+
102+
expected = {
103+
"type": "record",
104+
"name": "PyType",
105+
"fields": [
106+
{
107+
"name": "name",
108+
"type": "string",
109+
},
110+
{"name": "nickname", "type": ["string", "null"]},
111+
{"name": "age", "type": ["long", "null", "string"]},
112+
{"name": "opt", "type": [{"namedString": "Opt", "type": "string"}, "null"]},
113+
],
114+
}
115+
assert_schema(PyType, expected, options=pas.Option.MARK_NON_TOTAL_TYPED_DICTS)

0 commit comments

Comments
 (0)