Skip to content

Commit 388ec38

Browse files
committed
graceful shutdown for all UDFs
Signed-off-by: Sreekanth <prsreekanth920@gmail.com>
1 parent bcdc4c3 commit 388ec38

25 files changed

Lines changed: 795 additions & 209 deletions

File tree

packages/pynumaflow/pynumaflow/accumulator/async_server.py

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
import asyncio
2+
import contextlib
13
import inspect
4+
import sys
25

36
import aiorun
47
import grpc
58

69
from pynumaflow.accumulator.servicer.async_servicer import AsyncAccumulatorServicer
10+
from pynumaflow.info.server import write as info_server_write
711
from pynumaflow.info.types import ServerInfo, ContainerType, MINIMUM_NUMAFLOW_VERSION
812
from pynumaflow.proto.accumulator import accumulator_pb2_grpc
913

@@ -15,6 +19,7 @@
1519
MAX_NUM_THREADS,
1620
ACCUMULATOR_SOCK_PATH,
1721
ACCUMULATOR_SERVER_INFO_FILE_PATH,
22+
NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS,
1823
)
1924

2025
from pynumaflow.accumulator._dtypes import (
@@ -23,7 +28,7 @@
2328
Accumulator,
2429
)
2530

26-
from pynumaflow.shared.server import NumaflowServer, check_instance, start_async_server
31+
from pynumaflow.shared.server import NumaflowServer, check_instance
2732

2833

2934
def get_handler(
@@ -157,6 +162,7 @@ def __init__(
157162
]
158163
# Get the servicer instance for the async server
159164
self.servicer = AsyncAccumulatorServicer(self.accumulator_handler)
165+
self._error: BaseException | None = None
160166

161167
def start(self):
162168
"""
@@ -167,6 +173,9 @@ def start(self):
167173
"Starting Async Accumulator Server",
168174
)
169175
aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback)
176+
if self._error:
177+
_LOGGER.critical("Server exiting due to UDF error: %s", self._error)
178+
sys.exit(1)
170179

171180
async def aexec(self):
172181
"""
@@ -176,18 +185,52 @@ async def aexec(self):
176185
# As the server is async, we need to create a new server instance in the
177186
# same thread as the event loop so that all the async calls are made in the
178187
# same context
179-
# Create a new async server instance and add the servicer to it
180188
server = grpc.aio.server(options=self._server_options)
181189
server.add_insecure_port(self.sock_path)
190+
191+
# The asyncio.Event must be created here (inside aexec) rather than in __init__,
192+
# because it must be bound to the running event loop that aiorun creates.
193+
# At __init__ time no event loop exists yet.
194+
shutdown_event = asyncio.Event()
195+
self.servicer.set_shutdown_event(shutdown_event)
196+
182197
accumulator_pb2_grpc.add_AccumulatorServicer_to_server(self.servicer, server)
183198

184199
serv_info = ServerInfo.get_default_server_info()
185200
serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Accumulator]
186-
await start_async_server(
187-
server_async=server,
188-
sock_path=self.sock_path,
189-
max_threads=self.max_threads,
190-
cleanup_coroutines=list(),
191-
server_info_file=self.server_info_file,
192-
server_info=serv_info,
201+
202+
await server.start()
203+
info_server_write(server_info=serv_info, info_file=self.server_info_file)
204+
205+
_LOGGER.info(
206+
"Async GRPC Server listening on: %s with max threads: %s",
207+
self.sock_path,
208+
self.max_threads,
193209
)
210+
211+
async def _watch_for_shutdown():
212+
"""Wait for the shutdown event and stop the server with a grace period."""
213+
await shutdown_event.wait()
214+
_LOGGER.info("Shutdown signal received, stopping server gracefully...")
215+
# Stop accepting new requests and wait for a maximum of
216+
# NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete
217+
await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS)
218+
219+
shutdown_task = asyncio.create_task(_watch_for_shutdown())
220+
await server.wait_for_termination()
221+
222+
# Propagate error so start() can exit with a non-zero code
223+
self._error = self.servicer._error
224+
225+
shutdown_task.cancel()
226+
with contextlib.suppress(asyncio.CancelledError):
227+
await shutdown_task
228+
229+
_LOGGER.info("Stopping event loop...")
230+
# We use aiorun to manage the event loop. The aiorun.run() runs
231+
# forever until loop.stop() is called. If we don't stop the
232+
# event loop explicitly here, the python process will not exit.
233+
# It reamins stuck for 5 minutes until liveness and readiness probe
234+
# fails enough times and k8s sends a SIGTERM
235+
asyncio.get_event_loop().stop()
236+
_LOGGER.info("Event loop stopped")

packages/pynumaflow/pynumaflow/accumulator/servicer/async_servicer.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from google.protobuf import empty_pb2 as _empty_pb2
55

6-
from pynumaflow._constants import ERR_UDF_EXCEPTION_STRING
6+
from pynumaflow._constants import _LOGGER, ERR_UDF_EXCEPTION_STRING
77
from pynumaflow.proto.accumulator import accumulator_pb2, accumulator_pb2_grpc
88
from pynumaflow.accumulator._dtypes import (
99
Datum,
@@ -13,7 +13,7 @@
1313
KeyedWindow,
1414
)
1515
from pynumaflow.accumulator.servicer.task_manager import TaskManager
16-
from pynumaflow.shared.server import handle_async_error
16+
from pynumaflow.shared.server import update_context_err
1717
from pynumaflow.types import NumaflowServicerContext
1818

1919

@@ -57,6 +57,12 @@ def __init__(
5757
):
5858
# The accumulator handler can be a function or a builder class instance.
5959
self.__accumulator_handler: AccumulatorAsyncCallable | _AccumulatorBuilderClass = handler
60+
self._shutdown_event: asyncio.Event | None = None
61+
self._error: BaseException | None = None
62+
63+
def set_shutdown_event(self, event: asyncio.Event):
64+
"""Wire up the shutdown event created by the server's aexec() coroutine."""
65+
self._shutdown_event = event
6066

6167
async def AccumulateFn(
6268
self,
@@ -104,20 +110,35 @@ async def AccumulateFn(
104110
async for msg in consumer:
105111
# If the message is an exception, we raise the exception
106112
if isinstance(msg, BaseException):
107-
await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING)
113+
err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(msg)}"
114+
_LOGGER.critical(err_msg, exc_info=True)
115+
update_context_err(context, msg, err_msg)
116+
self._error = msg
117+
if self._shutdown_event is not None:
118+
self._shutdown_event.set()
108119
return
109120
# Send window EOF response or Window result response
110121
# back to the client
111122
else:
112123
yield msg
113124
except BaseException as e:
114-
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
125+
err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}"
126+
_LOGGER.critical(err_msg, exc_info=True)
127+
update_context_err(context, e, err_msg)
128+
self._error = e
129+
if self._shutdown_event is not None:
130+
self._shutdown_event.set()
115131
return
116132
# Wait for the process_input_stream task to finish for a clean exit
117133
try:
118134
await producer
119135
except BaseException as e:
120-
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
136+
err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}"
137+
_LOGGER.critical(err_msg, exc_info=True)
138+
update_context_err(context, e, err_msg)
139+
self._error = e
140+
if self._shutdown_event is not None:
141+
self._shutdown_event.set()
121142
return
122143

123144
async def IsReady(

packages/pynumaflow/pynumaflow/batchmapper/async_server.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import asyncio
2+
import contextlib
3+
import sys
4+
15
import aiorun
26
import grpc
37

@@ -8,9 +12,11 @@
812
BATCH_MAP_SOCK_PATH,
913
MAP_SERVER_INFO_FILE_PATH,
1014
MAX_NUM_THREADS,
15+
NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS,
1116
)
1217
from pynumaflow.batchmapper._dtypes import BatchMapCallable
1318
from pynumaflow.batchmapper.servicer.async_servicer import AsyncBatchMapServicer
19+
from pynumaflow.info.server import write as info_server_write
1420
from pynumaflow.info.types import (
1521
ServerInfo,
1622
MAP_MODE_KEY,
@@ -19,7 +25,7 @@
1925
ContainerType,
2026
)
2127
from pynumaflow.proto.mapper import map_pb2_grpc
22-
from pynumaflow.shared.server import NumaflowServer, start_async_server
28+
from pynumaflow.shared.server import NumaflowServer
2329

2430

2531
class BatchMapAsyncServer(NumaflowServer):
@@ -92,13 +98,17 @@ async def handler(
9298
]
9399

94100
self.servicer = AsyncBatchMapServicer(handler=self.batch_mapper_instance)
101+
self._error: BaseException | None = None
95102

96103
def start(self):
97104
"""
98105
Starter function for the Async Batch Map server, we need a separate caller
99106
to the aexec so that all the async coroutines can be started from a single context
100107
"""
101108
aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback)
109+
if self._error:
110+
_LOGGER.critical("Server exiting due to UDF error: %s", self._error)
111+
sys.exit(1)
102112

103113
async def aexec(self):
104114
"""
@@ -108,25 +118,54 @@ async def aexec(self):
108118
# As the server is async, we need to create a new server instance in the
109119
# same thread as the event loop so that all the async calls are made in the
110120
# same context
111-
# Create a new async server instance and add the servicer to it
112121
server = grpc.aio.server(options=self._server_options)
113122
server.add_insecure_port(self.sock_path)
114-
map_pb2_grpc.add_MapServicer_to_server(
115-
self.servicer,
116-
server,
117-
)
118-
_LOGGER.info("Starting Batch Map Server")
123+
124+
# The asyncio.Event must be created here (inside aexec) rather than in __init__,
125+
# because it must be bound to the running event loop that aiorun creates.
126+
# At __init__ time no event loop exists yet.
127+
shutdown_event = asyncio.Event()
128+
self.servicer.set_shutdown_event(shutdown_event)
129+
130+
map_pb2_grpc.add_MapServicer_to_server(self.servicer, server)
131+
119132
serv_info = ServerInfo.get_default_server_info()
120133
serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Mapper]
121134
# Add the MAP_MODE metadata to the server info for the correct map mode
122135
serv_info.metadata[MAP_MODE_KEY] = MapMode.BatchMap
123136

124-
# Start the async server
125-
await start_async_server(
126-
server_async=server,
127-
sock_path=self.sock_path,
128-
max_threads=self.max_threads,
129-
cleanup_coroutines=list(),
130-
server_info_file=self.server_info_file,
131-
server_info=serv_info,
137+
await server.start()
138+
info_server_write(server_info=serv_info, info_file=self.server_info_file)
139+
140+
_LOGGER.info(
141+
"Async GRPC Server listening on: %s with max threads: %s",
142+
self.sock_path,
143+
self.max_threads,
132144
)
145+
146+
async def _watch_for_shutdown():
147+
"""Wait for the shutdown event and stop the server with a grace period."""
148+
await shutdown_event.wait()
149+
_LOGGER.info("Shutdown signal received, stopping server gracefully...")
150+
# Stop accepting new requests and wait for a maximum of
151+
# NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete
152+
await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS)
153+
154+
shutdown_task = asyncio.create_task(_watch_for_shutdown())
155+
await server.wait_for_termination()
156+
157+
# Propagate error so start() can exit with a non-zero code
158+
self._error = self.servicer._error
159+
160+
shutdown_task.cancel()
161+
with contextlib.suppress(asyncio.CancelledError):
162+
await shutdown_task
163+
164+
_LOGGER.info("Stopping event loop...")
165+
# We use aiorun to manage the event loop. The aiorun.run() runs
166+
# forever until loop.stop() is called. If we don't stop the
167+
# event loop explicitly here, the python process will not exit.
168+
# It reamins stuck for 5 minutes until liveness and readiness probe
169+
# fails enough times and k8s sends a SIGTERM
170+
asyncio.get_event_loop().stop()
171+
_LOGGER.info("Event loop stopped")

packages/pynumaflow/pynumaflow/batchmapper/servicer/async_servicer.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pynumaflow.batchmapper._dtypes import BatchMapCallable, BatchMapError
88
from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc
99
from pynumaflow.shared.asynciter import NonBlockingIterator
10-
from pynumaflow.shared.server import handle_async_error
10+
from pynumaflow.shared.server import update_context_err
1111
from pynumaflow.types import NumaflowServicerContext
1212
from pynumaflow._constants import _LOGGER, STREAM_EOF, ERR_UDF_EXCEPTION_STRING
1313

@@ -26,6 +26,12 @@ def __init__(
2626
):
2727
self.background_tasks = set()
2828
self.__batch_map_handler: BatchMapCallable = handler
29+
self._shutdown_event: asyncio.Event | None = None
30+
self._error: BaseException | None = None
31+
32+
def set_shutdown_event(self, event: asyncio.Event):
33+
"""Wire up the shutdown event created by the server's aexec() coroutine."""
34+
self._shutdown_event = event
2935

3036
async def MapFn(
3137
self,
@@ -97,8 +103,12 @@ async def MapFn(
97103
await req_queue.put(datum)
98104

99105
except BaseException as err:
100-
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
101-
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING)
106+
err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}"
107+
_LOGGER.critical(err_msg, exc_info=True)
108+
update_context_err(context, err, err_msg)
109+
self._error = err
110+
if self._shutdown_event is not None:
111+
self._shutdown_event.set()
102112
return
103113

104114
async def IsReady(

packages/pynumaflow/pynumaflow/mapper/_servicer/_async_servicer.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pynumaflow.mapper._dtypes import MapAsyncCallable, Datum, MapError, Message, Messages
99
from pynumaflow._metadata import _user_and_system_metadata_from_proto
1010
from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc
11-
from pynumaflow.shared.server import handle_async_error
11+
from pynumaflow.shared.server import update_context_err
1212
from pynumaflow.types import NumaflowServicerContext
1313

1414

@@ -25,6 +25,12 @@ def __init__(
2525
):
2626
self.background_tasks = set()
2727
self.__map_handler: MapAsyncCallable = handler
28+
self._shutdown_event: asyncio.Event | None = None
29+
self._error: BaseException | None = None
30+
31+
def set_shutdown_event(self, event: asyncio.Event):
32+
"""Wire up the shutdown event created by the server's aexec() coroutine."""
33+
self._shutdown_event = event
2834

2935
async def MapFn(
3036
self,
@@ -57,16 +63,25 @@ async def MapFn(
5763
async for msg in consumer:
5864
# If the message is an exception, we raise the exception
5965
if isinstance(msg, BaseException):
60-
await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING)
66+
err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(msg)}"
67+
_LOGGER.critical(err_msg, exc_info=True)
68+
update_context_err(context, msg, err_msg)
69+
self._error = msg
70+
if self._shutdown_event is not None:
71+
self._shutdown_event.set()
6172
return
6273
# Send window response back to the client
6374
else:
6475
yield msg
6576
# wait for the producer task to complete
6677
await producer
6778
except BaseException as e:
68-
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
69-
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
79+
err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}"
80+
_LOGGER.critical(err_msg, exc_info=True)
81+
update_context_err(context, e, err_msg)
82+
self._error = e
83+
if self._shutdown_event is not None:
84+
self._shutdown_event.set()
7085
return
7186

7287
async def _process_inputs(

0 commit comments

Comments
 (0)