|
| 1 | +""" |
| 2 | +Shutdown-event tests for the async SourceTransform servicer. |
| 3 | +
|
| 4 | +Covers the CancelledError and BaseException handlers in SourceTransformFn. |
| 5 | +""" |
| 6 | + |
| 7 | +import asyncio |
| 8 | +from unittest import mock |
| 9 | + |
| 10 | +from pynumaflow.sourcetransformer.servicer._async_servicer import SourceTransformAsyncServicer |
| 11 | +from pynumaflow.sourcetransformer import Datum, Messages, Message |
| 12 | +from tests.testing_utils import mock_new_event_time |
| 13 | + |
| 14 | + |
| 15 | +async def async_transform_handler(keys: list[str], datum: Datum) -> Messages: |
| 16 | + return Messages(Message(datum.value, mock_new_event_time(), keys=keys)) |
| 17 | + |
| 18 | + |
| 19 | +async def _collect(async_gen): |
| 20 | + results = [] |
| 21 | + async for item in async_gen: |
| 22 | + results.append(item) |
| 23 | + return results |
| 24 | + |
| 25 | + |
| 26 | +def test_shutdown_on_cancelled_error(): |
| 27 | + """CancelledError during SourceTransformFn should set shutdown_event, no error stored.""" |
| 28 | + |
| 29 | + async def _run(): |
| 30 | + servicer = SourceTransformAsyncServicer(handler=async_transform_handler) |
| 31 | + shutdown_event = asyncio.Event() |
| 32 | + servicer.set_shutdown_event(shutdown_event) |
| 33 | + |
| 34 | + async def _cancelled_iter(): |
| 35 | + raise asyncio.CancelledError() |
| 36 | + yield |
| 37 | + |
| 38 | + ctx = mock.MagicMock() |
| 39 | + await _collect(servicer.SourceTransformFn(_cancelled_iter(), ctx)) |
| 40 | + |
| 41 | + assert shutdown_event.is_set() |
| 42 | + assert servicer._error is None |
| 43 | + |
| 44 | + asyncio.run(_run()) |
| 45 | + |
| 46 | + |
| 47 | +def test_shutdown_on_handler_error(): |
| 48 | + """BaseException in SourceTransformFn should set shutdown_event and store error.""" |
| 49 | + |
| 50 | + async def _run(): |
| 51 | + servicer = SourceTransformAsyncServicer(handler=async_transform_handler) |
| 52 | + shutdown_event = asyncio.Event() |
| 53 | + servicer.set_shutdown_event(shutdown_event) |
| 54 | + |
| 55 | + async def _error_iter(): |
| 56 | + raise RuntimeError("unexpected error") |
| 57 | + yield |
| 58 | + |
| 59 | + ctx = mock.MagicMock() |
| 60 | + await _collect(servicer.SourceTransformFn(_error_iter(), ctx)) |
| 61 | + |
| 62 | + assert shutdown_event.is_set() |
| 63 | + assert servicer._error is not None |
| 64 | + assert "unexpected error" in repr(servicer._error) |
| 65 | + |
| 66 | + asyncio.run(_run()) |
0 commit comments