Skip to content

Commit bb4cac3

Browse files
committed
Handle some cases during infer schema from dataclass
* For backward compatibility, only infer schema for frozen dataclasses when it's registered with row coder * Make sure Beam schema ID does not inherit * Fix IndexOutofBoundError trying to infer type from custom Iterable without type hint
1 parent ab56619 commit bb4cac3

5 files changed

Lines changed: 76 additions & 14 deletions

File tree

sdks/python/apache_beam/coders/coder_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def encode_special_deterministic(self, value, stream):
493493
stream.write_byte(PROTO_TYPE)
494494
self.encode_type(type(value), stream)
495495
stream.write(value.SerializePartialToString(deterministic=True), True)
496-
elif dataclasses and dataclasses.is_dataclass(value):
496+
elif dataclasses.is_dataclass(value):
497497
if not type(value).__dataclass_params__.frozen:
498498
raise TypeError(
499499
"Unable to deterministically encode non-frozen '%s' of type '%s' "

sdks/python/apache_beam/typehints/native_type_compatibility.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,28 @@ def match_is_named_tuple(user_type):
176176
hasattr(user_type, '__annotations__') and hasattr(user_type, '_fields'))
177177

178178

179-
def match_is_dataclass(user_type):
180-
return dataclasses.is_dataclass(user_type) and isinstance(user_type, type)
179+
def match_dataclass_for_row(user_type):
180+
"""Match whether the type is a dataclass handled by row coder.
181+
182+
for frozen dataclasses, only true when explicitly registered with row coder:
183+
184+
beam.coders.typecoders.registry.register_coder(
185+
MyDataClass, beam.coders.RowCoder)
186+
"""
187+
if not dataclasses.is_dataclass(user_type):
188+
return False
189+
190+
if not user_type.__dataclass_params__.frozen:
191+
return True
192+
193+
# avoid circular import
194+
# pylint: disable=wrong-import-position
195+
from apache_beam.coders.typecoders import registry as coders_registry
196+
from apache_beam.coders import RowCoder
197+
198+
return (
199+
user_type in coders_registry._coder and
200+
coders_registry._coder[user_type] == RowCoder)
181201

182202

183203
def _match_is_optional(user_type):

sdks/python/apache_beam/typehints/row_type.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from typing import Tuple
2828

2929
from apache_beam.typehints import typehints
30-
from apache_beam.typehints.native_type_compatibility import match_is_dataclass
30+
from apache_beam.typehints.native_type_compatibility import match_dataclass_for_row
3131
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
3232
from apache_beam.typehints.schema_registry import SchemaTypeRegistry
3333

@@ -91,6 +91,9 @@ def __init__(
9191
# Currently registration happens when converting to schema protos, in
9292
# apache_beam.typehints.schemas
9393
self._schema_id = getattr(self._user_type, _BEAM_SCHEMA_ID, None)
94+
if self._schema_id and _BEAM_SCHEMA_ID not in self._user_type.__dict__:
95+
# schema id does not inherit. Unset if schema id is from base class
96+
self._schema_id = None
9497

9598
self._schema_options = schema_options or []
9699
self._field_options = field_options or {}
@@ -105,7 +108,7 @@ def from_user_type(
105108
if match_is_named_tuple(user_type):
106109
fields = [(name, user_type.__annotations__[name])
107110
for name in user_type._fields]
108-
elif match_is_dataclass(user_type):
111+
elif match_dataclass_for_row(user_type):
109112
fields = [(field.name, field.type)
110113
for field in dataclasses.fields(user_type)]
111114
else:

sdks/python/apache_beam/typehints/row_type_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from apache_beam.testing.util import assert_that
2727
from apache_beam.testing.util import equal_to
2828
from apache_beam.typehints import row_type
29+
from apache_beam.typehints import schemas
2930

3031

3132
class RowTypeTest(unittest.TestCase):
@@ -85,6 +86,38 @@ def generate(num: int):
8586
| 'Count Elements' >> beam.Map(self._check_key_type_and_count))
8687
assert_that(result, equal_to([10] * 100))
8788

89+
def test_derived_dataclass_schema_id(self):
90+
@dataclass
91+
class BaseDataClass:
92+
id: int
93+
94+
@dataclass
95+
class DerivedDataClass(BaseDataClass):
96+
name: str
97+
98+
self.assertFalse(hasattr(BaseDataClass, row_type._BEAM_SCHEMA_ID))
99+
schema_for_base = schemas.schema_from_element_type(BaseDataClass)
100+
self.assertTrue(hasattr(BaseDataClass, row_type._BEAM_SCHEMA_ID))
101+
self.assertEqual(
102+
schema_for_base.id, getattr(BaseDataClass, row_type._BEAM_SCHEMA_ID))
103+
104+
# Getting the schema for BaseDataClass sets the _beam_schema_id
105+
schemas.typing_to_runner_api(
106+
BaseDataClass, schema_registry=schemas.SchemaTypeRegistry())
107+
108+
# We create a RowTypeConstraint from DerivedDataClass.
109+
# It should not inherit the _beam_schema_id from BaseDataClass!
110+
derived_row_type = row_type.RowTypeConstraint.from_user_type(
111+
DerivedDataClass)
112+
self.assertIsNone(derived_row_type._schema_id)
113+
114+
schema_for_derived = schemas.schema_from_element_type(DerivedDataClass)
115+
self.assertTrue(hasattr(DerivedDataClass, row_type._BEAM_SCHEMA_ID))
116+
self.assertEqual(
117+
schema_for_derived.id,
118+
getattr(DerivedDataClass, row_type._BEAM_SCHEMA_ID))
119+
self.assertNotEqual(schema_for_derived.id, schema_for_base.id)
120+
88121

89122
if __name__ == '__main__':
90123
unittest.main()

sdks/python/apache_beam/typehints/schemas.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@
9696
from apache_beam.typehints.native_type_compatibility import _safe_issubclass
9797
from apache_beam.typehints.native_type_compatibility import convert_to_python_type
9898
from apache_beam.typehints.native_type_compatibility import extract_optional_type
99-
from apache_beam.typehints.native_type_compatibility import match_is_dataclass
99+
from apache_beam.typehints.native_type_compatibility import match_dataclass_for_row
100100
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
101101
from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY
102102
from apache_beam.typehints.schema_registry import SchemaTypeRegistry
@@ -335,19 +335,23 @@ def typing_to_runner_api(self, type_: type) -> schema_pb2.FieldType:
335335
atomic_type=PRIMITIVE_TO_ATOMIC_TYPE[int])))
336336

337337
elif _safe_issubclass(type_, Sequence) and not _safe_issubclass(type_, str):
338-
element_type = self.typing_to_runner_api(_get_args(type_)[0])
339-
return schema_pb2.FieldType(
340-
array_type=schema_pb2.ArrayType(element_type=element_type))
338+
arg_types = _get_args(type_)
339+
if len(arg_types) > 0:
340+
element_type = self.typing_to_runner_api(arg_types[0])
341+
return schema_pb2.FieldType(
342+
array_type=schema_pb2.ArrayType(element_type=element_type))
341343

342344
elif _safe_issubclass(type_, Mapping):
343345
key_type, value_type = map(self.typing_to_runner_api, _get_args(type_))
344346
return schema_pb2.FieldType(
345347
map_type=schema_pb2.MapType(key_type=key_type, value_type=value_type))
346348

347349
elif _safe_issubclass(type_, Iterable) and not _safe_issubclass(type_, str):
348-
element_type = self.typing_to_runner_api(_get_args(type_)[0])
349-
return schema_pb2.FieldType(
350-
array_type=schema_pb2.ArrayType(element_type=element_type))
350+
arg_types = _get_args(type_)
351+
if len(arg_types) > 0:
352+
element_type = self.typing_to_runner_api(arg_types[0])
353+
return schema_pb2.FieldType(
354+
array_type=schema_pb2.ArrayType(element_type=element_type))
351355

352356
try:
353357
if LogicalType.is_known_logical_type(type_):
@@ -630,8 +634,10 @@ def schema_from_element_type(element_type: type) -> schema_pb2.Schema:
630634
Returns schema as a list of (name, python_type) tuples"""
631635
if isinstance(element_type, row_type.RowTypeConstraint):
632636
return named_fields_to_schema(element_type._fields)
633-
elif match_is_named_tuple(element_type) or match_is_dataclass(element_type):
634-
if hasattr(element_type, row_type._BEAM_SCHEMA_ID):
637+
elif match_is_named_tuple(element_type) or match_dataclass_for_row(
638+
element_type):
639+
# schema id does not inherit from base classes
640+
if row_type._BEAM_SCHEMA_ID in element_type.__dict__:
635641
# if the named tuple's schema is in registry, we just use it instead of
636642
# regenerating one.
637643
schema_id = getattr(element_type, row_type._BEAM_SCHEMA_ID)

0 commit comments

Comments
 (0)