Skip to content

Commit 3e5a085

Browse files
authored
Support infer types involving dataclass fields (#38548)
1 parent bdef23f commit 3e5a085

3 files changed

Lines changed: 39 additions & 0 deletions

File tree

sdks/python/apache_beam/typehints/opcodes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"""
3030
# pytype: skip-file
3131

32+
import dataclasses
3233
import inspect
3334
import logging
3435
import sys
@@ -447,6 +448,11 @@ def _getattr(o, name):
447448
return Const(BoundMethod(func, o))
448449
elif isinstance(o, row_type.RowTypeConstraint):
449450
return o.get_type_for(name)
451+
elif inspect.isclass(o) and dataclasses.is_dataclass(o):
452+
field = o.__dataclass_fields__.get(name)
453+
if field is not None:
454+
return field.type
455+
return Any
450456
else:
451457
return Any
452458

sdks/python/apache_beam/typehints/row_type_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,27 @@ class DerivedDataClass(BaseDataClass):
172172
getattr(DerivedDataClass, row_type._BEAM_SCHEMA_ID))
173173
self.assertNotEqual(schema_for_derived.id, schema_for_base.id)
174174

175+
def test_dataclass_map_typehints(self):
176+
@beam.coders.typecoders.registry.register_row
177+
@dataclass(frozen=True)
178+
class MyDataClass:
179+
id: int
180+
name: str
181+
182+
p = beam.Pipeline()
183+
pa = (p | beam.Create([MyDataClass(1, "a"), MyDataClass(2, "b")]))
184+
self.assertEqual(pa.element_type, MyDataClass)
185+
186+
pb = (
187+
pa | beam.Map(
188+
lambda x: beam.Row(id=x.id, name=x.name, name_hash=hash(x.name))))
189+
self.assertTrue(
190+
isinstance(pb.element_type, row_type.GeneratedClassRowTypeConstraint))
191+
self.assertEqual(
192+
pb.element_type,
193+
row_type.GeneratedClassRowTypeConstraint(
194+
fields=[('id', int), ('name', str), ('name_hash', int)]))
195+
175196

176197
if __name__ == '__main__':
177198
unittest.main()

sdks/python/apache_beam/typehints/trivial_inference_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
# pytype: skip-file
2121

22+
import dataclasses
2223
import types
2324
import unittest
2425

@@ -487,6 +488,17 @@ def testPyCallable(self):
487488
python_callable.PythonCallableWithSource("lambda x: (x, str(x))"),
488489
[int])
489490

491+
def testDataClassFields(self):
492+
@dataclasses.dataclass
493+
class MyDataClass:
494+
id: int
495+
name: str
496+
497+
self.assertReturnType(
498+
typehints.Tuple[int, str],
499+
python_callable.PythonCallableWithSource("lambda x: (x.id, x.name)"),
500+
[MyDataClass])
501+
490502

491503
if __name__ == '__main__':
492504
unittest.main()

0 commit comments

Comments
 (0)