Skip to content

Commit f52607f

Browse files
authored
Add __td_missing__ marker as default value for non-total TypedDict (#19)
1 parent 3c3c48e commit f52607f

File tree

2 files changed

+43
-6
lines changed

2 files changed

+43
-6
lines changed

src/py_avro_schema/_schemas.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -817,7 +817,7 @@ def __init__(self, py_type: Type[Union[Any]], namespace: Optional[str] = None, o
817817
def data(self, names: NamesType) -> JSONType:
818818
"""Return the schema data"""
819819
# Render the item schemas
820-
schemas = (item_schema.data(names=names) for item_schema in self.item_schemas)
820+
schemas = list(item_schema.data(names=names) for item_schema in self.item_schemas)
821821
# We need to deduplicate the schemas **after** rendering. This is because **different** Python types might
822822
# result in the **same** Avro schema. Preserving order as order may be significant in an Avro schema.
823823

@@ -831,6 +831,12 @@ def normalize_string_duplicates(_schema):
831831
return "string"
832832
return _schema
833833

834+
# If a namedString schema (str subclass with extra metadata) and a plain "string" are both present,
835+
# remove the plain "string" so the more informative namedString is preserved after deduplication.
836+
has_named_string = any(isinstance(s, dict) and "namedString" in s for s in schemas)
837+
if has_named_string:
838+
schemas = [s for s in schemas if s != "string"]
839+
834840
unique_schemas = list(more_itertools.unique_everseen(schemas, key=normalize_string_duplicates))
835841
if len(unique_schemas) > 1:
836842
return unique_schemas
@@ -1295,12 +1301,17 @@ def _record_field(self, py_field: tuple[str, Type]) -> RecordField:
12951301
"""Return an Avro record field object for a given TypedDict field"""
12961302
aliases, actual_type = get_field_aliases_and_actual_type(py_field[1])
12971303

1304+
default = dataclasses.MISSING
12981305
if Option.MARK_NON_TOTAL_TYPED_DICTS in self.options and not self.is_total:
12991306
# If a TypedDict is marked as total=False, it does not need to contain all the field. However, we need to
13001307
# be able to distinguish between the fields that are missing from the ones that are present but set to None.
13011308
# To do that, we extend the original type with str. We will later add a special string
13021309
# (e.g., __td_missing__) as a marker at deserialization time.
13031310
actual_type = Union[actual_type, str] # type: ignore
1311+
if _is_optional(actual_type):
1312+
# Note: this works since this schema does not implement `make_default` and the base implementation
1313+
# simply return the provided type (None in this case).
1314+
default = "__td_missing__" # type: ignore
13041315
elif _is_not_required(actual_type):
13051316
# A field can be marked with typing.NotRequired even in a TypedDict with is not marked with total=False.
13061317
# Similarly as above, we extend the wrapped type with string.
@@ -1311,6 +1322,7 @@ def _record_field(self, py_field: tuple[str, Type]) -> RecordField:
13111322
name=py_field[0],
13121323
namespace=self.namespace_override,
13131324
aliases=aliases,
1325+
default=default,
13141326
options=self.options,
13151327
)
13161328
return field_obj
@@ -1327,6 +1339,14 @@ def _doc_for_class(py_type: Type) -> str:
13271339
return ""
13281340

13291341

1342+
def _is_optional(py_type: Type) -> bool:
1343+
"""Given a Union of types, checks if None is one of those"""
1344+
try:
1345+
return type(None) in get_args(py_type)
1346+
except Exception:
1347+
return False
1348+
1349+
13301350
def _is_not_required(py_type: Type) -> bool:
13311351
"""Checks if a type is marked with typing.NotRequired"""
13321352
return get_origin(py_type) is NotRequired # noqa

tests/test_typed_dict.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,23 +90,40 @@ class User(TypedDict):
9090

9191

9292
def test_non_total_typed_dict():
93-
class Opt(StrEnum):
93+
class InvalidEnumSymbol(StrEnum):
9494
val = "invalid-val"
9595

96+
class ValidEnumSymbol(StrEnum):
97+
val = "valid_val"
98+
9699
class PyType(TypedDict, total=False):
97100
name: str
98101
nickname: str | None
99102
age: int | None
100-
opt: Opt | None
103+
invalid: InvalidEnumSymbol | None
104+
valid: ValidEnumSymbol | None
101105

102106
expected = {
103107
"type": "record",
104108
"name": "PyType",
105109
"fields": [
106110
{"name": "name", "type": "string"},
107-
{"name": "nickname", "type": ["string", "null"]},
108-
{"name": "age", "type": ["long", "null", "string"]},
109-
{"name": "opt", "type": [{"namedString": "Opt", "type": "string"}, "null"]},
111+
{"name": "nickname", "type": ["string", "null"], "default": "__td_missing__"},
112+
{"name": "age", "type": ["string", "long", "null"], "default": "__td_missing__"},
113+
{
114+
"name": "invalid",
115+
"type": [{"namedString": "InvalidEnumSymbol", "type": "string"}, "null"],
116+
"default": "__td_missing__",
117+
},
118+
{
119+
"default": "__td_missing__",
120+
"name": "valid",
121+
"type": [
122+
"string",
123+
{"default": "valid_val", "name": "ValidEnumSymbol", "symbols": ["valid_val"], "type": "enum"},
124+
"null",
125+
],
126+
},
110127
],
111128
}
112129
assert_schema(PyType, expected, options=pas.Option.MARK_NON_TOTAL_TYPED_DICTS)

0 commit comments

Comments
 (0)