Skip to content

Commit f5f73ce

Browse files
committed
more tests, remove unused variables
Signed-off-by: Sreekanth <prsreekanth920@gmail.com>
1 parent abecb26 commit f5f73ce

7 files changed

Lines changed: 323 additions & 8 deletions

File tree

packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,8 @@ class SyncMapServicer(map_pb2_grpc.MapServicer):
2121
Provides the functionality for the required rpc methods.
2222
"""
2323

24-
def __init__(self, handler: MapSyncCallable, multiproc: bool = False):
24+
def __init__(self, handler: MapSyncCallable):
2525
self.__map_handler: MapSyncCallable = handler
26-
# This indicates whether the grpc server attached is multiproc or not
27-
self.multiproc = multiproc
2826
# create a thread pool for executing UDF code
2927
self.executor = ThreadPoolExecutor(max_workers=NUM_THREADS_DEFAULT)
3028
# Graceful shutdown: when set, a watcher thread in _run_server() calls

packages/pynumaflow/pynumaflow/mapper/multiproc_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def handler(self, keys: list[str], datum: Datum) -> Messages:
106106
# Setting the max value to 2 * CPU count
107107
# Used for multiproc server
108108
self._process_count = min(server_count, 2 * _PROCESS_COUNT)
109-
self.servicer = SyncMapServicer(handler=mapper_instance, multiproc=True)
109+
self.servicer = SyncMapServicer(handler=mapper_instance)
110110

111111
# Shared event across all worker processes for coordinated shutdown.
112112
# When any worker's servicer sets this event, all workers' watcher

packages/pynumaflow/pynumaflow/sourcetransformer/multiproc_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def my_handler(keys: list[str], datum: Datum) -> Messages:
131131
# Setting the max value to 2 * CPU count
132132
# Used for multiproc server
133133
self._process_count = min(server_count, 2 * _PROCESS_COUNT)
134-
self.servicer = SourceTransformServicer(handler=source_transform_instance, multiproc=True)
134+
self.servicer = SourceTransformServicer(handler=source_transform_instance)
135135

136136
# Shared event across all worker processes for coordinated shutdown.
137137
# When any worker's servicer sets this event, all workers' watcher

packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,8 @@ class SourceTransformServicer(transform_pb2_grpc.SourceTransformServicer):
4141
Provides the functionality for the required rpc methods.
4242
"""
4343

44-
def __init__(self, handler: SourceTransformCallable, multiproc: bool = False):
44+
def __init__(self, handler: SourceTransformCallable):
4545
self.__transform_handler: SourceTransformCallable = handler
46-
# This indicates whether the grpc server attached is multiproc or not
47-
self.multiproc = multiproc
4846
# create a thread pool for executing UDF code
4947
self.executor = ThreadPoolExecutor(max_workers=NUM_THREADS_DEFAULT)
5048
# Graceful shutdown: when set, a watcher thread in _run_server() calls
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""
2+
Shutdown-event tests for the multiproc Map servicer.
3+
4+
These tests verify that the SyncMapServicer (as used by MapMultiprocServer)
5+
correctly sets shutdown_event on error, enabling coordinated graceful shutdown
6+
across all worker processes via the shared multiprocessing.Event.
7+
"""
8+
9+
from unittest import mock
10+
11+
import grpc
12+
from grpc import StatusCode
13+
from grpc_testing import server_from_dictionary, strict_real_time
14+
15+
from pynumaflow.mapper import MapMultiprocServer
16+
from pynumaflow.proto.mapper import map_pb2
17+
from tests.map.utils import map_handler, err_map_handler, get_test_datums
18+
19+
20+
def test_shutdown_event_set_on_handler_error():
21+
"""When the UDF handler raises, the servicer must signal the shutdown event."""
22+
server = MapMultiprocServer(mapper_instance=err_map_handler)
23+
servicer = server.servicer
24+
25+
services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: servicer}
26+
test_server = server_from_dictionary(services, strict_real_time())
27+
28+
test_datums = get_test_datums(handshake=True)
29+
30+
method = test_server.invoke_stream_stream(
31+
method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]),
32+
invocation_metadata={},
33+
timeout=2,
34+
)
35+
36+
for d in test_datums:
37+
method.send_request(d)
38+
method.requests_closed()
39+
40+
while True:
41+
try:
42+
method.take_response()
43+
except ValueError:
44+
break
45+
46+
_, code, _ = method.termination()
47+
assert code == StatusCode.INTERNAL
48+
assert servicer.shutdown_event.is_set()
49+
assert servicer.error is not None
50+
51+
52+
def test_shutdown_event_set_on_handshake_error():
53+
"""Missing handshake must also signal the shutdown event."""
54+
server = MapMultiprocServer(mapper_instance=map_handler)
55+
servicer = server.servicer
56+
57+
services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: servicer}
58+
test_server = server_from_dictionary(services, strict_real_time())
59+
60+
test_datums = get_test_datums(handshake=False)
61+
62+
method = test_server.invoke_stream_stream(
63+
method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]),
64+
invocation_metadata={},
65+
timeout=1,
66+
)
67+
68+
for d in test_datums:
69+
method.send_request(d)
70+
method.requests_closed()
71+
72+
while True:
73+
try:
74+
method.take_response()
75+
except ValueError:
76+
break
77+
78+
_, code, details = method.termination()
79+
assert code == StatusCode.INTERNAL
80+
assert "MapFn: expected handshake as the first message" in details
81+
assert servicer.shutdown_event.is_set()
82+
assert servicer.error is not None
83+
84+
85+
def test_shutdown_event_set_on_stream_close_before_handshake():
86+
"""grpc.RpcError on the first read (before handshake): shutdown_event set,
87+
result_queue is None so close is skipped."""
88+
server = MapMultiprocServer(mapper_instance=map_handler)
89+
servicer = server.servicer
90+
91+
def _cancelled_iter():
92+
raise grpc.RpcError()
93+
yield # make it a generator
94+
95+
responses = list(servicer.MapFn(_cancelled_iter(), mock.MagicMock()))
96+
97+
assert responses == []
98+
assert servicer.shutdown_event.is_set()
99+
assert servicer.error is None
100+
101+
102+
def test_shutdown_event_set_on_stream_close_mid_processing():
103+
"""grpc.RpcError mid-processing: result_queue is closed (unblocking the handler
104+
thread) and shutdown_event is set."""
105+
server = MapMultiprocServer(mapper_instance=map_handler)
106+
servicer = server.servicer
107+
108+
test_datums = get_test_datums(handshake=True)
109+
110+
def _cancelled_iter():
111+
yield test_datums[0] # handshake
112+
yield test_datums[1] # first data message
113+
raise grpc.RpcError()
114+
115+
responses = list(servicer.MapFn(_cancelled_iter(), mock.MagicMock()))
116+
117+
assert responses[0].handshake.sot
118+
assert servicer.shutdown_event.is_set()
119+
assert servicer.error is None
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
"""
2+
Shutdown-event tests for the SideInput servicer.
3+
4+
Verifies that the servicer sets shutdown_event and captures the error when the
5+
UDF handler raises, enabling graceful server stop via the watcher thread in
6+
_run_server() instead of a hard process kill.
7+
"""
8+
9+
from grpc import StatusCode
10+
from grpc_testing import server_from_dictionary, strict_real_time
11+
from google.protobuf import empty_pb2 as _empty_pb2
12+
13+
from pynumaflow.sideinput.servicer.servicer import SideInputServicer
14+
from pynumaflow.proto.sideinput import sideinput_pb2
15+
16+
17+
def _ok_handler():
18+
from pynumaflow.sideinput import Response
19+
20+
return Response.broadcast_message(b"test")
21+
22+
23+
def _err_handler():
24+
raise RuntimeError("Something is fishy!")
25+
26+
27+
def test_shutdown_event_set_on_handler_error():
28+
"""When the UDF handler raises, the servicer must signal the shutdown event."""
29+
servicer = SideInputServicer(handler=_err_handler)
30+
31+
services = {sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"]: servicer}
32+
test_server = server_from_dictionary(services, strict_real_time())
33+
34+
method = test_server.invoke_unary_unary(
35+
method_descriptor=(
36+
sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"].methods_by_name[
37+
"RetrieveSideInput"
38+
]
39+
),
40+
invocation_metadata={},
41+
request=_empty_pb2.Empty(),
42+
timeout=1,
43+
)
44+
45+
_, _, code, _ = method.termination()
46+
assert code == StatusCode.INTERNAL
47+
assert servicer.shutdown_event.is_set()
48+
assert servicer.error is not None
49+
50+
51+
def test_shutdown_event_not_set_on_success():
52+
"""On a successful call, shutdown_event must remain unset."""
53+
servicer = SideInputServicer(handler=_ok_handler)
54+
55+
services = {sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"]: servicer}
56+
test_server = server_from_dictionary(services, strict_real_time())
57+
58+
method = test_server.invoke_unary_unary(
59+
method_descriptor=(
60+
sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"].methods_by_name[
61+
"RetrieveSideInput"
62+
]
63+
),
64+
invocation_metadata={},
65+
request=_empty_pb2.Empty(),
66+
timeout=1,
67+
)
68+
69+
_, _, code, _ = method.termination()
70+
assert code == StatusCode.OK
71+
assert not servicer.shutdown_event.is_set()
72+
assert servicer.error is None
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""
2+
Shutdown-event tests for the multiproc SourceTransform servicer.
3+
4+
These tests verify that the SourceTransformServicer (as used by
5+
SourceTransformMultiProcServer) correctly sets shutdown_event on error,
6+
enabling coordinated graceful shutdown across all worker processes via
7+
the shared multiprocessing.Event.
8+
"""
9+
10+
from unittest import mock
11+
12+
import grpc
13+
from grpc import StatusCode
14+
from grpc_testing import server_from_dictionary, strict_real_time
15+
16+
from pynumaflow.sourcetransformer.multiproc_server import SourceTransformMultiProcServer
17+
from pynumaflow.proto.sourcetransformer import transform_pb2
18+
from tests.sourcetransform.utils import transform_handler, err_transform_handler, get_test_datums
19+
20+
21+
def test_shutdown_event_set_on_handler_error():
22+
"""When the UDF handler raises, the servicer must signal the shutdown event."""
23+
server = SourceTransformMultiProcServer(source_transform_instance=err_transform_handler)
24+
servicer = server.servicer
25+
26+
services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: servicer}
27+
test_server = server_from_dictionary(services, strict_real_time())
28+
29+
test_datums = get_test_datums(handshake=True)
30+
31+
method = test_server.invoke_stream_stream(
32+
method_descriptor=(
33+
transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[
34+
"SourceTransformFn"
35+
]
36+
),
37+
invocation_metadata={},
38+
timeout=2,
39+
)
40+
41+
for d in test_datums:
42+
method.send_request(d)
43+
method.requests_closed()
44+
45+
while True:
46+
try:
47+
method.take_response()
48+
except ValueError:
49+
break
50+
51+
_, code, _ = method.termination()
52+
assert code == StatusCode.INTERNAL
53+
assert servicer.shutdown_event.is_set()
54+
assert servicer.error is not None
55+
56+
57+
def test_shutdown_event_set_on_handshake_error():
58+
"""Missing handshake must also signal the shutdown event."""
59+
server = SourceTransformMultiProcServer(source_transform_instance=transform_handler)
60+
servicer = server.servicer
61+
62+
services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: servicer}
63+
test_server = server_from_dictionary(services, strict_real_time())
64+
65+
test_datums = get_test_datums(handshake=False)
66+
67+
method = test_server.invoke_stream_stream(
68+
method_descriptor=(
69+
transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[
70+
"SourceTransformFn"
71+
]
72+
),
73+
invocation_metadata={},
74+
timeout=1,
75+
)
76+
77+
for d in test_datums:
78+
method.send_request(d)
79+
method.requests_closed()
80+
81+
while True:
82+
try:
83+
method.take_response()
84+
except ValueError:
85+
break
86+
87+
_, code, details = method.termination()
88+
assert code == StatusCode.INTERNAL
89+
assert "SourceTransformFn: expected handshake message" in details
90+
assert servicer.shutdown_event.is_set()
91+
assert servicer.error is not None
92+
93+
94+
def test_shutdown_event_set_on_stream_close_before_handshake():
95+
"""grpc.RpcError on the first read (before handshake): shutdown_event set,
96+
result_queue is None so close is skipped."""
97+
server = SourceTransformMultiProcServer(source_transform_instance=transform_handler)
98+
servicer = server.servicer
99+
100+
def _cancelled_iter():
101+
raise grpc.RpcError()
102+
yield # make it a generator
103+
104+
responses = list(servicer.SourceTransformFn(_cancelled_iter(), mock.MagicMock()))
105+
106+
assert responses == []
107+
assert servicer.shutdown_event.is_set()
108+
assert servicer.error is None
109+
110+
111+
def test_shutdown_event_set_on_stream_close_mid_processing():
112+
"""grpc.RpcError mid-processing: result_queue is closed (unblocking the handler
113+
thread) and shutdown_event is set."""
114+
server = SourceTransformMultiProcServer(source_transform_instance=transform_handler)
115+
servicer = server.servicer
116+
117+
test_datums = get_test_datums(handshake=True)
118+
119+
def _cancelled_iter():
120+
yield test_datums[0] # handshake
121+
yield test_datums[1] # first data message
122+
raise grpc.RpcError()
123+
124+
responses = list(servicer.SourceTransformFn(_cancelled_iter(), mock.MagicMock()))
125+
126+
assert responses[0].handshake.sot
127+
assert servicer.shutdown_event.is_set()
128+
assert servicer.error is None

0 commit comments

Comments
 (0)