Skip to content

Commit 27424ba

Browse files
committed
Migrate unittest.TestCase to plain pytest functions
Signed-off-by: Sreekanth <prsreekanth920@gmail.com>
1 parent adfad47 commit 27424ba

38 files changed

Lines changed: 4370 additions & 4843 deletions

packages/pynumaflow/tests/accumulator/test_async_accumulator.py

Lines changed: 281 additions & 300 deletions
Large diffs are not rendered by default.
Lines changed: 61 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import asyncio
22
import logging
33
import threading
4-
import unittest
54
from collections.abc import AsyncIterable
6-
from unittest.mock import patch
75

86
import grpc
9-
from grpc.aio._server import Server
7+
import pytest
108

119
from pynumaflow import setup_logging
1210
from pynumaflow.accumulator import (
@@ -20,11 +18,12 @@
2018
from tests.testing_utils import (
2119
mock_message,
2220
get_time_args,
23-
mock_terminate_on_stop,
2421
)
2522

2623
LOGGER = setup_logging(__name__)
2724

25+
SOCK_PATH = "unix:///tmp/accumulator_err.sock"
26+
2827

2928
def request_generator(count, request):
3029
for i in range(count):
@@ -58,11 +57,6 @@ def start_request() -> accumulator_pb2.AccumulatorRequest:
5857
return request
5958

6059

61-
_s: Server = None
62-
_channel = grpc.insecure_channel("unix:///tmp/accumulator_err.sock")
63-
_loop = None
64-
65-
6660
def startup_callable(loop):
6761
asyncio.set_event_loop(loop)
6862
loop.run_forever()
@@ -99,77 +93,68 @@ def NewAsyncAccumulatorError():
9993
return udfs
10094

10195

102-
@patch("psutil.Process.kill", mock_terminate_on_stop)
103-
async def start_server(udfs):
96+
async def _start_server(udfs):
10497
server = grpc.aio.server()
10598
accumulator_pb2_grpc.add_AccumulatorServicer_to_server(udfs, server)
106-
listen_addr = "unix:///tmp/accumulator_err.sock"
107-
server.add_insecure_port(listen_addr)
108-
logging.info("Starting server on %s", listen_addr)
109-
global _s
110-
_s = server
99+
server.add_insecure_port(SOCK_PATH)
100+
logging.info("Starting server on %s", SOCK_PATH)
111101
await server.start()
112-
await server.wait_for_termination()
113-
114-
115-
@patch("psutil.Process.kill", mock_terminate_on_stop)
116-
class TestAsyncAccumulatorError(unittest.TestCase):
117-
@classmethod
118-
def setUpClass(cls) -> None:
119-
global _loop
120-
loop = asyncio.new_event_loop()
121-
_loop = loop
122-
_thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True)
123-
_thread.start()
124-
udfs = NewAsyncAccumulatorError()
125-
asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop)
126-
while True:
127-
try:
128-
with grpc.insecure_channel("unix:///tmp/accumulator_err.sock") as channel:
129-
f = grpc.channel_ready_future(channel)
130-
f.result(timeout=10)
131-
if f.done():
132-
break
133-
except grpc.FutureTimeoutError as e:
134-
LOGGER.error("error trying to connect to grpc server")
135-
LOGGER.error(e)
136-
137-
@classmethod
138-
def tearDownClass(cls) -> None:
102+
return server
103+
104+
105+
@pytest.fixture(scope="module")
106+
def async_accumulator_err_server():
107+
"""Module-scoped fixture: starts an async gRPC accumulator error server."""
108+
loop = asyncio.new_event_loop()
109+
thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True)
110+
thread.start()
111+
112+
udfs = NewAsyncAccumulatorError()
113+
future = asyncio.run_coroutine_threadsafe(_start_server(udfs), loop=loop)
114+
future.result(timeout=10)
115+
116+
# Wait for the server to be ready
117+
while True:
139118
try:
140-
_loop.stop()
141-
LOGGER.info("stopped the event loop")
142-
except Exception as e:
119+
with grpc.insecure_channel(SOCK_PATH) as channel:
120+
f = grpc.channel_ready_future(channel)
121+
f.result(timeout=10)
122+
if f.done():
123+
break
124+
except grpc.FutureTimeoutError as e:
125+
LOGGER.error("error trying to connect to grpc server")
143126
LOGGER.error(e)
144127

145-
@patch("psutil.Process.kill", mock_terminate_on_stop)
146-
def test_accumulate_partial_success(self) -> None:
147-
"""Test that the first datum is processed before error occurs"""
148-
stub = self.__stub()
149-
request = start_request()
128+
yield loop
150129

151-
try:
152-
generator_response = stub.AccumulateFn(
153-
request_iterator=request_generator(count=5, request=request)
154-
)
155-
156-
# Try to consume the generator
157-
counter = 0
158-
for response in generator_response:
159-
self.assertIsInstance(response, accumulator_pb2.AccumulatorResponse)
160-
self.assertTrue(response.payload.value.startswith(b"counter:"))
161-
counter += 1
162-
163-
self.assertEqual(counter, 1, "Expected only one successful response before error")
164-
except BaseException as err:
165-
self.assertTrue("Simulated error in accumulator handler" in str(err))
166-
return
167-
self.fail("Expected an exception.")
168-
169-
def __stub(self):
170-
return accumulator_pb2_grpc.AccumulatorStub(_channel)
171-
172-
173-
if __name__ == "__main__":
174-
logging.basicConfig(level=logging.DEBUG)
175-
unittest.main()
130+
loop.stop()
131+
LOGGER.info("stopped the event loop")
132+
133+
134+
@pytest.fixture()
135+
def accumulator_err_stub(async_accumulator_err_server):
136+
"""Returns an AccumulatorStub connected to the running async error server."""
137+
return accumulator_pb2_grpc.AccumulatorStub(grpc.insecure_channel(SOCK_PATH))
138+
139+
140+
def test_accumulate_partial_success(accumulator_err_stub) -> None:
141+
"""Test that the first datum is processed before error occurs"""
142+
request = start_request()
143+
144+
try:
145+
generator_response = accumulator_err_stub.AccumulateFn(
146+
request_iterator=request_generator(count=5, request=request)
147+
)
148+
149+
# Try to consume the generator
150+
counter = 0
151+
for response in generator_response:
152+
assert isinstance(response, accumulator_pb2.AccumulatorResponse)
153+
assert response.payload.value.startswith(b"counter:")
154+
counter += 1
155+
156+
assert counter == 1, "Expected only one successful response before error"
157+
except BaseException as err:
158+
assert "Simulated error in accumulator handler" in str(err)
159+
return
160+
pytest.fail("Expected an exception.")

0 commit comments

Comments
 (0)