Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion src/py_avro_schema/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ class Option(enum.Flag):
#: See https://docs.pydantic.dev/dev/api/fields/#pydantic.fields.Field
USE_FIELD_ALIAS = enum.auto()

#: TypedDict marked with ``total=False`` are valid structures when a field is missing. When of the field is also
# optional, we need to have a way to distinguish between a `None` and a non-set field. With this option, the type
# of each field is extended with `string`. This way, clients can add markers (e.g., `__td_missing__`) to discern
# the two cases.
MARK_NON_TOTAL_TYPED_DICTS = enum.auto()


JSON_OPTIONS = [opt for opt in Option if opt.name and opt.name.startswith("JSON_")]

Expand Down Expand Up @@ -804,7 +810,18 @@ def data(self, names: NamesType) -> JSONType:
schemas = (item_schema.data(names=names) for item_schema in self.item_schemas)
# We need to deduplicate the schemas **after** rendering. This is because **different** Python types might
# result in the **same** Avro schema. Preserving order as order may be significant in an Avro schema.
unique_schemas = list(more_itertools.unique_everseen(schemas))

def normalize_string_duplicates(_schema):
"""We might have cases in which we have a schema both for ``StrSubclassSchema`` (e.g., a ``StrEnum`` with
invalid names is represented as a ``StrSubclassSchema``) and a string. These are technically duplicates,
but ``unique_everseen`` won't remove them by default."""
if _schema == "string":
return "string"
elif isinstance(_schema, dict) and _schema.get("type") == "string":
return "string"
return _schema

unique_schemas = list(more_itertools.unique_everseen(schemas, key=normalize_string_duplicates))
if len(unique_schemas) > 1:
return unique_schemas
else:
Expand Down Expand Up @@ -1168,6 +1185,8 @@ def handles_type(cls, py_type: Type) -> bool:
not dataclasses.is_dataclass(py_type)
# Pydantic models are handled above
and not hasattr(py_type, "__pydantic_private__")
# typed_dict handled separately
and not is_typeddict(py_type)
# If we are subclassing a string, used the "named string" approach
and (inspect.isclass(py_type) and not issubclass(py_type, str))
# and any other class with typed annotations
Expand Down Expand Up @@ -1240,12 +1259,21 @@ def __init__(self, py_type: Type, namespace: str | None = None, options: Option
"""
super().__init__(py_type, namespace=namespace, options=options)
py_type = _type_from_annotated(py_type)
self.is_total = py_type.__dict__.get("__total__", True)
self.py_fields: dict[str, Type] = get_type_hints(py_type, include_extras=True)
self.record_fields = [self._record_field(field) for field in self.py_fields.items()]

def _record_field(self, py_field: tuple[str, Type]) -> RecordField:
"""Return an Avro record field object for a given TypedDict field"""
aliases, actual_type = get_field_aliases_and_actual_type(py_field[1])

if Option.MARK_NON_TOTAL_TYPED_DICTS in self.options and not self.is_total:
# If a TypedDict is marked as total=None, it does not need to contain all the field. However, we need to
# be able to distinguish between the fields that are missing from the ones that are present but set to None.
# To do that, we extend the original type with str. We will later add a special string
# (e.g., __td_missing__) as a marker at deserialization time.
actual_type = Union[actual_type, str] # type: ignore

field_obj = RecordField(
py_type=actual_type,
name=py_field[0],
Expand Down
28 changes: 28 additions & 0 deletions tests/test_typed_dict.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from enum import StrEnum
from typing import Annotated, TypedDict

import py_avro_schema as pas
from py_avro_schema._alias import Alias, register_type_alias
from py_avro_schema._testing import assert_schema

Expand Down Expand Up @@ -85,3 +87,29 @@ class User(TypedDict):
}

assert_schema(User, expected)


def test_non_total_typed_dict():
class Opt(StrEnum):
val = "invalid-val"

class PyType(TypedDict, total=False):
name: str
nickname: str | None
age: int | None
opt: Opt | None
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be worth to add a str | None here just to verify what happens when we do Union[str, None, str]?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in c514ae0. Union can't have duplicates, so it would just be Union[str, | None]. If you actually try to create a type alias like this, you'd get a Optional[str] :)


expected = {
"type": "record",
"name": "PyType",
"fields": [
{
"name": "name",
"type": "string",
},
{"name": "nickname", "type": ["string", "null"]},
{"name": "age", "type": ["long", "null", "string"]},
{"name": "opt", "type": [{"namedString": "Opt", "type": "string"}, "null"]},
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: so here the regular str is not added, so would that work if we add the special marker to a Opt enum type? seems like yes as your comment in the deduplication logic

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it would. {"namedString": "Opt", "type": "string"} is totally equivalent to "type": "string" for Avro.

],
}
assert_schema(PyType, expected, options=pas.Option.MARK_NON_TOTAL_TYPED_DICTS)