|
3 | 3 | import threading |
4 | 4 | import unittest |
5 | 5 | from collections.abc import AsyncIterable |
| 6 | +from unittest.mock import MagicMock |
6 | 7 | import grpc |
7 | 8 | from grpc.aio._server import Server |
8 | 9 |
|
|
16 | 17 | Metadata, |
17 | 18 | ) |
18 | 19 | from pynumaflow.proto.reducer import reduce_pb2, reduce_pb2_grpc |
| 20 | +from pynumaflow.reducestreamer.servicer.async_servicer import AsyncReduceStreamServicer |
19 | 21 | from pynumaflow.shared.asynciter import NonBlockingIterator |
20 | 22 | from tests.testing_utils import ( |
21 | 23 | mock_message, |
@@ -217,6 +219,66 @@ def __stub(self): |
217 | 219 | return reduce_pb2_grpc.ReduceStub(_channel) |
218 | 220 |
|
219 | 221 |
|
| 222 | +async def _emit_one_handler(keys, datums, output, md): |
| 223 | + """Handler that emits one message eagerly, then blocks reading remaining datums.""" |
| 224 | + await output.put(Message(b"result", keys=keys)) |
| 225 | + async for _ in datums: |
| 226 | + pass |
| 227 | + |
| 228 | + |
| 229 | +def test_cancelled_error_in_consumer_loop(): |
| 230 | + """athrow(CancelledError) at the yield point exercises the except CancelledError branch.""" |
| 231 | + servicer = AsyncReduceStreamServicer(_emit_one_handler) |
| 232 | + shutdown_event = asyncio.Event() |
| 233 | + servicer.set_shutdown_event(shutdown_event) |
| 234 | + request, _ = start_request(multiple_window=False) |
| 235 | + |
| 236 | + async def _run(): |
| 237 | + async def requests(): |
| 238 | + yield request |
| 239 | + await asyncio.sleep(999) |
| 240 | + |
| 241 | + gen = servicer.ReduceFn(requests(), MagicMock()) |
| 242 | + # Drive the pipeline until the handler's message is yielded. |
| 243 | + await gen.__anext__() |
| 244 | + # Simulate task cancellation (e.g. SIGTERM) at the yield point. |
| 245 | + try: |
| 246 | + await gen.athrow(asyncio.CancelledError()) |
| 247 | + except StopAsyncIteration: |
| 248 | + pass |
| 249 | + |
| 250 | + asyncio.run(_run()) |
| 251 | + assert shutdown_event.is_set() |
| 252 | + assert servicer._error is None |
| 253 | + |
| 254 | + |
| 255 | +def test_base_exception_in_consumer_loop(): |
| 256 | + """athrow(ValueError) at the yield point exercises the except BaseException branch.""" |
| 257 | + servicer = AsyncReduceStreamServicer(_emit_one_handler) |
| 258 | + shutdown_event = asyncio.Event() |
| 259 | + servicer.set_shutdown_event(shutdown_event) |
| 260 | + request, _ = start_request(multiple_window=False) |
| 261 | + |
| 262 | + async def _run(): |
| 263 | + async def requests(): |
| 264 | + yield request |
| 265 | + await asyncio.sleep(999) |
| 266 | + |
| 267 | + ctx = MagicMock() |
| 268 | + gen = servicer.ReduceFn(requests(), ctx) |
| 269 | + await gen.__anext__() |
| 270 | + try: |
| 271 | + await gen.athrow(ValueError("boom")) |
| 272 | + except StopAsyncIteration: |
| 273 | + pass |
| 274 | + return ctx |
| 275 | + |
| 276 | + ctx = asyncio.run(_run()) |
| 277 | + assert shutdown_event.is_set() |
| 278 | + assert isinstance(servicer._error, ValueError) |
| 279 | + ctx.set_code.assert_called_once_with(grpc.StatusCode.INTERNAL) |
| 280 | + |
| 281 | + |
220 | 282 | if __name__ == "__main__": |
221 | 283 | logging.basicConfig(level=logging.DEBUG) |
222 | 284 | unittest.main() |
0 commit comments