Skip to content

Commit bb27dc8

Browse files
committed
Use async methods for tests, avoid sigkill when running single module tests
Signed-off-by: Sreekanth <prsreekanth920@gmail.com>
1 parent 2f022e4 commit bb27dc8

1 file changed

Lines changed: 90 additions & 125 deletions

File tree

tests/source/test_async_source_err.py

Lines changed: 90 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
22
import logging
3-
import threading
43
import unittest
54
from unittest.mock import patch
65

@@ -23,156 +22,122 @@
2322

2423
LOGGER = setup_logging(__name__)
2524

26-
_s: Server = None
27-
server_port = "unix:///tmp/async_err_source.sock"
28-
_channel = grpc.insecure_channel(server_port)
29-
_loop = None
25+
# Mock the handle_async_error function to prevent process SIGKILL
26+
async def mock_handle_async_error(context, exception, exception_type):
27+
"""Mock handle_async_error to prevent process termination during tests."""
28+
from pynumaflow.shared.server import update_context_err
3029

30+
err_msg = f"{exception_type}: {repr(exception)}"
31+
update_context_err(context, exception, err_msg)
3132

32-
def startup_callable(loop):
33-
asyncio.set_event_loop(loop)
34-
loop.run_forever()
3533

36-
37-
async def start_server():
38-
server = grpc.aio.server()
39-
class_instance = AsyncSourceError()
40-
server_instance = SourceAsyncServer(sourcer_instance=class_instance)
41-
udfs = server_instance.servicer
42-
source_pb2_grpc.add_SourceServicer_to_server(udfs, server)
34+
# We are mocking the error handler to not exit the program during testing
35+
@patch("pynumaflow.sourcer.servicer.async_servicer.handle_async_error", mock_handle_async_error)
36+
class TestAsyncServerErrorScenario(unittest.IsolatedAsyncioTestCase):
4337
listen_addr = "unix:///tmp/async_err_source.sock"
44-
server.add_insecure_port(listen_addr)
45-
logging.info("Starting server on %s", listen_addr)
46-
global _s
47-
_s = server
48-
await server.start()
49-
await server.wait_for_termination()
50-
51-
52-
# We are mocking the terminate function from the psutil to not exit the program during testing
53-
@patch("psutil.Process.kill", mock_terminate_on_stop)
54-
class TestAsyncServerErrorScenario(unittest.TestCase):
55-
@classmethod
56-
def setUpClass(cls) -> None:
57-
global _loop
58-
loop = asyncio.new_event_loop()
59-
_loop = loop
60-
_thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True)
61-
_thread.start()
62-
asyncio.run_coroutine_threadsafe(start_server(), loop=loop)
63-
while True:
38+
39+
async def asyncSetUp(self) -> None:
40+
# Create a fresh server for each test to avoid event loop issues
41+
self.server = grpc.aio.server()
42+
class_instance = AsyncSourceError()
43+
server_instance = SourceAsyncServer(sourcer_instance=class_instance)
44+
udfs = server_instance.servicer
45+
source_pb2_grpc.add_SourceServicer_to_server(udfs, self.server)
46+
self.server.add_insecure_port(self.listen_addr)
47+
await self.server.start()
48+
49+
# Wait for server to be ready with timeout
50+
max_attempts = 50 # 5 seconds total (50 * 0.1)
51+
attempt = 0
52+
while attempt < max_attempts:
6453
try:
65-
with grpc.insecure_channel("unix:///tmp/async_err_source.sock") as channel:
66-
f = grpc.channel_ready_future(channel)
67-
f.result(timeout=10)
68-
if f.done():
69-
break
70-
except grpc.FutureTimeoutError as e:
71-
LOGGER.error("error trying to connect to grpc server")
72-
LOGGER.error(e)
73-
74-
@classmethod
75-
def tearDownClass(cls) -> None:
76-
try:
77-
_loop.stop()
78-
LOGGER.info("stopped the event loop")
79-
except Exception as e:
80-
LOGGER.error(e)
81-
82-
def test_read_error(self) -> None:
83-
with grpc.insecure_channel(server_port) as channel:
54+
async with grpc.aio.insecure_channel(self.listen_addr) as channel:
55+
await channel.channel_ready()
56+
break
57+
except Exception as e:
58+
LOGGER.debug("Waiting for server to be ready, attempt %d", attempt + 1)
59+
await asyncio.sleep(0.1)
60+
attempt += 1
61+
62+
if attempt >= max_attempts:
63+
raise RuntimeError("Server failed to start within timeout period")
64+
65+
async def asyncTearDown(self) -> None:
66+
# Stop the server after each test
67+
if hasattr(self, 'server') and self.server is not None:
68+
await self.server.stop(0)
69+
# Small delay to ensure socket cleanup
70+
await asyncio.sleep(0.1)
71+
72+
async def test_read_error(self) -> None:
73+
async with grpc.aio.insecure_channel(self.listen_addr) as channel:
8474
stub = source_pb2_grpc.SourceStub(channel)
8575
request = read_req_source_fn()
86-
generator_response = None
87-
try:
76+
with self.assertRaises(grpc.RpcError) as resp_err:
8877
generator_response = stub.ReadFn(
8978
request_iterator=request_generator(1, request, "read")
9079
)
91-
for _ in generator_response:
92-
pass
93-
except grpc.RpcError as e:
94-
self.assertEqual(grpc.StatusCode.INTERNAL, e.code())
95-
self.assertTrue("Got a runtime error from read handler." in e.details())
96-
return
97-
98-
self.fail("Expected an exception.")
99-
100-
def test_read_handshake_error(self) -> None:
101-
grpc_exception = None
102-
with grpc.insecure_channel(server_port) as channel:
80+
# await anext(aiter(generator_response)) should work for Python >=3.10
81+
[_ async for _ in generator_response]
82+
self.assertEqual(grpc.StatusCode.INTERNAL, resp_err.exception.code())
83+
self.assertTrue("Got a runtime error from read handler." in resp_err.exception.details())
84+
85+
86+
async def test_read_handshake_error(self) -> None:
87+
async with grpc.aio.insecure_channel(self.listen_addr) as channel:
10388
stub = source_pb2_grpc.SourceStub(channel)
10489
request = read_req_source_fn()
105-
generator_response = None
106-
try:
90+
with self.assertRaises(grpc.RpcError) as resp_err:
10791
generator_response = stub.ReadFn(
10892
request_iterator=request_generator(1, request, "read", False)
10993
)
110-
for _ in generator_response:
111-
pass
112-
except BaseException as e:
113-
self.assertTrue("ReadFn: expected handshake message" in e.__str__())
114-
return
115-
except grpc.RpcError as e:
116-
grpc_exception = e
117-
self.assertEqual(grpc.StatusCode.UNKNOWN, e.code())
118-
print(e.details())
119-
120-
self.assertIsNotNone(grpc_exception)
121-
self.fail("Expected an exception.")
122-
123-
def test_ack_error(self) -> None:
124-
with grpc.insecure_channel(server_port) as channel:
94+
# await anext(aiter(generator_response)) should work for Python >=3.10
95+
[_ async for _ in generator_response]
96+
self.assertEqual(grpc.StatusCode.INTERNAL, resp_err.exception.code())
97+
self.assertTrue("ReadFn: expected handshake message" in resp_err.exception.details())
98+
99+
100+
async def test_ack_error(self) -> None:
101+
async with grpc.aio.insecure_channel(self.listen_addr) as channel:
125102
stub = source_pb2_grpc.SourceStub(channel)
126103
request = ack_req_source_fn()
127-
try:
104+
with self.assertRaises(grpc.RpcError) as resp_err:
128105
resp = stub.AckFn(request_iterator=request_generator(1, request, "ack"))
129-
for _ in resp:
130-
pass
131-
except grpc.RpcError as e:
132-
self.assertEqual(grpc.StatusCode.INTERNAL, e.code())
133-
self.assertTrue("Got a runtime error from ack handler." in e.details())
134-
return
135-
self.fail("Expected an exception.")
136-
137-
def test_ack_no_handshake_error(self) -> None:
138-
with grpc.insecure_channel(server_port) as channel:
106+
[_ async for _ in resp]
107+
self.assertEqual(grpc.StatusCode.INTERNAL, resp_err.exception.code())
108+
self.assertTrue("Got a runtime error from ack handler." in resp_err.exception.details())
109+
110+
async def test_ack_no_handshake_error(self) -> None:
111+
async with grpc.aio.insecure_channel(self.listen_addr) as channel:
139112
stub = source_pb2_grpc.SourceStub(channel)
140113
request = ack_req_source_fn()
141-
try:
114+
with self.assertRaises(grpc.RpcError) as resp_err:
142115
resp = stub.AckFn(request_iterator=request_generator(1, request, "ack", False))
143-
for _ in resp:
144-
pass
145-
except BaseException as e:
146-
self.assertTrue("AckFn: expected handshake message" in e.__str__())
147-
return
148-
except grpc.RpcError as e:
149-
self.assertEqual(grpc.StatusCode.UNKNOWN, e.code())
150-
print(e.details())
151-
self.fail("Expected an exception.")
152-
153-
def test_pending_error(self) -> None:
154-
with grpc.insecure_channel(server_port) as channel:
116+
[_ async for _ in resp]
117+
self.assertEqual(grpc.StatusCode.INTERNAL, resp_err.exception.code())
118+
self.assertTrue("AckFn: expected handshake message" in resp_err.exception.details())
119+
120+
async def test_pending_error(self) -> None:
121+
async with grpc.aio.insecure_channel(self.listen_addr) as channel:
155122
stub = source_pb2_grpc.SourceStub(channel)
156123
request = _empty_pb2.Empty()
157-
try:
158-
stub.PendingFn(request=request)
159-
except Exception as e:
160-
self.assertTrue("Got a runtime error from pending handler." in e.__str__())
161-
return
162-
self.fail("Expected an exception.")
124+
with self.assertRaises(grpc.RpcError) as resp_err:
125+
await stub.PendingFn(request=request)
126+
self.assertEqual(grpc.StatusCode.INTERNAL, resp_err.exception.code())
127+
self.assertTrue("Got a runtime error from pending handler." in resp_err.exception.details())
163128

164-
def test_partition_error(self) -> None:
165-
with grpc.insecure_channel(server_port) as channel:
129+
async def test_partition_error(self) -> None:
130+
async with grpc.aio.insecure_channel(self.listen_addr) as channel:
166131
stub = source_pb2_grpc.SourceStub(channel)
167132
request = _empty_pb2.Empty()
168-
try:
169-
stub.PartitionsFn(request=request)
170-
except Exception as e:
171-
self.assertTrue("Got a runtime error from partition handler." in e.__str__())
172-
return
173-
self.fail("Expected an exception.")
174-
175-
def test_invalid_server_type(self) -> None:
133+
with self.assertRaises(grpc.RpcError) as resp_err:
134+
response = await stub.PartitionsFn(request=request)
135+
# Force evaluation of the response
136+
_ = response.result
137+
self.assertEqual(grpc.StatusCode.INTERNAL, resp_err.exception.code())
138+
self.assertTrue("Got a runtime error from partition handler." in resp_err.exception.details())
139+
140+
async def test_invalid_server_type(self) -> None:
176141
with self.assertRaises(TypeError):
177142
SourceAsyncServer()
178143

0 commit comments

Comments
 (0)