Skip to content

Commit b97265f

Browse files
authored
Support for typing.NotRequired (#17)
1 parent 5196d22 commit b97265f

2 files changed

Lines changed: 39 additions & 6 deletions

File tree

src/py_avro_schema/_schemas.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
ForwardRef,
3838
List,
3939
Literal,
40+
NotRequired,
4041
Optional,
4142
Tuple,
4243
Type,
@@ -1289,11 +1290,15 @@ def _record_field(self, py_field: tuple[str, Type]) -> RecordField:
12891290
aliases, actual_type = get_field_aliases_and_actual_type(py_field[1])
12901291

12911292
if Option.MARK_NON_TOTAL_TYPED_DICTS in self.options and not self.is_total:
1292-
# If a TypedDict is marked as total=None, it does not need to contain all the field. However, we need to
1293+
# If a TypedDict is marked as total=False, it does not need to contain all the field. However, we need to
12931294
# be able to distinguish between the fields that are missing from the ones that are present but set to None.
12941295
# To do that, we extend the original type with str. We will later add a special string
12951296
# (e.g., __td_missing__) as a marker at deserialization time.
12961297
actual_type = Union[actual_type, str] # type: ignore
1298+
elif _is_not_required(actual_type):
1299+
# A field can be marked with typing.NotRequired even in a TypedDict with is not marked with total=False.
1300+
# Similarly as above, we extend the wrapped type with string.
1301+
actual_type = Union[_unwrap_not_required(actual_type), str] # type: ignore
12971302

12981303
field_obj = RecordField(
12991304
py_type=actual_type,
@@ -1316,6 +1321,16 @@ def _doc_for_class(py_type: Type) -> str:
13161321
return ""
13171322

13181323

1324+
def _is_not_required(py_type: Type) -> bool:
1325+
"""Checks if a type is marked with typing.NotRequired"""
1326+
return get_origin(py_type) is NotRequired # noqa
1327+
1328+
1329+
def _unwrap_not_required(py_type: Type) -> type:
1330+
"""Returns the wrapped type for typing.NotRequired"""
1331+
return get_args(py_type)[0]
1332+
1333+
13191334
def _is_dict_str_any(py_type: Type) -> bool:
13201335
"""Return whether a given type is ``Dict[str, Any]``"""
13211336
origin = get_origin(py_type)

tests/test_typed_dict.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import StrEnum
2-
from typing import Annotated, TypedDict
2+
from typing import Annotated, NotRequired, TypedDict
33

44
import py_avro_schema as pas
55
from py_avro_schema._alias import Alias, register_type_alias
@@ -103,13 +103,31 @@ class PyType(TypedDict, total=False):
103103
"type": "record",
104104
"name": "PyType",
105105
"fields": [
106-
{
107-
"name": "name",
108-
"type": "string",
109-
},
106+
{"name": "name", "type": "string"},
110107
{"name": "nickname", "type": ["string", "null"]},
111108
{"name": "age", "type": ["long", "null", "string"]},
112109
{"name": "opt", "type": [{"namedString": "Opt", "type": "string"}, "null"]},
113110
],
114111
}
115112
assert_schema(PyType, expected, options=pas.Option.MARK_NON_TOTAL_TYPED_DICTS)
113+
114+
115+
def test_non_required_keyword():
116+
class PyType(TypedDict):
117+
name: str
118+
value: NotRequired[str]
119+
value_int: NotRequired[int]
120+
nullable_value: NotRequired[str | None]
121+
122+
expected = {
123+
"type": "record",
124+
"name": "PyType",
125+
"fields": [
126+
{"name": "name", "type": "string"},
127+
{"name": "value", "type": "string"},
128+
{"name": "value_int", "type": ["long", "string"]},
129+
{"name": "nullable_value", "type": ["string", "null"]},
130+
],
131+
}
132+
133+
assert_schema(PyType, expected, options=pas.Option.MARK_NON_TOTAL_TYPED_DICTS)

0 commit comments

Comments
 (0)