@@ -95,39 +95,67 @@ def tearDownClass(cls) -> None:
9595
9696 def test_map_stream (self ) -> None :
9797 stub = self .__stub ()
98- generator_response = None
98+
99+ # Send >1 requests
100+ req_count = 3
99101 try :
100- generator_response = stub .MapFn (request_iterator = request_generator (count = 1 , session = 1 ))
102+ generator_response = stub .MapFn (
103+ request_iterator = request_generator (count = req_count , session = 1 )
104+ )
101105 except grpc .RpcError as e :
102106 logging .error (e )
107+ self .fail (f"RPC failed: { e } " )
103108
109+ # First message must be the handshake
104110 handshake = next (generator_response )
105- # assert that handshake response is received.
106111 self .assertTrue (handshake .handshake .sot )
107- data_resp = []
108- for r in generator_response :
109- data_resp .append (r )
110-
111- self .assertEqual (11 , len (data_resp ))
112112
113- idx = 0
114- while idx < len (data_resp ) - 1 :
113+ # Expected: 10 results per request + 1 EOT per request
114+ expected_result_msgs = req_count * 10
115+ expected_eots = req_count
116+
117+ # Prepare expected payload
118+ expected_payload = bytes (
119+ "payload:test_mock_message "
120+ "event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00" ,
121+ encoding = "utf-8" ,
122+ )
123+
124+ from collections import Counter
125+
126+ id_counter = Counter ()
127+ result_msg_count = 0
128+ eot_count = 0
129+
130+ for msg in generator_response :
131+ # Count EOTs wherever they show up
132+ if hasattr (msg , "status" ) and msg .status .eot :
133+ eot_count += 1
134+ continue
135+
136+ # Otherwise, it's a data/result message; validate payload and tally by id
137+ self .assertTrue (msg .results , "Expected results in MapResponse." )
138+ self .assertEqual (expected_payload , msg .results [0 ].value )
139+ id_counter [msg .id ] += 1
140+ result_msg_count += 1
141+
142+ # Validate totals
143+ self .assertEqual (
144+ expected_result_msgs ,
145+ result_msg_count ,
146+ f"Expected { expected_result_msgs } result messages, got { result_msg_count } " ,
147+ )
148+ self .assertEqual (
149+ expected_eots , eot_count , f"Expected { expected_eots } EOT messages, got { eot_count } "
150+ )
151+
152+ # Validate 10 messages per request id: test-id-0..test-id-(req_count-1)
153+ for i in range (req_count ):
115154 self .assertEqual (
116- bytes (
117- "payload:test_mock_message "
118- "event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00" ,
119- encoding = "utf-8" ,
120- ),
121- data_resp [idx ].results [0 ].value ,
155+ 10 ,
156+ id_counter [f"test-id-{ i } " ],
157+ f"Expected 10 results for test-id-{ i } , got { id_counter [f'test-id-{ i } ' ]} " ,
122158 )
123- _id = data_resp [idx ].id
124- self .assertEqual (_id , "test-id-0" )
125- # capture the output from the SinkFn generator and assert.
126- idx += 1
127- # EOT Response
128- self .assertEqual (data_resp [len (data_resp ) - 1 ].status .eot , True )
129- # 10 sink responses + 1 EOT response
130- self .assertEqual (11 , len (data_resp ))
131159
132160 def test_is_ready (self ) -> None :
133161 with grpc .insecure_channel ("unix:///tmp/async_map_stream.sock" ) as channel :
0 commit comments