|
77 | 77 | SYMBOL_REGEX = re.compile(r"[A-Za-z_][A-Za-z0-9_]*") |
78 | 78 |
|
79 | 79 |
|
| 80 | +class TDMissingMarker(str): |
| 81 | + """ |
| 82 | + Custom Typed Dict missing marker to indicate values that are in the annotations but not present at runtime. |
| 83 | + We are using a custom subclass string type to be able to differentiate them when creating schemas. |
| 84 | + See `py_avro_schema._schemas.TypedDictSchema._record_field` and `UnionSchema._validate_union` |
| 85 | + """ |
| 86 | + |
| 87 | + ... |
| 88 | + |
| 89 | + |
| 90 | +TD_MISSING_MARKER = TDMissingMarker("__td_missing__") |
| 91 | + |
| 92 | + |
80 | 93 | class TypeNotSupportedError(TypeError): |
81 | 94 | """Error raised when a Avro schema cannot be generated for a given Python type""" |
82 | 95 |
|
@@ -312,20 +325,14 @@ def _wrap_as_record(self, inner_schema: JSONObj, names: NamesType) -> JSONType: |
312 | 325 | if fullname in names: |
313 | 326 | return fullname |
314 | 327 | names.append(fullname) |
315 | | - |
316 | | - fields = [ |
317 | | - {"name": REF_ID_KEY, "type": ["null", "long"], "default": None}, |
318 | | - {"name": REF_DATA_KEY, "type": inner_schema}, |
319 | | - ] |
320 | | - if Option.ADD_RUNTIME_TYPE_FIELD in self.options: |
321 | | - fields.append({"name": RUNTIME_TYPE_KEY, "type": ["null", "string"]}) |
322 | | - |
323 | 328 | record_schema = { |
324 | 329 | "type": "record", |
325 | 330 | "name": record_name, |
326 | | - "fields": fields, |
| 331 | + "fields": [ |
| 332 | + {"name": REF_ID_KEY, "type": ["null", "long"], "default": None}, |
| 333 | + {"name": REF_DATA_KEY, "type": inner_schema}, |
| 334 | + ], |
327 | 335 | } |
328 | | - |
329 | 336 | if self.namespace: |
330 | 337 | record_schema["namespace"] = self.namespace |
331 | 338 | return record_schema |
@@ -864,8 +871,34 @@ def __init__(self, py_type: Type[Union[Any]], namespace: Optional[str] = None, o |
864 | 871 | super().__init__(py_type, namespace=namespace, options=options) |
865 | 872 | py_type = _type_from_annotated(py_type) |
866 | 873 | args = get_args(py_type) |
| 874 | + self._validate_union(args) |
867 | 875 | self.item_schemas = [_schema_obj(arg, namespace=namespace, options=options) for arg in args] |
868 | 876 |
|
| 877 | + @staticmethod |
| 878 | + def _validate_union(args: tuple[Any, ...]) -> None: |
| 879 | + """ |
| 880 | + Validate that the arguments of the Union are possible to deal with. At runtime, we cannot get the runtime type |
| 881 | + of TypedDict instances, as they are just regular dicts. |
| 882 | + Same for sequences like List and Set, we would have to scan them to know all the runtime types of the values |
| 883 | + they contain. |
| 884 | + :param args: list of types of the Union |
| 885 | + :return: None |
| 886 | + :raises: TypeError if the Union types are invalid |
| 887 | + """ |
| 888 | + if type(None) not in args and TDMissingMarker not in args: |
| 889 | + if any( |
| 890 | + # Enum is treated as a Sequence |
| 891 | + not EnumSchema.handles_type(arg) |
| 892 | + and ( |
| 893 | + is_typeddict(arg) |
| 894 | + or SequenceSchema.handles_type(arg) |
| 895 | + or DictSchema.handles_type(arg) |
| 896 | + or SetSchema.handles_type(arg) |
| 897 | + ) |
| 898 | + for arg in args |
| 899 | + ): |
| 900 | + raise TypeError(f"Union of types {args} is not supported. Python cannot detect proper type at runtime") |
| 901 | + |
869 | 902 | def data(self, names: NamesType) -> JSONType: |
870 | 903 | """Return the schema data""" |
871 | 904 | # Render the item schemas |
@@ -1302,23 +1335,13 @@ def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Opti |
1302 | 1335 | self.py_fields: list[tuple[str, type]] = [] |
1303 | 1336 | for k, v in type_hints.items(): |
1304 | 1337 | self.py_fields.append((k, v)) |
1305 | | - # We store __init__ parameters with default values. They can be used as defaults for the record. |
1306 | | - self.signature_fields = { |
1307 | | - param.name: (param.annotation, param.default) |
1308 | | - for param in list(inspect.signature(py_type.__init__).parameters.values())[1:] |
1309 | | - if param.default is not inspect._empty |
1310 | | - } |
1311 | 1338 | self.record_fields = [self._record_field(field) for field in self.py_fields] |
1312 | 1339 |
|
1313 | 1340 | def _record_field(self, py_field: tuple[str, Type]) -> RecordField: |
1314 | 1341 | """Return an Avro record field object for a given Python instance attribute""" |
1315 | 1342 | aliases, actual_type = get_field_aliases_and_actual_type(py_field[1]) |
1316 | 1343 | name = py_field[0] |
1317 | 1344 | default = dataclasses.MISSING |
1318 | | - if field := self.signature_fields.get(name): |
1319 | | - _annotation, _default = field |
1320 | | - if actual_type is _annotation: |
1321 | | - default = _default or dataclasses.MISSING |
1322 | 1345 | field_obj = RecordField( |
1323 | 1346 | py_type=actual_type, |
1324 | 1347 | name=name, |
@@ -1370,15 +1393,15 @@ def _record_field(self, py_field: tuple[str, Type]) -> RecordField: |
1370 | 1393 | # be able to distinguish between the fields that are missing from the ones that are present but set to None. |
1371 | 1394 | # To do that, we extend the original type with str. We will later add a special string |
1372 | 1395 | # (e.g., __td_missing__) as a marker at deserialization time. |
1373 | | - actual_type = Union[actual_type, str] # type: ignore |
| 1396 | + actual_type = Union[actual_type, TDMissingMarker] # type: ignore |
1374 | 1397 | if _is_optional(actual_type): |
1375 | 1398 | # Note: this works since this schema does not implement `make_default` and the base implementation |
1376 | 1399 | # simply return the provided type (None in this case). |
1377 | | - default = "__td_missing__" # type: ignore |
| 1400 | + default = TD_MISSING_MARKER # type: ignore |
1378 | 1401 | elif _is_not_required(actual_type): |
1379 | 1402 | # A field can be marked with typing.NotRequired even in a TypedDict with is not marked with total=False. |
1380 | 1403 | # Similarly as above, we extend the wrapped type with string. |
1381 | | - actual_type = Union[_unwrap_not_required(actual_type), str] # type: ignore |
| 1404 | + actual_type = Union[_unwrap_not_required(actual_type), TDMissingMarker] # type: ignore |
1382 | 1405 |
|
1383 | 1406 | field_obj = RecordField( |
1384 | 1407 | py_type=actual_type, |
|
0 commit comments