Skip to content

Commit 3e49ba1

Browse files
authored
disallow TypedDict/Sequence/Dict unions (#24)
1 parent 469d815 commit 3e49ba1

File tree

6 files changed

+127
-53
lines changed

6 files changed

+127
-53
lines changed

src/py_avro_schema/_schemas.py

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,19 @@
7777
SYMBOL_REGEX = re.compile(r"[A-Za-z_][A-Za-z0-9_]*")
7878

7979

80+
class TDMissingMarker(str):
81+
"""
82+
Custom Typed Dict missing marker to indicate values that are in the annotations but not present at runtime.
83+
We are using a custom subclass string type to be able to differentiate them when creating schemas.
84+
See `py_avro_schema._schemas.TypedDictSchema._record_field` and `UnionSchema._validate_union`
85+
"""
86+
87+
...
88+
89+
90+
TD_MISSING_MARKER = TDMissingMarker("__td_missing__")
91+
92+
8093
class TypeNotSupportedError(TypeError):
8194
"""Error raised when a Avro schema cannot be generated for a given Python type"""
8295

@@ -312,20 +325,14 @@ def _wrap_as_record(self, inner_schema: JSONObj, names: NamesType) -> JSONType:
312325
if fullname in names:
313326
return fullname
314327
names.append(fullname)
315-
316-
fields = [
317-
{"name": REF_ID_KEY, "type": ["null", "long"], "default": None},
318-
{"name": REF_DATA_KEY, "type": inner_schema},
319-
]
320-
if Option.ADD_RUNTIME_TYPE_FIELD in self.options:
321-
fields.append({"name": RUNTIME_TYPE_KEY, "type": ["null", "string"]})
322-
323328
record_schema = {
324329
"type": "record",
325330
"name": record_name,
326-
"fields": fields,
331+
"fields": [
332+
{"name": REF_ID_KEY, "type": ["null", "long"], "default": None},
333+
{"name": REF_DATA_KEY, "type": inner_schema},
334+
],
327335
}
328-
329336
if self.namespace:
330337
record_schema["namespace"] = self.namespace
331338
return record_schema
@@ -864,8 +871,34 @@ def __init__(self, py_type: Type[Union[Any]], namespace: Optional[str] = None, o
864871
super().__init__(py_type, namespace=namespace, options=options)
865872
py_type = _type_from_annotated(py_type)
866873
args = get_args(py_type)
874+
self._validate_union(args)
867875
self.item_schemas = [_schema_obj(arg, namespace=namespace, options=options) for arg in args]
868876

877+
@staticmethod
878+
def _validate_union(args: tuple[Any, ...]) -> None:
879+
"""
880+
Validate that the arguments of the Union are possible to deal with. At runtime, we cannot get the runtime type
881+
of TypedDict instances, as they are just regular dicts.
882+
Same for sequences like List and Set, we would have to scan them to know all the runtime types of the values
883+
they contain.
884+
:param args: list of types of the Union
885+
:return: None
886+
:raises: TypeError if the Union types are invalid
887+
"""
888+
if type(None) not in args and TDMissingMarker not in args:
889+
if any(
890+
# Enum is treated as a Sequence
891+
not EnumSchema.handles_type(arg)
892+
and (
893+
is_typeddict(arg)
894+
or SequenceSchema.handles_type(arg)
895+
or DictSchema.handles_type(arg)
896+
or SetSchema.handles_type(arg)
897+
)
898+
for arg in args
899+
):
900+
raise TypeError(f"Union of types {args} is not supported. Python cannot detect proper type at runtime")
901+
869902
def data(self, names: NamesType) -> JSONType:
870903
"""Return the schema data"""
871904
# Render the item schemas
@@ -1302,23 +1335,13 @@ def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Opti
13021335
self.py_fields: list[tuple[str, type]] = []
13031336
for k, v in type_hints.items():
13041337
self.py_fields.append((k, v))
1305-
# We store __init__ parameters with default values. They can be used as defaults for the record.
1306-
self.signature_fields = {
1307-
param.name: (param.annotation, param.default)
1308-
for param in list(inspect.signature(py_type.__init__).parameters.values())[1:]
1309-
if param.default is not inspect._empty
1310-
}
13111338
self.record_fields = [self._record_field(field) for field in self.py_fields]
13121339

13131340
def _record_field(self, py_field: tuple[str, Type]) -> RecordField:
13141341
"""Return an Avro record field object for a given Python instance attribute"""
13151342
aliases, actual_type = get_field_aliases_and_actual_type(py_field[1])
13161343
name = py_field[0]
13171344
default = dataclasses.MISSING
1318-
if field := self.signature_fields.get(name):
1319-
_annotation, _default = field
1320-
if actual_type is _annotation:
1321-
default = _default or dataclasses.MISSING
13221345
field_obj = RecordField(
13231346
py_type=actual_type,
13241347
name=name,
@@ -1370,15 +1393,15 @@ def _record_field(self, py_field: tuple[str, Type]) -> RecordField:
13701393
# be able to distinguish between the fields that are missing from the ones that are present but set to None.
13711394
# To do that, we extend the original type with str. We will later add a special string
13721395
# (e.g., __td_missing__) as a marker at deserialization time.
1373-
actual_type = Union[actual_type, str] # type: ignore
1396+
actual_type = Union[actual_type, TDMissingMarker] # type: ignore
13741397
if _is_optional(actual_type):
13751398
# Note: this works since this schema does not implement `make_default` and the base implementation
13761399
# simply return the provided type (None in this case).
1377-
default = "__td_missing__" # type: ignore
1400+
default = TD_MISSING_MARKER # type: ignore
13781401
elif _is_not_required(actual_type):
13791402
# A field can be marked with typing.NotRequired even in a TypedDict with is not marked with total=False.
13801403
# Similarly as above, we extend the wrapped type with string.
1381-
actual_type = Union[_unwrap_not_required(actual_type), str] # type: ignore
1404+
actual_type = Union[_unwrap_not_required(actual_type), TDMissingMarker] # type: ignore
13821405

13831406
field_obj = RecordField(
13841407
py_type=actual_type,

src/py_avro_schema/_testing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
import py_avro_schema._schemas
2525

2626

27-
def assert_schema(py_type: Type, expected_schema: Union[str, Dict[str, str], List[str]], **kwargs) -> None:
27+
def assert_schema(
28+
py_type: Type, expected_schema: Union[str, Dict[str, str], List[str | Dict[str, str]]], **kwargs
29+
) -> None:
2830
"""Test that the given Python type results in the correct Avro schema"""
2931
if not kwargs.pop("do_auto_namespace", False):
3032
kwargs["options"] = kwargs.get("options", py_avro_schema.Option(0)) | py_avro_schema.Option.NO_AUTO_NAMESPACE

tests/test_avro_schema.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,3 @@ class PyType:
8080
],
8181
}
8282
assert_schema(PyType, expected, options=pas.Option.ADD_RUNTIME_TYPE_FIELD)
83-
84-
85-
def test_add_type_field_on_wrapped_record():
86-
py_type = list[str]
87-
expected = {
88-
"type": "record",
89-
"name": "StrList",
90-
"fields": [
91-
{"name": "__id", "type": ["null", "long"], "default": None},
92-
{"name": "__data", "type": {"type": "array", "items": "string"}},
93-
{"name": "_runtime_type", "type": ["null", "string"]},
94-
],
95-
}
96-
assert_schema(py_type, expected, options=pas.Option.WRAP_INTO_RECORDS | pas.Option.ADD_RUNTIME_TYPE_FIELD)

tests/test_plain_class.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def __init__(
5151
{
5252
"name": "country",
5353
"type": "string",
54-
"default": "NLD",
5554
},
5655
{
5756
"name": "latitude",

tests/test_primitives.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,36 @@ def test_union_str_str():
385385
assert_schema(py_type, expected)
386386

387387

388+
def test_union_str_list_str_error():
389+
py_type = Union[str, list[str]]
390+
with pytest.raises(TypeError):
391+
py_avro_schema._schemas.schema(py_type)
392+
393+
394+
def test_union_str_dict_str_error():
395+
py_type = Union[str, dict[str, str]]
396+
with pytest.raises(TypeError):
397+
py_avro_schema._schemas.schema(py_type)
398+
399+
400+
def test_union_str_set_str_error():
401+
py_type = Union[str, set[str]]
402+
with pytest.raises(TypeError):
403+
py_avro_schema._schemas.schema(py_type)
404+
405+
406+
def test_union_str_tuple_str_error():
407+
py_type = Union[str, tuple[str, ...]]
408+
with pytest.raises(TypeError):
409+
py_avro_schema._schemas.schema(py_type)
410+
411+
412+
def test_union_str_list_str_with_marker():
413+
py_type = Union[list[str], py_avro_schema._schemas.TDMissingMarker]
414+
expected = [{"items": "string", "type": "array"}, {"namedString": "TDMissingMarker", "type": "string"}]
415+
assert_schema(py_type, expected)
416+
417+
388418
def test_union_str_annotated_str():
389419
py_type = Union[str, Annotated[str, ...]]
390420
expected = "string"

tests/test_typed_dict.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,20 @@
1+
# Copyright 2022 J.P. Morgan Chase & Co.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
# the License. You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an
9+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
# specific language governing permissions and limitations under the License.
11+
112
from enum import StrEnum
2-
from typing import Annotated, NotRequired, TypedDict
13+
from typing import Annotated, NotRequired, TypedDict, Union
14+
15+
import pytest
316

17+
import py_avro_schema
418
import py_avro_schema as pas
519
from py_avro_schema._alias import Alias, register_type_alias
620
from py_avro_schema._testing import assert_schema
@@ -104,27 +118,35 @@ class PyType(TypedDict, total=False):
104118
valid: ValidEnumSymbol | None
105119

106120
expected = {
107-
"type": "record",
108-
"name": "PyType",
109121
"fields": [
110-
{"name": "name", "type": "string"},
111-
{"name": "nickname", "type": ["string", "null"], "default": "__td_missing__"},
112-
{"name": "age", "type": ["string", "long", "null"], "default": "__td_missing__"},
122+
{"name": "name", "type": {"namedString": "TDMissingMarker", "type": "string"}},
113123
{
114-
"name": "invalid",
115-
"type": [{"namedString": "InvalidEnumSymbol", "type": "string"}, "null"],
116124
"default": "__td_missing__",
125+
"name": "nickname",
126+
"type": ["null", {"namedString": "TDMissingMarker", "type": "string"}],
127+
},
128+
{
129+
"default": "__td_missing__",
130+
"name": "age",
131+
"type": [{"namedString": "TDMissingMarker", "type": "string"}, "long", "null"],
132+
},
133+
{
134+
"default": "__td_missing__",
135+
"name": "invalid",
136+
"type": [{"namedString": "TDMissingMarker", "type": "string"}, "null"],
117137
},
118138
{
119139
"default": "__td_missing__",
120140
"name": "valid",
121141
"type": [
122-
"string",
142+
{"namedString": "TDMissingMarker", "type": "string"},
123143
{"default": "valid_val", "name": "ValidEnumSymbol", "symbols": ["valid_val"], "type": "enum"},
124144
"null",
125145
],
126146
},
127147
],
148+
"name": "PyType",
149+
"type": "record",
128150
}
129151
assert_schema(PyType, expected, options=pas.Option.MARK_NON_TOTAL_TYPED_DICTS)
130152

@@ -137,14 +159,14 @@ class PyType(TypedDict):
137159
nullable_value: NotRequired[str | None]
138160

139161
expected = {
140-
"type": "record",
141-
"name": "PyType",
142162
"fields": [
143163
{"name": "name", "type": "string"},
144-
{"name": "value", "type": "string"},
145-
{"name": "value_int", "type": ["long", "string"]},
146-
{"name": "nullable_value", "type": ["string", "null"]},
164+
{"name": "value", "type": {"namedString": "TDMissingMarker", "type": "string"}},
165+
{"name": "value_int", "type": ["long", {"namedString": "TDMissingMarker", "type": "string"}]},
166+
{"name": "nullable_value", "type": ["null", {"namedString": "TDMissingMarker", "type": "string"}]},
147167
],
168+
"name": "PyType",
169+
"type": "record",
148170
}
149171

150172
assert_schema(PyType, expected, options=pas.Option.MARK_NON_TOTAL_TYPED_DICTS)
@@ -170,3 +192,15 @@ class PyType(TypedDict):
170192
],
171193
}
172194
assert_schema(PyType, expected, options=pas.Option.ADD_REFERENCE_ID)
195+
196+
197+
def test_union_typed_dict_error():
198+
class PyType(TypedDict):
199+
var: str
200+
201+
class PyType2(TypedDict):
202+
var: str
203+
204+
py_type = Union[PyType, PyType2]
205+
with pytest.raises(TypeError):
206+
py_avro_schema._schemas.schema(py_type)

0 commit comments

Comments
 (0)