Skip to content

Commit c40a1a3

Browse files
committed
fix: Clean shutdown for Sink threaded server using threading.Event
Signed-off-by: Sreekanth <prsreekanth920@gmail.com>
1 parent 23bc5d0 commit c40a1a3

5 files changed

Lines changed: 385 additions & 262 deletions

File tree

packages/pynumaflow/pynumaflow/shared/server.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import multiprocessing
55
import os
66
import socket
7+
import threading
78
import traceback
89

910
from google.protobuf import any_pb2
@@ -18,6 +19,7 @@
1819
from pynumaflow._constants import (
1920
_LOGGER,
2021
MULTIPROC_MAP_SOCK_ADDR,
22+
NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS,
2123
UDFType,
2224
)
2325
from pynumaflow.exceptions import SocketError
@@ -57,6 +59,7 @@ def sync_server_start(
5759
server_options=None,
5860
server_info: ServerInfo | None = None,
5961
udf_type: str = UDFType.Map,
62+
shutdown_event: threading.Event | None = None,
6063
):
6164
"""
6265
Utility function to start a sync grpc server instance.
@@ -75,6 +78,7 @@ def sync_server_start(
7578
udf_type=udf_type,
7679
server_info_file=server_info_file,
7780
server_info=server_info,
81+
shutdown_event=shutdown_event,
7882
)
7983

8084

@@ -86,10 +90,15 @@ def _run_server(
8690
udf_type: str,
8791
server_info_file: str | None = None,
8892
server_info: ServerInfo | None = None,
93+
shutdown_event: threading.Event | None = None,
8994
) -> None:
9095
"""
9196
Starts the Synchronous server instance on the given UNIX socket
9297
with given max threads. Wait for the server to terminate.
98+
99+
If *shutdown_event* is provided, a background daemon thread will wait
100+
on it and then call ``server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS)``
101+
for a cooperative graceful shutdown (no process kill).
93102
"""
94103
server = grpc.server(
95104
ThreadPoolExecutor(
@@ -115,10 +124,21 @@ def _run_server(
115124
server.add_insecure_port(bind_address)
116125
# start the gRPC server
117126
server.start()
127+
118128
# Add the server information to the server info file if provided
119129
if server_info and server_info_file:
120130
info_server_write(server_info=server_info, info_file=server_info_file)
121131

132+
if shutdown_event is not None:
133+
134+
def _watch_for_shutdown():
135+
shutdown_event.wait()
136+
_LOGGER.info("Shutdown signal received, stopping server gracefully...")
137+
server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS)
138+
139+
watcher = threading.Thread(target=_watch_for_shutdown, daemon=True)
140+
watcher.start()
141+
122142
_LOGGER.info("GRPC Server listening on: %s %d", bind_address, os.getpid())
123143
server.wait_for_termination()
124144

packages/pynumaflow/pynumaflow/shared/thread_with_return.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ class ThreadWithReturnValue(Thread):
55
"""
66
A custom Thread class that allows the target function to return a value.
77
This class extends the built-in threading.Thread class.
8+
Exceptions raised by the target are captured and re-raised on join().
89
"""
910

1011
def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, verbose=None):
@@ -23,32 +24,43 @@ def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, verbo
2324
Thread.__init__(self, group, target, name, args, kwargs)
2425
# Variable to store the return value of the target function
2526
self._return = None
27+
self._exception: BaseException | None = None
2628

2729
def run(self):
2830
"""
2931
Run the thread.
3032
3133
This method is overridden from the Thread class.
3234
It calls the target function and saves the return value.
35+
If the target raises, the exception is captured for re-raising on join().
3336
"""
3437
if self._target is not None:
35-
# Execute target and store the result
36-
self._return = self._target(*self._args, **self._kwargs)
38+
try:
39+
# Execute target and store the result
40+
self._return = self._target(*self._args, **self._kwargs)
41+
except BaseException as exc:
42+
self._exception = exc
3743

3844
def join(self, *args):
3945
"""
4046
Wait for the thread to complete and return the result.
4147
4248
This method is overridden from the Thread class.
43-
It calls the parent class's join() method and then returns the stored return value.
49+
It calls the parent class's join() method, re-raises any captured
50+
exception, and then returns the stored return value.
4451
4552
Parameters:
4653
*args: Variable length argument list to pass to the join() method.
4754
4855
Returns:
4956
The return value from the target function.
57+
58+
Raises:
59+
BaseException: If the target function raised during run().
5060
"""
5161
# Call the parent class's join() method to wait for the thread to finish
5262
Thread.join(self, *args)
63+
if self._exception is not None:
64+
raise self._exception
5365
# Return the result of the target function
5466
return self._return

packages/pynumaflow/pynumaflow/sinker/server.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import sys
23

34
from pynumaflow.info.types import ServerInfo, ContainerType, MINIMUM_NUMAFLOW_VERSION
45
from pynumaflow.sinker.servicer.sync_servicer import SyncSinkServicer
@@ -120,6 +121,7 @@ def start(self):
120121
)
121122
serv_info = ServerInfo.get_default_server_info()
122123
serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Sinker]
124+
123125
# Start the server
124126
sync_server_start(
125127
servicer=self.servicer,
@@ -129,4 +131,9 @@ def start(self):
129131
server_options=self._server_options,
130132
udf_type=UDFType.Sink,
131133
server_info=serv_info,
134+
shutdown_event=self.servicer._shutdown_event,
132135
)
136+
137+
if self.servicer._error:
138+
_LOGGER.critical("Server exiting due to UDF error: %s", self.servicer._error)
139+
sys.exit(1)

packages/pynumaflow/pynumaflow/sinker/servicer/sync_servicer.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import threading
12
from collections.abc import Iterator
23

34

4-
from pynumaflow._constants import _LOGGER, STREAM_EOF
5+
from pynumaflow._constants import _LOGGER, STREAM_EOF, ERR_UDF_EXCEPTION_STRING
56
from pynumaflow.proto.sinker import sink_pb2_grpc, sink_pb2
6-
from pynumaflow.shared.server import exit_on_error
7+
from pynumaflow.shared.server import update_context_err
78
from pynumaflow.shared.synciter import SyncIterator
89
from pynumaflow.shared.thread_with_return import ThreadWithReturnValue
910
from pynumaflow.sinker._dtypes import SinkSyncCallable
@@ -24,6 +25,8 @@ class SyncSinkServicer(sink_pb2_grpc.SinkServicer):
2425

2526
def __init__(self, handler: SinkSyncCallable):
2627
self.handler: SinkSyncCallable = handler
28+
self._shutdown_event: threading.Event = threading.Event()
29+
self._error: BaseException | None = None
2730

2831
def SinkFn(
2932
self, request_iterator: Iterator[sink_pb2.SinkRequest], context: NumaflowServicerContext
@@ -79,10 +82,11 @@ def SinkFn(
7982
cur_task.join()
8083

8184
except BaseException as err:
82-
# Handle exceptions
83-
err_msg = f"UDSinkError: {repr(err)}"
85+
err_msg = f"UDSinkError, {ERR_UDF_EXCEPTION_STRING}: {repr(err)}"
8486
_LOGGER.critical(err_msg, exc_info=True)
85-
exit_on_error(context, err_msg)
87+
update_context_err(context, err, err_msg)
88+
self._error = err
89+
self._shutdown_event.set()
8690
return
8791

8892
def _invoke_sink(self, request_queue: SyncIterator, context: NumaflowServicerContext):
@@ -93,7 +97,6 @@ def _invoke_sink(self, request_queue: SyncIterator, context: NumaflowServicerCon
9397
except BaseException as err:
9498
err_msg = f"UDSinkError: {repr(err)}"
9599
_LOGGER.critical(err_msg, exc_info=True)
96-
exit_on_error(context, err_msg)
97100
raise err
98101

99102
def IsReady(self, request, context: NumaflowServicerContext) -> sink_pb2.ReadyResponse:

0 commit comments

Comments
 (0)