Skip to content

Commit ffa54ac

Browse files
authored
Add option to wrap lists and dics inside Avro records (#21)
1 parent 70728f6 commit ffa54ac

File tree

3 files changed

+384
-12
lines changed

3 files changed

+384
-12
lines changed

src/py_avro_schema/_schemas.py

Lines changed: 87 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373

7474
RUNTIME_TYPE_KEY = "_runtime_type"
7575
REF_ID_KEY = "__id"
76+
REF_DATA_KEY = "__data"
7677
SYMBOL_REGEX = re.compile(r"[A-Za-z_][A-Za-z0-9_]*")
7778

7879

@@ -151,6 +152,9 @@ class Option(enum.Flag):
151152
# factories might a problem when comparing schemas, as they change every time a schema is generated by definition.
152153
DETERMINISTIC_DEFAULTS = enum.auto()
153154

155+
#: Wraps lists and maps into a record type
156+
WRAP_INTO_RECORDS = enum.auto()
157+
154158

155159
JSON_OPTIONS = [opt for opt in Option if opt.name and opt.name.startswith("JSON_")]
156160

@@ -298,6 +302,28 @@ def make_default(self, py_default: Any) -> Any:
298302
"""
299303
return py_default
300304

305+
def _wrap_as_record(self, inner_schema: JSONObj, names: NamesType) -> JSONType:
306+
"""
307+
Wrap a container schema (array or map) into an Avro record with ``__id`` and ``__data`` fields.
308+
Handles deduplication via ``names``.
309+
"""
310+
record_name = _avro_name_for_type(_type_from_annotated(self.py_type))
311+
fullname = f"{self.namespace}.{record_name}" if self.namespace else record_name
312+
if fullname in names:
313+
return fullname
314+
names.append(fullname)
315+
record_schema = {
316+
"type": "record",
317+
"name": record_name,
318+
"fields": [
319+
{"name": REF_ID_KEY, "type": ["null", "long"], "default": None},
320+
{"name": REF_DATA_KEY, "type": inner_schema},
321+
],
322+
}
323+
if self.namespace:
324+
record_schema["namespace"] = self.namespace
325+
return record_schema
326+
301327

302328
@register_schema
303329
class PrimitiveSchema(Schema):
@@ -712,19 +738,22 @@ def __init__(
712738
args = get_args(py_type) # TODO: validate if args has exactly 1 item?
713739
self.items_schema = _schema_obj(args[0], namespace=namespace, options=options)
714740

715-
def data(self, names: NamesType) -> JSONObj:
741+
def data(self, names: NamesType) -> JSONType:
716742
"""Return the schema data"""
717-
return {
718-
"type": "array",
719-
"items": self.items_schema.data(names=names),
720-
}
743+
array_schema = {"type": "array", "items": self.items_schema.data(names=names)}
744+
if Option.WRAP_INTO_RECORDS not in self.options:
745+
return array_schema
746+
return self._wrap_as_record(array_schema, names)
721747

722-
def make_default(self, py_default: collections.abc.Sequence) -> JSONArray:
748+
def make_default(self, py_default: collections.abc.Sequence) -> JSONType:
723749
"""Return an Avro schema compliant default value for a given Python Sequence
724750
725751
:param py_default: The Python sequence to generate a default value for.
726752
"""
727-
return [self.items_schema.make_default(item) for item in py_default]
753+
list_default = [self.items_schema.make_default(item) for item in py_default]
754+
if Option.WRAP_INTO_RECORDS in self.options:
755+
return {REF_ID_KEY: None, REF_DATA_KEY: list_default}
756+
return list_default
728757

729758

730759
@register_schema
@@ -787,12 +816,18 @@ def __init__(
787816
raise TypeError(f"Cannot generate Avro mapping schema for Python dictionary {py_type} with non-string keys")
788817
self.values_schema = _schema_obj(args[1], namespace=namespace, options=options)
789818

790-
def data(self, names: NamesType) -> JSONObj:
819+
def data(self, names: NamesType) -> JSONType:
791820
"""Return the schema data"""
792-
return {
793-
"type": "map",
794-
"values": self.values_schema.data(names=names),
795-
}
821+
map_schema = {"type": "map", "values": self.values_schema.data(names=names)}
822+
if Option.WRAP_INTO_RECORDS not in self.options:
823+
return map_schema
824+
return self._wrap_as_record(map_schema, names)
825+
826+
def make_default(self, py_default: Any) -> JSONType:
827+
"""Return an Avro schema compliant default value for a given Python value"""
828+
if Option.WRAP_INTO_RECORDS in self.options:
829+
return {REF_ID_KEY: None, REF_DATA_KEY: py_default}
830+
return py_default
796831

797832

798833
@register_schema
@@ -1429,3 +1464,43 @@ def _type_from_annotated(py_type: Type) -> Type:
14291464
return args[0]
14301465
else:
14311466
return py_type
1467+
1468+
1469+
def _avro_name_for_type(py_type: Type) -> str:
1470+
"""
1471+
Generate an Avro-compatible name for a given Python type. It is used when wrapping container types (mostly lists
1472+
and maps) into Avro records.
1473+
It also uses the module name to build the name of the record. Initially, we thought about hashing all the fully
1474+
qualified names to distinguish `list[ClassA]` from `list[ClassA]` where `ClassA` are separate classes from
1475+
different modules. As Avro does not seem to have max length for the record name, this seems to be more readable.
1476+
See `test_avro_name_for_type` test suite.
1477+
"""
1478+
py_type = _type_from_annotated(py_type)
1479+
if py_type is None or py_type is type(None):
1480+
return "Null"
1481+
origin = get_origin(py_type)
1482+
args = get_args(py_type)
1483+
if inspect.isclass(py_type):
1484+
if not (name := py_type.__name__):
1485+
raise TypeNotSupportedError(
1486+
f"Cannot generate a wrapper record name for Python type {py_type}: empty class name"
1487+
)
1488+
name = name[0].upper() + name[1:]
1489+
module = py_type.__module__
1490+
if module and module != "builtins":
1491+
mod_prefix = "".join(
1492+
word[0].upper() + word[1:] for part in module.split(".") for word in part.split("_") if word
1493+
)
1494+
return mod_prefix + name
1495+
return name
1496+
if origin is not None and args:
1497+
union_type = getattr(types, "UnionType", None)
1498+
if origin is Union or (union_type and origin is union_type):
1499+
return "Or".join(sorted(_avro_name_for_type(arg) for arg in args))
1500+
if _is_class(origin, collections.abc.MutableSet):
1501+
return _avro_name_for_type(args[0]) + "Set"
1502+
if _is_class(origin, collections.abc.Sequence):
1503+
return _avro_name_for_type(args[0]) + "List"
1504+
if _is_class(origin, collections.abc.Mapping):
1505+
return _avro_name_for_type(args[1]) + "Map"
1506+
raise TypeNotSupportedError(f"Cannot generate a wrapper record name for Python type {py_type}")

tests/test_avro_name_for_type.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
"""
2+
Set of unit tests for the _avro_name_for_type function, as this is a pretty crucial component of our design with
3+
wrapped records.
4+
"""
5+
6+
import typing
7+
8+
import pytest
9+
10+
from py_avro_schema._schemas import TypeNotSupportedError, _avro_name_for_type
11+
12+
13+
class ClassA:
14+
pass
15+
16+
17+
class ClassB:
18+
pass
19+
20+
21+
# Sequences
22+
23+
24+
def test_list_str():
25+
assert _avro_name_for_type(list[str]) == "StrList"
26+
27+
28+
def test_nested_list():
29+
assert _avro_name_for_type(list[list[str]]) == "StrListList"
30+
31+
32+
def test_list_of_custom_class():
33+
assert _avro_name_for_type(list[ClassA]) == "TestAvroNameForTypeClassAList"
34+
35+
36+
def test_list_of_union():
37+
assert _avro_name_for_type(list[str | int]) == "IntOrStrList"
38+
39+
40+
def test_list_of_union_two_custom_classes():
41+
assert _avro_name_for_type(list[ClassA | ClassB]) == "TestAvroNameForTypeClassAOrTestAvroNameForTypeClassBList"
42+
43+
44+
def test_list_of_optional_custom_class():
45+
assert _avro_name_for_type(list[ClassA | None]) == "NullOrTestAvroNameForTypeClassAList"
46+
47+
48+
def test_list_of_dict_with_custom_class():
49+
assert _avro_name_for_type(list[dict[str, ClassA]]) == "TestAvroNameForTypeClassAMapList"
50+
51+
52+
# Sets
53+
54+
55+
def test_set_str():
56+
assert _avro_name_for_type(set[str]) == "StrSet"
57+
58+
59+
def test_set_custom_class():
60+
assert _avro_name_for_type(set[ClassA]) == "TestAvroNameForTypeClassASet"
61+
62+
63+
# Maps
64+
65+
66+
def test_dict_str_str():
67+
assert _avro_name_for_type(dict[str, str]) == "StrMap"
68+
69+
70+
def test_dict_custom_class_value():
71+
assert _avro_name_for_type(dict[str, ClassA]) == "TestAvroNameForTypeClassAMap"
72+
73+
74+
def test_dict_with_union_two_custom_classes():
75+
assert _avro_name_for_type(dict[str, ClassA | ClassB]) == "TestAvroNameForTypeClassAOrTestAvroNameForTypeClassBMap"
76+
77+
78+
def test_dict_with_optional_custom_class():
79+
assert _avro_name_for_type(dict[str, ClassA | None]) == "NullOrTestAvroNameForTypeClassAMap"
80+
81+
82+
def test_dict_with_list_of_custom_class():
83+
assert _avro_name_for_type(dict[str, list[ClassA]]) == "TestAvroNameForTypeClassAListMap"
84+
85+
86+
def test_dict_none_value():
87+
assert _avro_name_for_type(dict[str, None]) == "NullMap"
88+
89+
90+
# Unions
91+
92+
93+
def test_union_str_int():
94+
assert _avro_name_for_type(str | int) == "IntOrStr"
95+
96+
97+
def test_union_str_int_legacy_syntax():
98+
assert _avro_name_for_type(typing.Union[str, int]) == "IntOrStr"
99+
100+
101+
def test_union_two_custom_classes_order_independent():
102+
assert _avro_name_for_type(ClassA | ClassB) == "TestAvroNameForTypeClassAOrTestAvroNameForTypeClassB"
103+
assert _avro_name_for_type(ClassB | ClassA) == "TestAvroNameForTypeClassAOrTestAvroNameForTypeClassB"
104+
105+
106+
def test_optional_custom_class():
107+
assert _avro_name_for_type(ClassA | None) == "NullOrTestAvroNameForTypeClassA"
108+
109+
110+
# Special cases
111+
112+
113+
def test_same_class_name_different_modules():
114+
115+
ClassFromA = type("MyClass", (), {"__module__": "pkg.mod_a"})
116+
ClassFromB = type("MyClass", (), {"__module__": "pkg.mod_b"})
117+
118+
name_a = _avro_name_for_type(list[ClassFromA]) # noqa
119+
name_b = _avro_name_for_type(list[ClassFromB]) # noqa
120+
121+
assert name_a == "PkgModAMyClassList"
122+
assert name_b == "PkgModBMyClassList"
123+
124+
125+
# Error cases
126+
127+
128+
def test_unknown_type_raises():
129+
T = typing.TypeVar("T")
130+
with pytest.raises(TypeNotSupportedError, match="Cannot generate a wrapper record name"):
131+
_avro_name_for_type(T)

0 commit comments

Comments
 (0)