|
1 | 1 | import asyncio |
2 | 2 | import logging |
3 | 3 | import threading |
4 | | -import unittest |
5 | 4 | from collections.abc import AsyncIterable |
6 | | -from unittest.mock import patch |
7 | 5 |
|
8 | 6 | import grpc |
9 | | -from grpc.aio._server import Server |
| 7 | +import pytest |
10 | 8 |
|
11 | 9 | from pynumaflow import setup_logging |
12 | 10 | from pynumaflow.accumulator import ( |
|
20 | 18 | from tests.testing_utils import ( |
21 | 19 | mock_message, |
22 | 20 | get_time_args, |
23 | | - mock_terminate_on_stop, |
24 | 21 | ) |
25 | 22 |
|
26 | 23 | LOGGER = setup_logging(__name__) |
27 | 24 |
|
| 25 | +SOCK_PATH = "unix:///tmp/accumulator_err.sock" |
| 26 | + |
28 | 27 |
|
29 | 28 | def request_generator(count, request): |
30 | 29 | for i in range(count): |
@@ -58,11 +57,6 @@ def start_request() -> accumulator_pb2.AccumulatorRequest: |
58 | 57 | return request |
59 | 58 |
|
60 | 59 |
|
61 | | -_s: Server = None |
62 | | -_channel = grpc.insecure_channel("unix:///tmp/accumulator_err.sock") |
63 | | -_loop = None |
64 | | - |
65 | | - |
66 | 60 | def startup_callable(loop): |
67 | 61 | asyncio.set_event_loop(loop) |
68 | 62 | loop.run_forever() |
@@ -99,77 +93,68 @@ def NewAsyncAccumulatorError(): |
99 | 93 | return udfs |
100 | 94 |
|
101 | 95 |
|
102 | | -@patch("psutil.Process.kill", mock_terminate_on_stop) |
103 | | -async def start_server(udfs): |
| 96 | +async def _start_server(udfs): |
104 | 97 | server = grpc.aio.server() |
105 | 98 | accumulator_pb2_grpc.add_AccumulatorServicer_to_server(udfs, server) |
106 | | - listen_addr = "unix:///tmp/accumulator_err.sock" |
107 | | - server.add_insecure_port(listen_addr) |
108 | | - logging.info("Starting server on %s", listen_addr) |
109 | | - global _s |
110 | | - _s = server |
| 99 | + server.add_insecure_port(SOCK_PATH) |
| 100 | + logging.info("Starting server on %s", SOCK_PATH) |
111 | 101 | await server.start() |
112 | | - await server.wait_for_termination() |
113 | | - |
114 | | - |
115 | | -@patch("psutil.Process.kill", mock_terminate_on_stop) |
116 | | -class TestAsyncAccumulatorError(unittest.TestCase): |
117 | | - @classmethod |
118 | | - def setUpClass(cls) -> None: |
119 | | - global _loop |
120 | | - loop = asyncio.new_event_loop() |
121 | | - _loop = loop |
122 | | - _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) |
123 | | - _thread.start() |
124 | | - udfs = NewAsyncAccumulatorError() |
125 | | - asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) |
126 | | - while True: |
127 | | - try: |
128 | | - with grpc.insecure_channel("unix:///tmp/accumulator_err.sock") as channel: |
129 | | - f = grpc.channel_ready_future(channel) |
130 | | - f.result(timeout=10) |
131 | | - if f.done(): |
132 | | - break |
133 | | - except grpc.FutureTimeoutError as e: |
134 | | - LOGGER.error("error trying to connect to grpc server") |
135 | | - LOGGER.error(e) |
136 | | - |
137 | | - @classmethod |
138 | | - def tearDownClass(cls) -> None: |
| 102 | + return server |
| 103 | + |
| 104 | + |
| 105 | +@pytest.fixture(scope="module") |
| 106 | +def async_accumulator_err_server(): |
| 107 | + """Module-scoped fixture: starts an async gRPC accumulator error server.""" |
| 108 | + loop = asyncio.new_event_loop() |
| 109 | + thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) |
| 110 | + thread.start() |
| 111 | + |
| 112 | + udfs = NewAsyncAccumulatorError() |
| 113 | + future = asyncio.run_coroutine_threadsafe(_start_server(udfs), loop=loop) |
| 114 | + future.result(timeout=10) |
| 115 | + |
| 116 | + # Wait for the server to be ready |
| 117 | + while True: |
139 | 118 | try: |
140 | | - _loop.stop() |
141 | | - LOGGER.info("stopped the event loop") |
142 | | - except Exception as e: |
| 119 | + with grpc.insecure_channel(SOCK_PATH) as channel: |
| 120 | + f = grpc.channel_ready_future(channel) |
| 121 | + f.result(timeout=10) |
| 122 | + if f.done(): |
| 123 | + break |
| 124 | + except grpc.FutureTimeoutError as e: |
| 125 | + LOGGER.error("error trying to connect to grpc server") |
143 | 126 | LOGGER.error(e) |
144 | 127 |
|
145 | | - @patch("psutil.Process.kill", mock_terminate_on_stop) |
146 | | - def test_accumulate_partial_success(self) -> None: |
147 | | - """Test that the first datum is processed before error occurs""" |
148 | | - stub = self.__stub() |
149 | | - request = start_request() |
| 128 | + yield loop |
150 | 129 |
|
151 | | - try: |
152 | | - generator_response = stub.AccumulateFn( |
153 | | - request_iterator=request_generator(count=5, request=request) |
154 | | - ) |
155 | | - |
156 | | - # Try to consume the generator |
157 | | - counter = 0 |
158 | | - for response in generator_response: |
159 | | - self.assertIsInstance(response, accumulator_pb2.AccumulatorResponse) |
160 | | - self.assertTrue(response.payload.value.startswith(b"counter:")) |
161 | | - counter += 1 |
162 | | - |
163 | | - self.assertEqual(counter, 1, "Expected only one successful response before error") |
164 | | - except BaseException as err: |
165 | | - self.assertTrue("Simulated error in accumulator handler" in str(err)) |
166 | | - return |
167 | | - self.fail("Expected an exception.") |
168 | | - |
169 | | - def __stub(self): |
170 | | - return accumulator_pb2_grpc.AccumulatorStub(_channel) |
171 | | - |
172 | | - |
173 | | -if __name__ == "__main__": |
174 | | - logging.basicConfig(level=logging.DEBUG) |
175 | | - unittest.main() |
| 130 | + loop.stop() |
| 131 | + LOGGER.info("stopped the event loop") |
| 132 | + |
| 133 | + |
| 134 | +@pytest.fixture() |
| 135 | +def accumulator_err_stub(async_accumulator_err_server): |
| 136 | + """Returns an AccumulatorStub connected to the running async error server.""" |
| 137 | + return accumulator_pb2_grpc.AccumulatorStub(grpc.insecure_channel(SOCK_PATH)) |
| 138 | + |
| 139 | + |
| 140 | +def test_accumulate_partial_success(accumulator_err_stub) -> None: |
| 141 | + """Test that the first datum is processed before error occurs""" |
| 142 | + request = start_request() |
| 143 | + |
| 144 | + try: |
| 145 | + generator_response = accumulator_err_stub.AccumulateFn( |
| 146 | + request_iterator=request_generator(count=5, request=request) |
| 147 | + ) |
| 148 | + |
| 149 | + # Try to consume the generator |
| 150 | + counter = 0 |
| 151 | + for response in generator_response: |
| 152 | + assert isinstance(response, accumulator_pb2.AccumulatorResponse) |
| 153 | + assert response.payload.value.startswith(b"counter:") |
| 154 | + counter += 1 |
| 155 | + |
| 156 | + assert counter == 1, "Expected only one successful response before error" |
| 157 | + except BaseException as err: |
| 158 | + assert "Simulated error in accumulator handler" in str(err) |
| 159 | + return |
| 160 | + pytest.fail("Expected an exception.") |
0 commit comments