Skip to content

Commit 79f886b

Browse files
committed
update tests
Signed-off-by: kohlisid <sidhant.kohli@gmail.com>
1 parent 9949d58 commit 79f886b

1 file changed

Lines changed: 52 additions & 24 deletions

File tree

tests/mapstream/test_async_map_stream.py

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -95,39 +95,67 @@ def tearDownClass(cls) -> None:
9595

9696
def test_map_stream(self) -> None:
9797
stub = self.__stub()
98-
generator_response = None
98+
99+
# Send >1 requests
100+
req_count = 3
99101
try:
100-
generator_response = stub.MapFn(request_iterator=request_generator(count=1, session=1))
102+
generator_response = stub.MapFn(
103+
request_iterator=request_generator(count=req_count, session=1)
104+
)
101105
except grpc.RpcError as e:
102106
logging.error(e)
107+
self.fail(f"RPC failed: {e}")
103108

109+
# First message must be the handshake
104110
handshake = next(generator_response)
105-
# assert that handshake response is received.
106111
self.assertTrue(handshake.handshake.sot)
107-
data_resp = []
108-
for r in generator_response:
109-
data_resp.append(r)
110-
111-
self.assertEqual(11, len(data_resp))
112112

113-
idx = 0
114-
while idx < len(data_resp) - 1:
113+
# Expected: 10 results per request + 1 EOT per request
114+
expected_result_msgs = req_count * 10
115+
expected_eots = req_count
116+
117+
# Prepare expected payload
118+
expected_payload = bytes(
119+
"payload:test_mock_message "
120+
"event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00",
121+
encoding="utf-8",
122+
)
123+
124+
from collections import Counter
125+
126+
id_counter = Counter()
127+
result_msg_count = 0
128+
eot_count = 0
129+
130+
for msg in generator_response:
131+
# Count EOTs wherever they show up
132+
if hasattr(msg, "status") and msg.status.eot:
133+
eot_count += 1
134+
continue
135+
136+
# Otherwise, it's a data/result message; validate payload and tally by id
137+
self.assertTrue(msg.results, "Expected results in MapResponse.")
138+
self.assertEqual(expected_payload, msg.results[0].value)
139+
id_counter[msg.id] += 1
140+
result_msg_count += 1
141+
142+
# Validate totals
143+
self.assertEqual(
144+
expected_result_msgs,
145+
result_msg_count,
146+
f"Expected {expected_result_msgs} result messages, got {result_msg_count}",
147+
)
148+
self.assertEqual(
149+
expected_eots, eot_count, f"Expected {expected_eots} EOT messages, got {eot_count}"
150+
)
151+
152+
# Validate 10 messages per request id: test-id-0..test-id-(req_count-1)
153+
for i in range(req_count):
115154
self.assertEqual(
116-
bytes(
117-
"payload:test_mock_message "
118-
"event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00",
119-
encoding="utf-8",
120-
),
121-
data_resp[idx].results[0].value,
155+
10,
156+
id_counter[f"test-id-{i}"],
157+
f"Expected 10 results for test-id-{i}, got {id_counter[f'test-id-{i}']}",
122158
)
123-
_id = data_resp[idx].id
124-
self.assertEqual(_id, "test-id-0")
125-
# capture the output from the SinkFn generator and assert.
126-
idx += 1
127-
# EOT Response
128-
self.assertEqual(data_resp[len(data_resp) - 1].status.eot, True)
129-
# 10 sink responses + 1 EOT response
130-
self.assertEqual(11, len(data_resp))
131159

132160
def test_is_ready(self) -> None:
133161
with grpc.insecure_channel("unix:///tmp/async_map_stream.sock") as channel:

0 commit comments

Comments
 (0)