Skip to content

Commit 93412af

Browse files
committed
fix: Clean shutdown for Sink Async servers using asyncio.Event
Signed-off-by: Sreekanth <prsreekanth920@gmail.com>
1 parent cba0fdc commit 93412af

3 files changed

Lines changed: 56 additions & 16 deletions

File tree

packages/pynumaflow/pynumaflow/sinker/async_server.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
import asyncio
2+
import contextlib
13
import os
4+
import sys
25

36
import aiorun
47
import grpc
58

9+
from pynumaflow.info.server import write as info_server_write
610
from pynumaflow.info.types import ContainerType, ServerInfo, MINIMUM_NUMAFLOW_VERSION
711
from pynumaflow.sinker.servicer.async_servicer import AsyncSinkServicer
812
from pynumaflow.proto.sinker import sink_pb2_grpc
@@ -24,7 +28,7 @@
2428
MAX_NUM_THREADS,
2529
)
2630

27-
from pynumaflow.shared.server import NumaflowServer, start_async_server
31+
from pynumaflow.shared.server import NumaflowServer
2832
from pynumaflow.sinker._dtypes import SinkAsyncCallable
2933

3034

@@ -118,13 +122,17 @@ def __init__(
118122
]
119123

120124
self.servicer = AsyncSinkServicer(sinker_instance)
125+
self._error: BaseException | None = None
121126

122127
def start(self):
123128
"""
124129
Starter function for the Async server class, need a separate caller
125130
so that all the async coroutines can be started from a single context
126131
"""
127132
aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback)
133+
if self._error:
134+
_LOGGER.critical("Server exiting due to UDF error: %s", self._error)
135+
sys.exit(1)
128136

129137
async def aexec(self):
130138
"""
@@ -133,17 +141,41 @@ async def aexec(self):
133141
# As the server is async, we need to create a new server instance in the
134142
# same thread as the event loop so that all the async calls are made in the
135143
# same context
136-
# Create a new server instance, add the servicer to it and start the server
137144
server = grpc.aio.server(options=self._server_options)
138145
server.add_insecure_port(self.sock_path)
146+
147+
# The asyncio.Event must be created here (inside aexec) rather than in __init__,
148+
# because it must be bound to the running event loop that aiorun creates.
149+
# At __init__ time no event loop exists yet.
150+
shutdown_event = asyncio.Event()
151+
self.servicer.set_shutdown_event(shutdown_event)
152+
139153
sink_pb2_grpc.add_SinkServicer_to_server(self.servicer, server)
154+
140155
serv_info = ServerInfo.get_default_server_info()
141156
serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Sinker]
142-
await start_async_server(
143-
server_async=server,
144-
sock_path=self.sock_path,
145-
max_threads=self.max_threads,
146-
cleanup_coroutines=list(),
147-
server_info_file=self.server_info_file,
148-
server_info=serv_info,
157+
158+
await server.start()
159+
info_server_write(server_info=serv_info, info_file=self.server_info_file)
160+
161+
_LOGGER.info(
162+
"Async GRPC Server listening on: %s with max threads: %s",
163+
self.sock_path,
164+
self.max_threads,
149165
)
166+
167+
async def _watch_for_shutdown():
168+
"""Wait for the shutdown event and stop the server with a grace period."""
169+
await shutdown_event.wait()
170+
_LOGGER.info("Shutdown signal received, stopping server gracefully...")
171+
await server.stop(5)
172+
173+
shutdown_task = asyncio.create_task(_watch_for_shutdown())
174+
await server.wait_for_termination()
175+
176+
# Propagate error so start() can exit with a non-zero code
177+
self._error = self.servicer._error
178+
179+
shutdown_task.cancel()
180+
with contextlib.suppress(asyncio.CancelledError):
181+
await shutdown_task

packages/pynumaflow/pynumaflow/sinker/servicer/async_servicer.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from google.protobuf import empty_pb2 as _empty_pb2
55
from pynumaflow.shared.asynciter import NonBlockingIterator
66

7-
from pynumaflow.shared.server import handle_async_error
7+
from pynumaflow.shared.server import update_context_err
88
from pynumaflow.sinker._dtypes import Datum, SinkAsyncCallable
99
from pynumaflow.proto.sinker import sink_pb2_grpc, sink_pb2
1010
from pynumaflow.sinker.servicer.utils import (
@@ -30,6 +30,12 @@ def __init__(
3030
self.background_tasks = set()
3131
self.__sink_handler: SinkAsyncCallable = handler
3232
self.cleanup_coroutines = []
33+
self._shutdown_event: asyncio.Event | None = None
34+
self._error: BaseException | None = None
35+
36+
def set_shutdown_event(self, event: asyncio.Event):
37+
"""Wire up the shutdown event created by the server's aexec() coroutine."""
38+
self._shutdown_event = event
3339

3440
async def SinkFn(
3541
self,
@@ -82,10 +88,14 @@ async def SinkFn(
8288
datum = datum_from_sink_req(d)
8389
await req_queue.put(datum)
8490
except BaseException as err:
85-
# if there is an exception, we will mark all the responses as a failure
86-
err_msg = f"UDSinkError: {repr(err)}"
91+
err_msg = f"UDSinkError, {ERR_UDF_EXCEPTION_STRING}: {repr(err)}"
8792
_LOGGER.critical(err_msg, exc_info=True)
88-
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING)
93+
update_context_err(context, err, err_msg)
94+
# Store the error and signal the server to shut down gracefully
95+
# instead of killing the process via SIGKILL.
96+
self._error = err
97+
if self._shutdown_event:
98+
self._shutdown_event.set()
8999
return
90100

91101
async def __invoke_sink(

packages/pynumaflow/tests/sink/test_async_sink.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import threading
44
import unittest
55
from collections.abc import AsyncIterable
6-
from unittest.mock import patch
76

87
import grpc
98
from google.protobuf import empty_pb2 as _empty_pb2
@@ -32,7 +31,7 @@
3231
mock_fallback_message,
3332
mockenv,
3433
)
35-
from tests.testing_utils import get_time_args, mock_terminate_on_stop
34+
from tests.testing_utils import get_time_args
3635

3736
LOGGER = setup_logging(__name__)
3837

@@ -128,7 +127,6 @@ async def start_server():
128127

129128

130129
# We are mocking the terminate function from the psutil to not exit the program during testing
131-
@patch("psutil.Process.kill", mock_terminate_on_stop)
132130
class TestAsyncSink(unittest.TestCase):
133131
@classmethod
134132
def setUpClass(cls) -> None:

0 commit comments

Comments
 (0)