|
1 | 1 | import asyncio |
2 | 2 | import logging |
3 | | -import threading |
4 | 3 | import unittest |
5 | 4 | from unittest.mock import patch |
6 | 5 |
|
|
23 | 22 |
|
24 | 23 | LOGGER = setup_logging(__name__) |
25 | 24 |
|
26 | | -_s: Server = None |
27 | | -server_port = "unix:///tmp/async_err_source.sock" |
28 | | -_channel = grpc.insecure_channel(server_port) |
29 | | -_loop = None |
| 25 | +# Mock the handle_async_error function to prevent process SIGKILL |
| 26 | +async def mock_handle_async_error(context, exception, exception_type): |
| 27 | + """Mock handle_async_error to prevent process termination during tests.""" |
| 28 | + from pynumaflow.shared.server import update_context_err |
30 | 29 |
|
| 30 | + err_msg = f"{exception_type}: {repr(exception)}" |
| 31 | + update_context_err(context, exception, err_msg) |
31 | 32 |
|
32 | | -def startup_callable(loop): |
33 | | - asyncio.set_event_loop(loop) |
34 | | - loop.run_forever() |
35 | 33 |
|
36 | | - |
37 | | -async def start_server(): |
38 | | - server = grpc.aio.server() |
39 | | - class_instance = AsyncSourceError() |
40 | | - server_instance = SourceAsyncServer(sourcer_instance=class_instance) |
41 | | - udfs = server_instance.servicer |
42 | | - source_pb2_grpc.add_SourceServicer_to_server(udfs, server) |
| 34 | +# We are mocking the error handler to not exit the program during testing |
| 35 | +@patch("pynumaflow.sourcer.servicer.async_servicer.handle_async_error", mock_handle_async_error) |
| 36 | +class TestAsyncServerErrorScenario(unittest.IsolatedAsyncioTestCase): |
43 | 37 | listen_addr = "unix:///tmp/async_err_source.sock" |
44 | | - server.add_insecure_port(listen_addr) |
45 | | - logging.info("Starting server on %s", listen_addr) |
46 | | - global _s |
47 | | - _s = server |
48 | | - await server.start() |
49 | | - await server.wait_for_termination() |
50 | | - |
51 | | - |
52 | | -# We are mocking the terminate function from the psutil to not exit the program during testing |
53 | | -@patch("psutil.Process.kill", mock_terminate_on_stop) |
54 | | -class TestAsyncServerErrorScenario(unittest.TestCase): |
55 | | - @classmethod |
56 | | - def setUpClass(cls) -> None: |
57 | | - global _loop |
58 | | - loop = asyncio.new_event_loop() |
59 | | - _loop = loop |
60 | | - _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) |
61 | | - _thread.start() |
62 | | - asyncio.run_coroutine_threadsafe(start_server(), loop=loop) |
63 | | - while True: |
| 38 | + |
| 39 | + async def asyncSetUp(self) -> None: |
| 40 | + # Create a fresh server for each test to avoid event loop issues |
| 41 | + self.server = grpc.aio.server() |
| 42 | + class_instance = AsyncSourceError() |
| 43 | + server_instance = SourceAsyncServer(sourcer_instance=class_instance) |
| 44 | + udfs = server_instance.servicer |
| 45 | + source_pb2_grpc.add_SourceServicer_to_server(udfs, self.server) |
| 46 | + self.server.add_insecure_port(self.listen_addr) |
| 47 | + await self.server.start() |
| 48 | + |
| 49 | + # Wait for server to be ready with timeout |
| 50 | + max_attempts = 50 # 5 seconds total (50 * 0.1) |
| 51 | + attempt = 0 |
| 52 | + while attempt < max_attempts: |
64 | 53 | try: |
65 | | - with grpc.insecure_channel("unix:///tmp/async_err_source.sock") as channel: |
66 | | - f = grpc.channel_ready_future(channel) |
67 | | - f.result(timeout=10) |
68 | | - if f.done(): |
69 | | - break |
70 | | - except grpc.FutureTimeoutError as e: |
71 | | - LOGGER.error("error trying to connect to grpc server") |
72 | | - LOGGER.error(e) |
73 | | - |
74 | | - @classmethod |
75 | | - def tearDownClass(cls) -> None: |
76 | | - try: |
77 | | - _loop.stop() |
78 | | - LOGGER.info("stopped the event loop") |
79 | | - except Exception as e: |
80 | | - LOGGER.error(e) |
81 | | - |
82 | | - def test_read_error(self) -> None: |
83 | | - with grpc.insecure_channel(server_port) as channel: |
| 54 | + async with grpc.aio.insecure_channel(self.listen_addr) as channel: |
| 55 | + await channel.channel_ready() |
| 56 | + break |
| 57 | + except Exception as e: |
| 58 | + LOGGER.debug("Waiting for server to be ready, attempt %d", attempt + 1) |
| 59 | + await asyncio.sleep(0.1) |
| 60 | + attempt += 1 |
| 61 | + |
| 62 | + if attempt >= max_attempts: |
| 63 | + raise RuntimeError("Server failed to start within timeout period") |
| 64 | + |
| 65 | + async def asyncTearDown(self) -> None: |
| 66 | + # Stop the server after each test |
| 67 | + if hasattr(self, 'server') and self.server is not None: |
| 68 | + await self.server.stop(0) |
| 69 | + # Small delay to ensure socket cleanup |
| 70 | + await asyncio.sleep(0.1) |
| 71 | + |
| 72 | + async def test_read_error(self) -> None: |
| 73 | + async with grpc.aio.insecure_channel(self.listen_addr) as channel: |
84 | 74 | stub = source_pb2_grpc.SourceStub(channel) |
85 | 75 | request = read_req_source_fn() |
86 | | - generator_response = None |
87 | | - try: |
| 76 | + with self.assertRaises(grpc.RpcError) as resp_err: |
88 | 77 | generator_response = stub.ReadFn( |
89 | 78 | request_iterator=request_generator(1, request, "read") |
90 | 79 | ) |
91 | | - for _ in generator_response: |
92 | | - pass |
93 | | - except grpc.RpcError as e: |
94 | | - self.assertEqual(grpc.StatusCode.INTERNAL, e.code()) |
95 | | - self.assertTrue("Got a runtime error from read handler." in e.details()) |
96 | | - return |
97 | | - |
98 | | - self.fail("Expected an exception.") |
99 | | - |
100 | | - def test_read_handshake_error(self) -> None: |
101 | | - grpc_exception = None |
102 | | - with grpc.insecure_channel(server_port) as channel: |
| 80 | + # await anext(aiter(generator_response)) should work for Python >=3.10 |
| 81 | + [_ async for _ in generator_response] |
| 82 | + self.assertEqual(grpc.StatusCode.INTERNAL, resp_err.exception.code()) |
| 83 | + self.assertTrue("Got a runtime error from read handler." in resp_err.exception.details()) |
| 84 | + |
| 85 | + |
| 86 | + async def test_read_handshake_error(self) -> None: |
| 87 | + async with grpc.aio.insecure_channel(self.listen_addr) as channel: |
103 | 88 | stub = source_pb2_grpc.SourceStub(channel) |
104 | 89 | request = read_req_source_fn() |
105 | | - generator_response = None |
106 | | - try: |
| 90 | + with self.assertRaises(grpc.RpcError) as resp_err: |
107 | 91 | generator_response = stub.ReadFn( |
108 | 92 | request_iterator=request_generator(1, request, "read", False) |
109 | 93 | ) |
110 | | - for _ in generator_response: |
111 | | - pass |
112 | | - except BaseException as e: |
113 | | - self.assertTrue("ReadFn: expected handshake message" in e.__str__()) |
114 | | - return |
115 | | - except grpc.RpcError as e: |
116 | | - grpc_exception = e |
117 | | - self.assertEqual(grpc.StatusCode.UNKNOWN, e.code()) |
118 | | - print(e.details()) |
119 | | - |
120 | | - self.assertIsNotNone(grpc_exception) |
121 | | - self.fail("Expected an exception.") |
122 | | - |
123 | | - def test_ack_error(self) -> None: |
124 | | - with grpc.insecure_channel(server_port) as channel: |
| 94 | + # await anext(aiter(generator_response)) should work for Python >=3.10 |
| 95 | + [_ async for _ in generator_response] |
| 96 | + self.assertEqual(grpc.StatusCode.INTERNAL, resp_err.exception.code()) |
| 97 | + self.assertTrue("ReadFn: expected handshake message" in resp_err.exception.details()) |
| 98 | + |
| 99 | + |
| 100 | + async def test_ack_error(self) -> None: |
| 101 | + async with grpc.aio.insecure_channel(self.listen_addr) as channel: |
125 | 102 | stub = source_pb2_grpc.SourceStub(channel) |
126 | 103 | request = ack_req_source_fn() |
127 | | - try: |
| 104 | + with self.assertRaises(grpc.RpcError) as resp_err: |
128 | 105 | resp = stub.AckFn(request_iterator=request_generator(1, request, "ack")) |
129 | | - for _ in resp: |
130 | | - pass |
131 | | - except grpc.RpcError as e: |
132 | | - self.assertEqual(grpc.StatusCode.INTERNAL, e.code()) |
133 | | - self.assertTrue("Got a runtime error from ack handler." in e.details()) |
134 | | - return |
135 | | - self.fail("Expected an exception.") |
136 | | - |
137 | | - def test_ack_no_handshake_error(self) -> None: |
138 | | - with grpc.insecure_channel(server_port) as channel: |
| 106 | + [_ async for _ in resp] |
| 107 | + self.assertEqual(grpc.StatusCode.INTERNAL, resp_err.exception.code()) |
| 108 | + self.assertTrue("Got a runtime error from ack handler." in resp_err.exception.details()) |
| 109 | + |
| 110 | + async def test_ack_no_handshake_error(self) -> None: |
| 111 | + async with grpc.aio.insecure_channel(self.listen_addr) as channel: |
139 | 112 | stub = source_pb2_grpc.SourceStub(channel) |
140 | 113 | request = ack_req_source_fn() |
141 | | - try: |
| 114 | + with self.assertRaises(grpc.RpcError) as resp_err: |
142 | 115 | resp = stub.AckFn(request_iterator=request_generator(1, request, "ack", False)) |
143 | | - for _ in resp: |
144 | | - pass |
145 | | - except BaseException as e: |
146 | | - self.assertTrue("AckFn: expected handshake message" in e.__str__()) |
147 | | - return |
148 | | - except grpc.RpcError as e: |
149 | | - self.assertEqual(grpc.StatusCode.UNKNOWN, e.code()) |
150 | | - print(e.details()) |
151 | | - self.fail("Expected an exception.") |
152 | | - |
153 | | - def test_pending_error(self) -> None: |
154 | | - with grpc.insecure_channel(server_port) as channel: |
| 116 | + [_ async for _ in resp] |
| 117 | + self.assertEqual(grpc.StatusCode.INTERNAL, resp_err.exception.code()) |
| 118 | + self.assertTrue("AckFn: expected handshake message" in resp_err.exception.details()) |
| 119 | + |
| 120 | + async def test_pending_error(self) -> None: |
| 121 | + async with grpc.aio.insecure_channel(self.listen_addr) as channel: |
155 | 122 | stub = source_pb2_grpc.SourceStub(channel) |
156 | 123 | request = _empty_pb2.Empty() |
157 | | - try: |
158 | | - stub.PendingFn(request=request) |
159 | | - except Exception as e: |
160 | | - self.assertTrue("Got a runtime error from pending handler." in e.__str__()) |
161 | | - return |
162 | | - self.fail("Expected an exception.") |
| 124 | + with self.assertRaises(grpc.RpcError) as resp_err: |
| 125 | + await stub.PendingFn(request=request) |
| 126 | + self.assertEqual(grpc.StatusCode.INTERNAL, resp_err.exception.code()) |
| 127 | + self.assertTrue("Got a runtime error from pending handler." in resp_err.exception.details()) |
163 | 128 |
|
164 | | - def test_partition_error(self) -> None: |
165 | | - with grpc.insecure_channel(server_port) as channel: |
| 129 | + async def test_partition_error(self) -> None: |
| 130 | + async with grpc.aio.insecure_channel(self.listen_addr) as channel: |
166 | 131 | stub = source_pb2_grpc.SourceStub(channel) |
167 | 132 | request = _empty_pb2.Empty() |
168 | | - try: |
169 | | - stub.PartitionsFn(request=request) |
170 | | - except Exception as e: |
171 | | - self.assertTrue("Got a runtime error from partition handler." in e.__str__()) |
172 | | - return |
173 | | - self.fail("Expected an exception.") |
174 | | - |
175 | | - def test_invalid_server_type(self) -> None: |
| 133 | + with self.assertRaises(grpc.RpcError) as resp_err: |
| 134 | + response = await stub.PartitionsFn(request=request) |
| 135 | + # Force evaluation of the response |
| 136 | + _ = response.result |
| 137 | + self.assertEqual(grpc.StatusCode.INTERNAL, resp_err.exception.code()) |
| 138 | + self.assertTrue("Got a runtime error from partition handler." in resp_err.exception.details()) |
| 139 | + |
| 140 | + async def test_invalid_server_type(self) -> None: |
176 | 141 | with self.assertRaises(TypeError): |
177 | 142 | SourceAsyncServer() |
178 | 143 |
|
|
0 commit comments