Skip to content

Commit 855e3a9

Browse files
authored
Merge branch 'main' into clean-shutdown-unarymap
2 parents 072be7e + b37f25f commit 855e3a9

7 files changed

Lines changed: 263 additions & 46 deletions

File tree

packages/pynumaflow/pynumaflow/accumulator/_dtypes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ class WindowOperation(IntEnum):
1717
Enumerate the type of Window operation received.
1818
"""
1919

20-
OPEN = (0,)
21-
CLOSE = (1,)
22-
APPEND = (2,)
20+
OPEN = 0
21+
CLOSE = 1
22+
APPEND = 2
2323

2424

2525
@dataclass(init=False, slots=True)

packages/pynumaflow/pynumaflow/accumulator/servicer/task_manager.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
_AccumulatorBuilderClass,
1515
AccumulatorAsyncCallable,
1616
WindowOperation,
17+
AccumulatorRequest,
1718
)
1819
from pynumaflow.proto.accumulator import accumulator_pb2
1920
from pynumaflow.shared.asynciter import NonBlockingIterator
@@ -93,7 +94,7 @@ async def stream_send_eof(self):
9394
for unified_key in task_keys:
9495
await self.tasks[unified_key].iterator.put(STREAM_EOF)
9596

96-
async def close_task(self, req):
97+
async def close_task(self, req: AccumulatorRequest):
9798
"""
9899
Closes a running accumulator task for a given key.
99100
Based on the request we compute the unique key, and then
@@ -104,8 +105,9 @@ async def close_task(self, req):
104105
3. Wait for all the results from the task to be written to the global result queue
105106
4. Remove the task from the tracker
106107
"""
107-
d = req.payload
108-
keys = d.keys
108+
# Use keyed_window.keys for task lookup since payload.keys may be empty
109+
# (e.g., CLOSE operations don't carry data, so payload.keys is not populated).
110+
keys = req.keyed_window.keys
109111
unified_key = build_unique_key_name(keys)
110112
curr_task = self.tasks.get(unified_key, None)
111113

@@ -120,14 +122,16 @@ async def close_task(self, req):
120122
# Put the exception in the result queue
121123
await self.global_result_queue.put(err)
122124

123-
async def create_task(self, req):
125+
async def create_task(self, req: AccumulatorRequest):
124126
"""
125127
Creates a new accumulator task for the given request.
126128
Based on the request we compute a unique key, and then
127129
it creates a new task or appends the request to the existing task.
128130
"""
129131
d = req.payload
130-
keys = d.keys
132+
# Use keyed_window.keys for task lookup — the authoritative key identity
133+
# for the window, consistent across all operation types (OPEN, APPEND, CLOSE).
134+
keys = req.keyed_window.keys
131135
unified_key = build_unique_key_name(keys)
132136
curr_task = self.tasks.get(unified_key, None)
133137

@@ -138,7 +142,7 @@ async def create_task(self, req):
138142
# Create a new result queue for the current task
139143
# We create a new result queue for each task, so that
140144
# the results of the accumulator operation can be sent to the
141-
# the global result queue, which in turn sends the results
145+
# global result queue, which in turn sends the results
142146
# to the client.
143147
res_queue = NonBlockingIterator()
144148

@@ -172,13 +176,14 @@ async def create_task(self, req):
172176
# Put the request in the iterator
173177
await curr_task.iterator.put(d)
174178

175-
async def send_datum_to_task(self, req):
179+
async def send_datum_to_task(self, req: AccumulatorRequest):
176180
"""
177181
Appends the request to the existing window reduce task.
178182
If the task does not exist, create it.
179183
"""
180184
d = req.payload
181-
keys = d.keys
185+
# Use keyed_window.keys for task lookup to match the key used in create_task/close_task.
186+
keys = req.keyed_window.keys
182187
unified_key = build_unique_key_name(keys)
183188
result = self.tasks.get(unified_key, None)
184189
if not result:
@@ -215,9 +220,7 @@ async def __invoke_accumulator(
215220
# Put the exception in the result queue
216221
await self.global_result_queue.put(err)
217222

218-
async def process_input_stream(
219-
self, request_iterator: AsyncIterable[accumulator_pb2.AccumulatorRequest]
220-
):
223+
async def process_input_stream(self, request_iterator: AsyncIterable[AccumulatorRequest]):
221224
# Start iterating through the request iterator and create tasks
222225
# based on the operation type received.
223226
try:
@@ -226,15 +229,15 @@ async def process_input_stream(
226229
request_count += 1
227230
# check whether the request is an open, append, or close operation
228231
match request.operation:
229-
case int(WindowOperation.OPEN):
232+
case WindowOperation.OPEN:
230233
# create a new task for the open operation and
231234
# put the request in the task iterator
232235
await self.create_task(request)
233-
case int(WindowOperation.APPEND):
236+
case WindowOperation.APPEND:
234237
# append the task data to the existing task
235238
# if the task does not exist, create a new task
236239
await self.send_datum_to_task(request)
237-
case int(WindowOperation.CLOSE):
240+
case WindowOperation.CLOSE:
238241
# close the current task for req
239242
await self.close_task(request)
240243
case _:

packages/pynumaflow/pynumaflow/reducestreamer/async_server.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
import asyncio
2+
import contextlib
13
import inspect
4+
import sys
25

36
import aiorun
47
import grpc
58

9+
from pynumaflow.info.server import write as info_server_write
610
from pynumaflow.info.types import ServerInfo, ContainerType, MINIMUM_NUMAFLOW_VERSION
711
from pynumaflow.proto.reducer import reduce_pb2_grpc
812

@@ -15,6 +19,7 @@
1519
REDUCE_STREAM_SOCK_PATH,
1620
REDUCE_STREAM_SERVER_INFO_FILE_PATH,
1721
MAX_NUM_THREADS,
22+
NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS,
1823
)
1924

2025
from pynumaflow.reducestreamer._dtypes import (
@@ -23,7 +28,7 @@
2328
ReduceStreamer,
2429
)
2530

26-
from pynumaflow.shared.server import NumaflowServer, check_instance, start_async_server
31+
from pynumaflow.shared.server import NumaflowServer, check_instance
2732

2833

2934
def get_handler(
@@ -156,6 +161,7 @@ def __init__(
156161
]
157162
# Get the servicer instance for the async server
158163
self.servicer = AsyncReduceStreamServicer(self.reduce_stream_handler)
164+
self._error: BaseException | None = None
159165

160166
def start(self):
161167
"""
@@ -166,6 +172,9 @@ def start(self):
166172
"Starting Async Reduce Stream Server",
167173
)
168174
aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback)
175+
if self._error:
176+
_LOGGER.critical("Server exiting due to UDF error: %s", self._error)
177+
sys.exit(1)
169178

170179
async def aexec(self):
171180
"""
@@ -178,15 +187,42 @@ async def aexec(self):
178187
# Create a new async server instance and add the servicer to it
179188
server = grpc.aio.server(options=self._server_options)
180189
server.add_insecure_port(self.sock_path)
190+
191+
# The asyncio.Event must be created here (inside aexec) rather than in __init__,
192+
# because it must be bound to the running event loop that aiorun creates.
193+
shutdown_event = asyncio.Event()
194+
self.servicer.set_shutdown_event(shutdown_event)
195+
181196
reduce_pb2_grpc.add_ReduceServicer_to_server(self.servicer, server)
182197

183198
serv_info = ServerInfo.get_default_server_info()
184199
serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Reducestreamer]
185-
await start_async_server(
186-
server_async=server,
187-
sock_path=self.sock_path,
188-
max_threads=self.max_threads,
189-
cleanup_coroutines=list(),
190-
server_info_file=self.server_info_file,
191-
server_info=serv_info,
200+
201+
await server.start()
202+
info_server_write(server_info=serv_info, info_file=self.server_info_file)
203+
204+
_LOGGER.info(
205+
"Async GRPC Reduce Stream Server listening on: %s with max threads: %s",
206+
self.sock_path,
207+
self.max_threads,
192208
)
209+
210+
async def _watch_for_shutdown():
211+
"""Wait for the shutdown event and stop the server with a grace period."""
212+
await shutdown_event.wait()
213+
_LOGGER.info("Shutdown signal received, stopping server gracefully...")
214+
await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS)
215+
216+
shutdown_task = asyncio.create_task(_watch_for_shutdown())
217+
await server.wait_for_termination()
218+
219+
# Propagate error so start() can exit with a non-zero code
220+
self._error = self.servicer._error
221+
222+
shutdown_task.cancel()
223+
with contextlib.suppress(asyncio.CancelledError):
224+
await shutdown_task
225+
226+
_LOGGER.info("Stopping event loop...")
227+
asyncio.get_event_loop().stop()
228+
_LOGGER.info("Event loop stopped")

packages/pynumaflow/pynumaflow/reducestreamer/servicer/async_servicer.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from google.protobuf import empty_pb2 as _empty_pb2
55

6-
from pynumaflow._constants import ERR_UDF_EXCEPTION_STRING
6+
from pynumaflow._constants import ERR_UDF_EXCEPTION_STRING, _LOGGER
77
from pynumaflow.proto.reducer import reduce_pb2, reduce_pb2_grpc
88
from pynumaflow.reducestreamer._dtypes import (
99
Datum,
@@ -12,7 +12,7 @@
1212
ReduceRequest,
1313
)
1414
from pynumaflow.reducestreamer.servicer.task_manager import TaskManager
15-
from pynumaflow.shared.server import handle_async_error
15+
from pynumaflow.shared.server import update_context_err
1616
from pynumaflow.types import NumaflowServicerContext
1717

1818

@@ -47,6 +47,12 @@ def __init__(
4747
):
4848
# The Reduce handler can be a function or a builder class instance.
4949
self.__reduce_handler: ReduceStreamAsyncCallable | _ReduceStreamBuilderClass = handler
50+
self._shutdown_event: asyncio.Event | None = None
51+
self._error: BaseException | None = None
52+
53+
def set_shutdown_event(self, event: asyncio.Event):
54+
"""Wire up the shutdown event created by the server's aexec() coroutine."""
55+
self._shutdown_event = event
5056

5157
async def ReduceFn(
5258
self,
@@ -94,20 +100,50 @@ async def ReduceFn(
94100
async for msg in consumer:
95101
# If the message is an exception, we raise the exception
96102
if isinstance(msg, BaseException):
97-
await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING)
103+
err_msg = f"ReduceStreamError, {ERR_UDF_EXCEPTION_STRING}: {repr(msg)}"
104+
_LOGGER.critical(err_msg, exc_info=True)
105+
update_context_err(context, msg, err_msg)
106+
self._error = msg
107+
if self._shutdown_event is not None:
108+
self._shutdown_event.set()
98109
return
99110
# Send window EOF response or Window result response
100111
# back to the client
101112
else:
102113
yield msg
114+
except GeneratorExit:
115+
# ReduceFn is an async generator (it yields messages). When Numaflow closes a
116+
# window, gRPC calls .aclose() on this generator, throwing GeneratorExit at
117+
# the yield point. This is normal stream lifecycle — return cleanly.
118+
return
119+
except asyncio.CancelledError:
120+
# SIGTERM: aiorun cancelled all tasks. Signal the server to stop so
121+
# Server.__del__ doesn't try to schedule on a closed event loop.
122+
if self._shutdown_event is not None:
123+
self._shutdown_event.set()
124+
return
103125
except BaseException as e:
104-
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
126+
err_msg = f"ReduceStreamError, {ERR_UDF_EXCEPTION_STRING}: {repr(e)}"
127+
_LOGGER.critical(err_msg, exc_info=True)
128+
update_context_err(context, e, err_msg)
129+
self._error = e
130+
if self._shutdown_event is not None:
131+
self._shutdown_event.set()
105132
return
106133
# Wait for the process_input_stream task to finish for a clean exit
107134
try:
108135
await producer
136+
except asyncio.CancelledError:
137+
if self._shutdown_event is not None:
138+
self._shutdown_event.set()
139+
return
109140
except BaseException as e:
110-
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
141+
err_msg = f"ReduceStreamError, {ERR_UDF_EXCEPTION_STRING}: {repr(e)}"
142+
_LOGGER.critical(err_msg, exc_info=True)
143+
update_context_err(context, e, err_msg)
144+
self._error = e
145+
if self._shutdown_event is not None:
146+
self._shutdown_event.set()
111147
return
112148

113149
async def IsReady(

packages/pynumaflow/pynumaflow/reducestreamer/servicer/task_manager.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ async def __invoke_reduce(
195195
new_instance = self.__reduce_handler.create()
196196
try:
197197
_ = await new_instance(keys, request_iterator, output, md)
198+
except asyncio.CancelledError:
199+
_LOGGER.info("ReduceStream __invoke_reduce cancelled, returning cleanly")
200+
return
198201
# If there is an error in the reduce operation, log and
199202
# then send the error to the result queue
200203
except BaseException as err:
@@ -217,6 +220,9 @@ async def process_input_stream(self, request_iterator: AsyncIterable[reduce_pb2.
217220
# append the task data to the existing task
218221
# if the task does not exist, create a new task
219222
await self.send_datum_to_task(request)
223+
except asyncio.CancelledError:
224+
_LOGGER.info("ReduceStream process_input_stream cancelled, returning cleanly")
225+
return
220226
# If there is an error in the reduce operation, log and
221227
# then send the error to the result queue
222228
except BaseException as e:
@@ -261,6 +267,9 @@ async def process_input_stream(self, request_iterator: AsyncIterable[reduce_pb2.
261267

262268
# Once all tasks are completed, senf EOF the global result queue
263269
await self.global_result_queue.put(STREAM_EOF)
270+
except asyncio.CancelledError:
271+
_LOGGER.info("ReduceStream post-processing cancelled, returning cleanly")
272+
return
264273
except BaseException as e:
265274
err_msg = f"Reduce Streaming Error: {repr(e)}"
266275
_LOGGER.critical(err_msg, exc_info=True)

packages/pynumaflow/tests/accumulator/test_async_accumulator.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030
def request_generator(count, request, resetkey: bool = False, send_close: bool = False):
3131
for i in range(count):
3232
if resetkey:
33-
# Clear previous keys and add new ones
33+
# Update keys on both payload and keyedWindow to match real platform behavior
3434
del request.payload.keys[:]
3535
request.payload.keys.extend([f"key-{i}"])
36+
del request.operation.keyedWindow.keys[:]
37+
request.operation.keyedWindow.keys.extend([f"key-{i}"])
3638

3739
# Set operation based on index - first is OPEN, rest are APPEND
3840
if i == 0:
@@ -52,9 +54,11 @@ def request_generator(count, request, resetkey: bool = False, send_close: bool =
5254
def request_generator_append_only(count, request, resetkey: bool = False):
5355
for i in range(count):
5456
if resetkey:
55-
# Clear previous keys and add new ones
57+
# Update keys on both payload and keyedWindow to match real platform behavior
5658
del request.payload.keys[:]
5759
request.payload.keys.extend([f"key-{i}"])
60+
del request.operation.keyedWindow.keys[:]
61+
request.operation.keyedWindow.keys.extend([f"key-{i}"])
5862

5963
# Set operation to APPEND for all requests
6064
request.operation.event = accumulator_pb2.AccumulatorRequest.WindowOperation.Event.APPEND
@@ -64,9 +68,11 @@ def request_generator_append_only(count, request, resetkey: bool = False):
6468
def request_generator_mixed(count, request, resetkey: bool = False):
6569
for i in range(count):
6670
if resetkey:
67-
# Clear previous keys and add new ones
71+
# Update keys on both payload and keyedWindow to match real platform behavior
6872
del request.payload.keys[:]
6973
request.payload.keys.extend([f"key-{i}"])
74+
del request.operation.keyedWindow.keys[:]
75+
request.operation.keyedWindow.keys.extend([f"key-{i}"])
7076

7177
if i % 2 == 0:
7278
# Set operation to APPEND for even requests
@@ -107,17 +113,26 @@ def start_request() -> accumulator_pb2.AccumulatorRequest:
107113

108114
def start_request_without_open() -> accumulator_pb2.AccumulatorRequest:
109115
event_time_timestamp, watermark_timestamp = get_time_args()
110-
116+
window = accumulator_pb2.KeyedWindow(
117+
start=mock_interval_window_start(),
118+
end=mock_interval_window_end(),
119+
slot="slot-0",
120+
keys=["test_key"],
121+
)
111122
payload = accumulator_pb2.Payload(
112123
keys=["test_key"],
113124
value=mock_message(),
114125
event_time=event_time_timestamp,
115126
watermark=watermark_timestamp,
116127
id="test_id",
117128
)
118-
129+
operation = accumulator_pb2.AccumulatorRequest.WindowOperation(
130+
event=accumulator_pb2.AccumulatorRequest.WindowOperation.Event.APPEND,
131+
keyedWindow=window,
132+
)
119133
request = accumulator_pb2.AccumulatorRequest(
120134
payload=payload,
135+
operation=operation,
121136
)
122137
return request
123138

0 commit comments

Comments
 (0)