|
73 | 73 |
|
74 | 74 | RUNTIME_TYPE_KEY = "_runtime_type" |
75 | 75 | REF_ID_KEY = "__id" |
| 76 | +REF_DATA_KEY = "__data" |
76 | 77 | SYMBOL_REGEX = re.compile(r"[A-Za-z_][A-Za-z0-9_]*") |
77 | 78 |
|
78 | 79 |
|
@@ -151,6 +152,9 @@ class Option(enum.Flag): |
151 | 152 | # factories might a problem when comparing schemas, as they change every time a schema is generated by definition. |
152 | 153 | DETERMINISTIC_DEFAULTS = enum.auto() |
153 | 154 |
|
| 155 | + #: Wraps lists and maps into a record type |
| 156 | + WRAP_INTO_RECORDS = enum.auto() |
| 157 | + |
154 | 158 |
|
155 | 159 | JSON_OPTIONS = [opt for opt in Option if opt.name and opt.name.startswith("JSON_")] |
156 | 160 |
|
@@ -298,6 +302,28 @@ def make_default(self, py_default: Any) -> Any: |
298 | 302 | """ |
299 | 303 | return py_default |
300 | 304 |
|
| 305 | + def _wrap_as_record(self, inner_schema: JSONObj, names: NamesType) -> JSONType: |
| 306 | + """ |
| 307 | + Wrap a container schema (array or map) into an Avro record with ``__id`` and ``__data`` fields. |
| 308 | + Handles deduplication via ``names``. |
| 309 | + """ |
| 310 | + record_name = _avro_name_for_type(_type_from_annotated(self.py_type)) |
| 311 | + fullname = f"{self.namespace}.{record_name}" if self.namespace else record_name |
| 312 | + if fullname in names: |
| 313 | + return fullname |
| 314 | + names.append(fullname) |
| 315 | + record_schema = { |
| 316 | + "type": "record", |
| 317 | + "name": record_name, |
| 318 | + "fields": [ |
| 319 | + {"name": REF_ID_KEY, "type": ["null", "long"], "default": None}, |
| 320 | + {"name": REF_DATA_KEY, "type": inner_schema}, |
| 321 | + ], |
| 322 | + } |
| 323 | + if self.namespace: |
| 324 | + record_schema["namespace"] = self.namespace |
| 325 | + return record_schema |
| 326 | + |
301 | 327 |
|
302 | 328 | @register_schema |
303 | 329 | class PrimitiveSchema(Schema): |
@@ -712,19 +738,22 @@ def __init__( |
712 | 738 | args = get_args(py_type) # TODO: validate if args has exactly 1 item? |
713 | 739 | self.items_schema = _schema_obj(args[0], namespace=namespace, options=options) |
714 | 740 |
|
715 | | - def data(self, names: NamesType) -> JSONObj: |
| 741 | + def data(self, names: NamesType) -> JSONType: |
716 | 742 | """Return the schema data""" |
717 | | - return { |
718 | | - "type": "array", |
719 | | - "items": self.items_schema.data(names=names), |
720 | | - } |
| 743 | + array_schema = {"type": "array", "items": self.items_schema.data(names=names)} |
| 744 | + if Option.WRAP_INTO_RECORDS not in self.options: |
| 745 | + return array_schema |
| 746 | + return self._wrap_as_record(array_schema, names) |
721 | 747 |
|
722 | | - def make_default(self, py_default: collections.abc.Sequence) -> JSONArray: |
| 748 | + def make_default(self, py_default: collections.abc.Sequence) -> JSONType: |
723 | 749 | """Return an Avro schema compliant default value for a given Python Sequence |
724 | 750 |
|
725 | 751 | :param py_default: The Python sequence to generate a default value for. |
726 | 752 | """ |
727 | | - return [self.items_schema.make_default(item) for item in py_default] |
| 753 | + list_default = [self.items_schema.make_default(item) for item in py_default] |
| 754 | + if Option.WRAP_INTO_RECORDS in self.options: |
| 755 | + return {REF_ID_KEY: None, REF_DATA_KEY: list_default} |
| 756 | + return list_default |
728 | 757 |
|
729 | 758 |
|
730 | 759 | @register_schema |
@@ -787,12 +816,18 @@ def __init__( |
787 | 816 | raise TypeError(f"Cannot generate Avro mapping schema for Python dictionary {py_type} with non-string keys") |
788 | 817 | self.values_schema = _schema_obj(args[1], namespace=namespace, options=options) |
789 | 818 |
|
790 | | - def data(self, names: NamesType) -> JSONObj: |
| 819 | + def data(self, names: NamesType) -> JSONType: |
791 | 820 | """Return the schema data""" |
792 | | - return { |
793 | | - "type": "map", |
794 | | - "values": self.values_schema.data(names=names), |
795 | | - } |
| 821 | + map_schema = {"type": "map", "values": self.values_schema.data(names=names)} |
| 822 | + if Option.WRAP_INTO_RECORDS not in self.options: |
| 823 | + return map_schema |
| 824 | + return self._wrap_as_record(map_schema, names) |
| 825 | + |
| 826 | + def make_default(self, py_default: Any) -> JSONType: |
| 827 | + """Return an Avro schema compliant default value for a given Python value""" |
| 828 | + if Option.WRAP_INTO_RECORDS in self.options: |
| 829 | + return {REF_ID_KEY: None, REF_DATA_KEY: py_default} |
| 830 | + return py_default |
796 | 831 |
|
797 | 832 |
|
798 | 833 | @register_schema |
@@ -1429,3 +1464,43 @@ def _type_from_annotated(py_type: Type) -> Type: |
1429 | 1464 | return args[0] |
1430 | 1465 | else: |
1431 | 1466 | return py_type |
| 1467 | + |
| 1468 | + |
| 1469 | +def _avro_name_for_type(py_type: Type) -> str: |
| 1470 | + """ |
| 1471 | + Generate an Avro-compatible name for a given Python type. It is used when wrapping container types (mostly lists |
| 1472 | + and maps) into Avro records. |
| 1473 | + It also uses the module name to build the name of the record. Initially, we thought about hashing all the fully |
| 1474 | + qualified names to distinguish `list[ClassA]` from `list[ClassA]` where `ClassA` are separate classes from |
| 1475 | + different modules. As Avro does not seem to have max length for the record name, this seems to be more readable. |
| 1476 | + See `test_avro_name_for_type` test suite. |
| 1477 | + """ |
| 1478 | + py_type = _type_from_annotated(py_type) |
| 1479 | + if py_type is None or py_type is type(None): |
| 1480 | + return "Null" |
| 1481 | + origin = get_origin(py_type) |
| 1482 | + args = get_args(py_type) |
| 1483 | + if inspect.isclass(py_type): |
| 1484 | + if not (name := py_type.__name__): |
| 1485 | + raise TypeNotSupportedError( |
| 1486 | + f"Cannot generate a wrapper record name for Python type {py_type}: empty class name" |
| 1487 | + ) |
| 1488 | + name = name[0].upper() + name[1:] |
| 1489 | + module = py_type.__module__ |
| 1490 | + if module and module != "builtins": |
| 1491 | + mod_prefix = "".join( |
| 1492 | + word[0].upper() + word[1:] for part in module.split(".") for word in part.split("_") if word |
| 1493 | + ) |
| 1494 | + return mod_prefix + name |
| 1495 | + return name |
| 1496 | + if origin is not None and args: |
| 1497 | + union_type = getattr(types, "UnionType", None) |
| 1498 | + if origin is Union or (union_type and origin is union_type): |
| 1499 | + return "Or".join(sorted(_avro_name_for_type(arg) for arg in args)) |
| 1500 | + if _is_class(origin, collections.abc.MutableSet): |
| 1501 | + return _avro_name_for_type(args[0]) + "Set" |
| 1502 | + if _is_class(origin, collections.abc.Sequence): |
| 1503 | + return _avro_name_for_type(args[0]) + "List" |
| 1504 | + if _is_class(origin, collections.abc.Mapping): |
| 1505 | + return _avro_name_for_type(args[1]) + "Map" |
| 1506 | + raise TypeNotSupportedError(f"Cannot generate a wrapper record name for Python type {py_type}") |
0 commit comments