Skip to content

Commit 49fa4dc

Browse files
committed
Allow non-frozen dataclass register with other coders as a backup for backward compatibility; add tests
1 parent ed6e34b commit 49fa4dc

3 files changed

Lines changed: 59 additions & 12 deletions

File tree

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
## New Features / Improvements
7070

7171
* Added support for large pipeline options via a file (Python) ([#37370](https://github.com/apache/beam/issues/37370)).
72+
* Supported infer schema from dataclass (Python) ([#22085](https://github.com/apache/beam/issues/37370)). Default coder for typehint-ed (or set with_output_type) for non-frozen dataclasses changed to RowCoder. To preserve the old behavior (fast primitive coder), explicitly register the type with FastPrimitiveCoder.
7273

7374
## Breaking Changes
7475

sdks/python/apache_beam/typehints/native_type_compatibility.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -179,27 +179,38 @@ def match_is_named_tuple(user_type):
179179
def match_dataclass_for_row(user_type):
180180
"""Match whether the type is a dataclass handled by row coder.
181181
182-
for frozen dataclasses, only true when explicitly registered with row coder:
182+
For frozen dataclasses, only true when explicitly registered with row coder:
183183
184184
beam.coders.typecoders.registry.register_coder(
185185
MyDataClass, beam.coders.RowCoder)
186+
187+
(for backward-compatibility reason).
188+
189+
For non-frozen dataclasses, default to true otherwise explicitly registered
190+
with a coder other than the row coder.
186191
"""
192+
187193
if not dataclasses.is_dataclass(user_type):
188194
return False
189195

190-
if not user_type.__dataclass_params__.frozen:
191-
return True
192-
196+
is_frozen = user_type.__dataclass_params__.frozen
193197
# avoid circular import
194-
# pylint: disable=wrong-import-position
195-
from apache_beam.coders.typecoders import registry as coders_registry
196-
from apache_beam.coders import RowCoder
198+
try:
199+
# pylint: disable=wrong-import-position
200+
from apache_beam.coders.typecoders import registry as coders_registry
201+
from apache_beam.coders import RowCoder
202+
except AttributeError:
203+
# coder registery not yet initialized so it must be absent
204+
return not is_frozen
197205

198-
# check _coders (not get_coder) to get the registered coder directly without
199-
# fallback
200-
return (
201-
user_type in coders_registry._coders and
202-
coders_registry._coders[user_type] == RowCoder)
206+
if is_frozen:
207+
return (
208+
user_type in coders_registry._coders and
209+
coders_registry._coders[user_type] == RowCoder)
210+
else:
211+
return (
212+
user_type not in coders_registry._coders or
213+
coders_registry._coders[user_type] == RowCoder)
203214

204215

205216
def _match_is_optional(user_type):

sdks/python/apache_beam/typehints/native_type_compatibility_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# pytype: skip-file
2121

2222
import collections.abc
23+
import dataclasses
2324
import enum
2425
import re
2526
import typing
@@ -33,6 +34,7 @@
3334
from apache_beam.typehints.native_type_compatibility import convert_to_python_types
3435
from apache_beam.typehints.native_type_compatibility import convert_typing_to_builtin
3536
from apache_beam.typehints.native_type_compatibility import is_any
37+
from apache_beam.typehints.native_type_compatibility import match_dataclass_for_row
3638

3739
_TestNamedTuple = typing.NamedTuple(
3840
'_TestNamedTuple', [('age', int), ('name', bytes)])
@@ -509,6 +511,39 @@ def test_type_alias_type_unwrapped(self):
509511
self.assertEqual(
510512
typehints.Tuple[int, ...], convert_to_beam_type(AliasTuple))
511513

514+
def test_dataclass_default(self):
515+
@dataclasses.dataclass(frozen=True)
516+
class FrozenDC:
517+
foo: int
518+
519+
@dataclasses.dataclass
520+
class NonFrozenDC:
521+
foo: int
522+
523+
self.assertFalse(match_dataclass_for_row(FrozenDC))
524+
self.assertTrue(match_dataclass_for_row(NonFrozenDC))
525+
526+
def test_dataclass_registered(self):
527+
@dataclasses.dataclass(frozen=True)
528+
class FrozenRegisteredDC:
529+
foo: int
530+
531+
@dataclasses.dataclass
532+
class NonFrozenRegisteredDC:
533+
foo: int
534+
535+
# pylint: disable=wrong-import-position
536+
from apache_beam.coders import typecoders
537+
from apache_beam.coders.coders import FastPrimitivesCoder
538+
from apache_beam.coders import RowCoder
539+
540+
typecoders.registry.register_coder(FrozenRegisteredDC, RowCoder)
541+
typecoders.registry.register_coder(
542+
NonFrozenRegisteredDC, FastPrimitivesCoder)
543+
544+
self.assertTrue(match_dataclass_for_row(FrozenRegisteredDC))
545+
self.assertFalse(match_dataclass_for_row(NonFrozenRegisteredDC))
546+
512547

513548
if __name__ == '__main__':
514549
unittest.main()

0 commit comments

Comments
 (0)