@@ -86,6 +86,62 @@ def generate(num: int):
8686 | 'Count Elements' >> beam .Map (self ._check_key_type_and_count ))
8787 assert_that (result , equal_to ([10 ] * 100 ))
8888
89+ def test_group_by_key_namedtuple_union (self ):
90+ Tuple1 = typing .NamedTuple ("Tuple1" , [("id" , int )])
91+
92+ Tuple2 = typing .NamedTuple ("Tuple2" , [("id" , int ), ("name" , str )])
93+
94+ def generate (num : int ):
95+ for i in range (2 ):
96+ yield (Tuple1 (i ), num )
97+ yield (Tuple2 (i , 'a' ), num )
98+
99+ pipeline = TestPipeline (is_integration_test = False )
100+
101+ with pipeline as p :
102+ result = (
103+ p
104+ | 'Create' >> beam .Create ([i for i in range (2 )])
105+ | 'Generate' >> beam .ParDo (generate ).with_output_types (
106+ tuple [(Tuple1 | Tuple2 ), int ])
107+ | 'GBK' >> beam .GroupByKey ()
108+ | 'Count' >> beam .Map (lambda x : len (x [1 ])))
109+ assert_that (result , equal_to ([2 ] * 4 ))
110+
111+ # Union of dataclasses as type hint currently result in FastPrimitiveCoder
112+ # fails at GBK
113+ @unittest .skip ("https://github.com/apache/beam/issues/22085" )
114+ def test_group_by_key_inherited_dataclass (self ):
115+ @dataclass
116+ class DataClassInt :
117+ id : int
118+
119+ @dataclass
120+ class DataClassStr (DataClassInt ):
121+ name : str
122+
123+ beam .coders .typecoders .registry .register_coder (
124+ DataClassInt , beam .coders .RowCoder )
125+ beam .coders .typecoders .registry .register_coder (
126+ DataClassStr , beam .coders .RowCoder )
127+
128+ def generate (num : int ):
129+ for i in range (10 ):
130+ yield (DataClassInt (i ), num )
131+ yield (DataClassStr (i , 'a' ), num )
132+
133+ pipeline = TestPipeline (is_integration_test = False )
134+
135+ with pipeline as p :
136+ result = (
137+ p
138+ | 'Create' >> beam .Create ([i for i in range (2 )])
139+ | 'Generate' >> beam .ParDo (generate ).with_output_types (
140+ tuple [(DataClassInt | DataClassStr ), int ])
141+ | 'GBK' >> beam .GroupByKey ()
142+ | 'Count Elements' >> beam .Map (self ._check_key_type_and_count ))
143+ assert_that (result , equal_to ([2 ] * 4 ))
144+
89145 def test_derived_dataclass_schema_id (self ):
90146 @dataclass
91147 class BaseDataClass :
0 commit comments