Skip to content

Commit 4099aa8

Browse files
committed
Support inferring schemas from Python dataclasses
1 parent 5a6f763 commit 4099aa8

4 files changed

Lines changed: 93 additions & 1 deletion

File tree

sdks/python/apache_beam/typehints/native_type_compatibility.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import sys
2626
import types
2727
import typing
28+
import dataclasses
2829
from typing import Generic
2930
from typing import TypeVar
3031

@@ -175,6 +176,10 @@ def match_is_named_tuple(user_type):
175176
hasattr(user_type, '__annotations__') and hasattr(user_type, '_fields'))
176177

177178

179+
def match_is_dataclass(user_type):
180+
return dataclasses.is_dataclass(user_type) and isinstance(user_type, type)
181+
182+
178183
def _match_is_optional(user_type):
179184
return _match_is_union(user_type) and sum(
180185
tp is type(None) for tp in _get_args(user_type)) == 1
@@ -418,6 +423,7 @@ def convert_to_beam_type(typ):
418423
# This MUST appear before the entry for the normal Tuple.
419424
_TypeMapEntry(
420425
match=match_is_named_tuple, arity=0, beam_type=typehints.Any),
426+
_TypeMapEntry(match=match_is_dataclass, arity=0, beam_type=typehints.Any),
421427
_TypeMapEntry(
422428
match=_match_is_primitive(tuple), arity=-1,
423429
beam_type=typehints.Tuple),

sdks/python/apache_beam/typehints/row_type.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from typing import Tuple
2727

2828
from apache_beam.typehints import typehints
29+
from apache_beam.typehints.native_type_compatibility import match_is_dataclass
2930
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
3031
from apache_beam.typehints.schema_registry import SchemaTypeRegistry
3132

@@ -127,6 +128,30 @@ def from_user_type(
127128
field_options=field_options,
128129
field_descriptions=field_descriptions)
129130

131+
if match_is_dataclass(user_type):
132+
import dataclasses
133+
fields = [(field.name, field.type)
134+
for field in dataclasses.fields(user_type)]
135+
136+
field_descriptions = getattr(user_type, '_field_descriptions', None)
137+
138+
if _user_type_is_generated(user_type):
139+
return RowTypeConstraint.from_fields(
140+
fields,
141+
schema_id=getattr(user_type, _BEAM_SCHEMA_ID),
142+
schema_options=schema_options,
143+
field_options=field_options,
144+
field_descriptions=field_descriptions)
145+
146+
# TODO(https://github.com/apache/beam/issues/22125): Add user API for
147+
# specifying schema/field options
148+
return RowTypeConstraint(
149+
fields=fields,
150+
user_type=user_type,
151+
schema_options=schema_options,
152+
field_options=field_options,
153+
field_descriptions=field_descriptions)
154+
130155
return None
131156

132157
@staticmethod

sdks/python/apache_beam/typehints/schemas.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
from apache_beam.typehints.native_type_compatibility import _safe_issubclass
9797
from apache_beam.typehints.native_type_compatibility import convert_to_python_type
9898
from apache_beam.typehints.native_type_compatibility import extract_optional_type
99+
from apache_beam.typehints.native_type_compatibility import match_is_dataclass
99100
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
100101
from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY
101102
from apache_beam.typehints.schema_registry import SchemaTypeRegistry
@@ -629,7 +630,7 @@ def schema_from_element_type(element_type: type) -> schema_pb2.Schema:
629630
Returns schema as a list of (name, python_type) tuples"""
630631
if isinstance(element_type, row_type.RowTypeConstraint):
631632
return named_fields_to_schema(element_type._fields)
632-
elif match_is_named_tuple(element_type):
633+
elif match_is_named_tuple(element_type) or match_is_dataclass(element_type):
633634
if hasattr(element_type, row_type._BEAM_SCHEMA_ID):
634635
# if the named tuple's schema is in registry, we just use it instead of
635636
# regenerating one.

sdks/python/apache_beam/typehints/schemas_test.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
# pytype: skip-file
2121

22+
import dataclasses
2223
import itertools
2324
import pickle
2425
import unittest
@@ -388,6 +389,23 @@ def test_namedtuple_roundtrip(self, user_type):
388389
self.assertIsInstance(roundtripped, row_type.RowTypeConstraint)
389390
self.assert_namedtuple_equivalent(roundtripped.user_type, user_type)
390391

392+
def test_dataclass_roundtrip(self):
393+
@dataclasses.dataclass
394+
class SimpleDataclass:
395+
id: np.int64
396+
name: str
397+
398+
roundtripped = typing_from_runner_api(
399+
typing_to_runner_api(
400+
SimpleDataclass, schema_registry=SchemaTypeRegistry()),
401+
schema_registry=SchemaTypeRegistry())
402+
403+
self.assertIsInstance(roundtripped, row_type.RowTypeConstraint)
404+
# The roundtripped user_type is generated as a NamedTuple, so we can't test equivalence directly with the dataclass.
405+
# Instead, let's verify annotations.
406+
self.assertEqual(
407+
roundtripped.user_type.__annotations__, SimpleDataclass.__annotations__)
408+
391409
def test_row_type_constraint_to_schema(self):
392410
result_type = typing_to_runner_api(
393411
row_type.RowTypeConstraint.from_fields([
@@ -646,6 +664,48 @@ def test_trivial_example(self):
646664
expected.row_type.schema.fields,
647665
typing_to_runner_api(MyCuteClass).row_type.schema.fields)
648666

667+
def test_trivial_example_dataclass(self):
668+
@dataclasses.dataclass
669+
class MyCuteDataclass:
670+
name: str
671+
age: Optional[int]
672+
interests: List[str]
673+
height: float
674+
blob: ByteString
675+
676+
expected = schema_pb2.FieldType(
677+
row_type=schema_pb2.RowType(
678+
schema=schema_pb2.Schema(
679+
fields=[
680+
schema_pb2.Field(
681+
name='name',
682+
type=schema_pb2.FieldType(
683+
atomic_type=schema_pb2.STRING),
684+
),
685+
schema_pb2.Field(
686+
name='age',
687+
type=schema_pb2.FieldType(
688+
nullable=True, atomic_type=schema_pb2.INT64)),
689+
schema_pb2.Field(
690+
name='interests',
691+
type=schema_pb2.FieldType(
692+
array_type=schema_pb2.ArrayType(
693+
element_type=schema_pb2.FieldType(
694+
atomic_type=schema_pb2.STRING)))),
695+
schema_pb2.Field(
696+
name='height',
697+
type=schema_pb2.FieldType(
698+
atomic_type=schema_pb2.DOUBLE)),
699+
schema_pb2.Field(
700+
name='blob',
701+
type=schema_pb2.FieldType(
702+
atomic_type=schema_pb2.BYTES)),
703+
])))
704+
705+
self.assertEqual(
706+
expected.row_type.schema.fields,
707+
typing_to_runner_api(MyCuteDataclass).row_type.schema.fields)
708+
649709
def test_user_type_annotated_with_id_after_conversion(self):
650710
MyCuteClass = NamedTuple('MyCuteClass', [
651711
('name', str),

0 commit comments

Comments
 (0)