Skip to content

Commit 0aa835d

Browse files
committed
Support inferring schemas from Python dataclasses
1 parent 5a6f763 commit 0aa835d

5 files changed

Lines changed: 97 additions & 12 deletions

File tree

sdks/python/apache_beam/coders/coder_impl.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"""
3131
# pytype: skip-file
3232

33+
import dataclasses
3334
import decimal
3435
import enum
3536
import itertools
@@ -67,11 +68,6 @@
6768
from apache_beam.utils.timestamp import MIN_TIMESTAMP
6869
from apache_beam.utils.timestamp import Timestamp
6970

70-
try:
71-
import dataclasses
72-
except ImportError:
73-
dataclasses = None # type: ignore
74-
7571
try:
7672
import dill
7773
except ImportError:

sdks/python/apache_beam/typehints/native_type_compatibility.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import collections
2323
import collections.abc
24+
import dataclasses
2425
import logging
2526
import sys
2627
import types
@@ -175,6 +176,10 @@ def match_is_named_tuple(user_type):
175176
hasattr(user_type, '__annotations__') and hasattr(user_type, '_fields'))
176177

177178

179+
def match_is_dataclass(user_type):
180+
return dataclasses.is_dataclass(user_type) and isinstance(user_type, type)
181+
182+
178183
def _match_is_optional(user_type):
179184
return _match_is_union(user_type) and sum(
180185
tp is type(None) for tp in _get_args(user_type)) == 1
@@ -418,6 +423,7 @@ def convert_to_beam_type(typ):
418423
# This MUST appear before the entry for the normal Tuple.
419424
_TypeMapEntry(
420425
match=match_is_named_tuple, arity=0, beam_type=typehints.Any),
426+
_TypeMapEntry(match=match_is_dataclass, arity=0, beam_type=typehints.Any),
421427
_TypeMapEntry(
422428
match=_match_is_primitive(tuple), arity=-1,
423429
beam_type=typehints.Tuple),

sdks/python/apache_beam/typehints/row_type.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919

2020
from __future__ import annotations
2121

22+
import dataclasses
2223
from typing import Any
2324
from typing import Dict
2425
from typing import Optional
2526
from typing import Sequence
2627
from typing import Tuple
2728

2829
from apache_beam.typehints import typehints
30+
from apache_beam.typehints.native_type_compatibility import match_is_dataclass
2931
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
3032
from apache_beam.typehints.schema_registry import SchemaTypeRegistry
3133

@@ -56,18 +58,14 @@ def __init__(
5658
for guidance on creating PCollections with inferred schemas.
5759
5860
Note RowTypeConstraint does not currently store arbitrary functions for
59-
converting to/from the user type. Instead, we only support ``NamedTuple``
60-
user types and make the follow assumptions:
61+
converting to/from the user type. Instead, we support ``NamedTuple`` and
62+
``dataclasses`` user types and make the follow assumptions:
6163
6264
- The user type can be constructed with field values as arguments in order
6365
(i.e. ``constructor(*field_values)``).
6466
- Field values can be accessed from instances of the user type by attribute
6567
(i.e. with ``getattr(obj, field_name)``).
6668
67-
In the future we will add support for dataclasses
68-
([#22085](https://github.com/apache/beam/issues/22085)) which also satisfy
69-
these assumptions.
70-
7169
The RowTypeConstraint constructor should not be called directly (even
7270
internally to Beam). Prefer static methods ``from_user_type`` or
7371
``from_fields``.
@@ -127,6 +125,29 @@ def from_user_type(
127125
field_options=field_options,
128126
field_descriptions=field_descriptions)
129127

128+
if match_is_dataclass(user_type):
129+
fields = [(field.name, field.type)
130+
for field in dataclasses.fields(user_type)]
131+
132+
field_descriptions = getattr(user_type, '_field_descriptions', None)
133+
134+
if _user_type_is_generated(user_type):
135+
return RowTypeConstraint.from_fields(
136+
fields,
137+
schema_id=getattr(user_type, _BEAM_SCHEMA_ID),
138+
schema_options=schema_options,
139+
field_options=field_options,
140+
field_descriptions=field_descriptions)
141+
142+
# TODO(https://github.com/apache/beam/issues/22125): Add user API for
143+
# specifying schema/field options
144+
return RowTypeConstraint(
145+
fields=fields,
146+
user_type=user_type,
147+
schema_options=schema_options,
148+
field_options=field_options,
149+
field_descriptions=field_descriptions)
150+
130151
return None
131152

132153
@staticmethod

sdks/python/apache_beam/typehints/schemas.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
from apache_beam.typehints.native_type_compatibility import _safe_issubclass
9797
from apache_beam.typehints.native_type_compatibility import convert_to_python_type
9898
from apache_beam.typehints.native_type_compatibility import extract_optional_type
99+
from apache_beam.typehints.native_type_compatibility import match_is_dataclass
99100
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
100101
from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY
101102
from apache_beam.typehints.schema_registry import SchemaTypeRegistry
@@ -629,7 +630,7 @@ def schema_from_element_type(element_type: type) -> schema_pb2.Schema:
629630
Returns schema as a list of (name, python_type) tuples"""
630631
if isinstance(element_type, row_type.RowTypeConstraint):
631632
return named_fields_to_schema(element_type._fields)
632-
elif match_is_named_tuple(element_type):
633+
elif match_is_named_tuple(element_type) or match_is_dataclass(element_type):
633634
if hasattr(element_type, row_type._BEAM_SCHEMA_ID):
634635
# if the named tuple's schema is in registry, we just use it instead of
635636
# regenerating one.

sdks/python/apache_beam/typehints/schemas_test.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
# pytype: skip-file
2121

22+
import dataclasses
2223
import itertools
2324
import pickle
2425
import unittest
@@ -388,6 +389,24 @@ def test_namedtuple_roundtrip(self, user_type):
388389
self.assertIsInstance(roundtripped, row_type.RowTypeConstraint)
389390
self.assert_namedtuple_equivalent(roundtripped.user_type, user_type)
390391

392+
def test_dataclass_roundtrip(self):
393+
@dataclasses.dataclass
394+
class SimpleDataclass:
395+
id: np.int64
396+
name: str
397+
398+
roundtripped = typing_from_runner_api(
399+
typing_to_runner_api(
400+
SimpleDataclass, schema_registry=SchemaTypeRegistry()),
401+
schema_registry=SchemaTypeRegistry())
402+
403+
self.assertIsInstance(roundtripped, row_type.RowTypeConstraint)
404+
# The roundtripped user_type is generated as a NamedTuple, so we can't test
405+
# equivalence directly with the dataclass.
406+
# Instead, let's verify annotations.
407+
self.assertEqual(
408+
roundtripped.user_type.__annotations__, SimpleDataclass.__annotations__)
409+
391410
def test_row_type_constraint_to_schema(self):
392411
result_type = typing_to_runner_api(
393412
row_type.RowTypeConstraint.from_fields([
@@ -646,6 +665,48 @@ def test_trivial_example(self):
646665
expected.row_type.schema.fields,
647666
typing_to_runner_api(MyCuteClass).row_type.schema.fields)
648667

668+
def test_trivial_example_dataclass(self):
669+
@dataclasses.dataclass
670+
class MyCuteDataclass:
671+
name: str
672+
age: Optional[int]
673+
interests: List[str]
674+
height: float
675+
blob: ByteString
676+
677+
expected = schema_pb2.FieldType(
678+
row_type=schema_pb2.RowType(
679+
schema=schema_pb2.Schema(
680+
fields=[
681+
schema_pb2.Field(
682+
name='name',
683+
type=schema_pb2.FieldType(
684+
atomic_type=schema_pb2.STRING),
685+
),
686+
schema_pb2.Field(
687+
name='age',
688+
type=schema_pb2.FieldType(
689+
nullable=True, atomic_type=schema_pb2.INT64)),
690+
schema_pb2.Field(
691+
name='interests',
692+
type=schema_pb2.FieldType(
693+
array_type=schema_pb2.ArrayType(
694+
element_type=schema_pb2.FieldType(
695+
atomic_type=schema_pb2.STRING)))),
696+
schema_pb2.Field(
697+
name='height',
698+
type=schema_pb2.FieldType(
699+
atomic_type=schema_pb2.DOUBLE)),
700+
schema_pb2.Field(
701+
name='blob',
702+
type=schema_pb2.FieldType(
703+
atomic_type=schema_pb2.BYTES)),
704+
])))
705+
706+
self.assertEqual(
707+
expected.row_type.schema.fields,
708+
typing_to_runner_api(MyCuteDataclass).row_type.schema.fields)
709+
649710
def test_user_type_annotated_with_id_after_conversion(self):
650711
MyCuteClass = NamedTuple('MyCuteClass', [
651712
('name', str),

0 commit comments

Comments
 (0)