Skip to content

Commit 63fce82

Browse files
committed
Fix #37862: fixed named tuple and effectively fails dataclass inside union typehint
1 parent f54be68 commit 63fce82

2 files changed

Lines changed: 64 additions & 1 deletion

File tree

sdks/python/apache_beam/typehints/row_type_test.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,62 @@ def generate(num: int):
8686
| 'Count Elements' >> beam.Map(self._check_key_type_and_count))
8787
assert_that(result, equal_to([10] * 100))
8888

89+
def test_group_by_key_namedtuple_union(self):
90+
Tuple1 = typing.NamedTuple("Tuple1", [("id", int)])
91+
92+
Tuple2 = typing.NamedTuple("Tuple2", [("id", int), ("name", str)])
93+
94+
def generate(num: int):
95+
for i in range(2):
96+
yield (Tuple1(i), num)
97+
yield (Tuple2(i, 'a'), num)
98+
99+
pipeline = TestPipeline(is_integration_test=False)
100+
101+
with pipeline as p:
102+
result = (
103+
p
104+
| 'Create' >> beam.Create([i for i in range(2)])
105+
| 'Generate' >> beam.ParDo(generate).with_output_types(
106+
tuple[(Tuple1 | Tuple2), int])
107+
| 'GBK' >> beam.GroupByKey()
108+
| 'Count' >> beam.Map(lambda x: len(x[1])))
109+
assert_that(result, equal_to([2] * 4))
110+
111+
# Union of dataclasses as type hint currently result in FastPrimitiveCoder
112+
# fails at GBK
113+
@unittest.skip("https://github.com/apache/beam/issues/22085")
114+
def test_group_by_key_inherited_dataclass(self):
115+
@dataclass
116+
class DataClassInt:
117+
id: int
118+
119+
@dataclass
120+
class DataClassStr(DataClassInt):
121+
name: str
122+
123+
beam.coders.typecoders.registry.register_coder(
124+
DataClassInt, beam.coders.RowCoder)
125+
beam.coders.typecoders.registry.register_coder(
126+
DataClassStr, beam.coders.RowCoder)
127+
128+
def generate(num: int):
129+
for i in range(10):
130+
yield (DataClassInt(i), num)
131+
yield (DataClassStr(i, 'a'), num)
132+
133+
pipeline = TestPipeline(is_integration_test=False)
134+
135+
with pipeline as p:
136+
result = (
137+
p
138+
| 'Create' >> beam.Create([i for i in range(2)])
139+
| 'Generate' >> beam.ParDo(generate).with_output_types(
140+
tuple[(DataClassInt | DataClassStr), int])
141+
| 'GBK' >> beam.GroupByKey()
142+
| 'Count Elements' >> beam.Map(self._check_key_type_and_count))
143+
assert_that(result, equal_to([2] * 4))
144+
89145
def test_derived_dataclass_schema_id(self):
90146
@dataclass
91147
class BaseDataClass:

sdks/python/apache_beam/typehints/schemas.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,8 +663,15 @@ def union_schema_type(element_types):
663663
element_types must be a set of schema-aware types whose fields have the
664664
same naming and ordering.
665665
"""
666+
named_fields_and_types = []
667+
for t in element_types:
668+
n = named_fields_from_element_type(t)
669+
if named_fields_and_types and len(named_fields_and_types[-1]) != len(n):
670+
raise TypeError("element types has different number of fields")
671+
named_fields_and_types.append(n)
672+
666673
union_fields_and_types = []
667-
for field in zip(*[named_fields_from_element_type(t) for t in element_types]):
674+
for field in zip(*named_fields_and_types):
668675
names, types = zip(*field)
669676
name_set = set(names)
670677
if len(name_set) != 1:

0 commit comments

Comments
 (0)