Skip to content

Commit 9949d58

Browse files
committed
feat: parallelized mapstream
Signed-off-by: kohlisid <sidhant.kohli@gmail.com>
1 parent 17e063b commit 9949d58

1 file changed

Lines changed: 94 additions & 41 deletions

File tree

Lines changed: 94 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,27 @@
1+
import asyncio
12
from collections.abc import AsyncIterable
23

34
from google.protobuf import empty_pb2 as _empty_pb2
45

6+
from pynumaflow.shared.asynciter import NonBlockingIterator
7+
from pynumaflow._constants import _LOGGER, STREAM_EOF, ERR_UDF_EXCEPTION_STRING
58
from pynumaflow.mapstreamer import Datum
69
from pynumaflow.mapstreamer._dtypes import MapStreamCallable, MapStreamError
710
from pynumaflow.proto.mapper import map_pb2_grpc, map_pb2
811
from pynumaflow.shared.server import handle_async_error
912
from pynumaflow.types import NumaflowServicerContext
10-
from pynumaflow._constants import _LOGGER, ERR_UDF_EXCEPTION_STRING
1113

1214

1315
class AsyncMapStreamServicer(map_pb2_grpc.MapServicer):
1416
"""
15-
This class is used to create a new grpc Map Stream Servicer instance.
16-
It implements the SyncMapServicer interface from the proto
17-
map_pb2_grpc.py file.
18-
Provides the functionality for the required rpc methods.
17+
Concurrent gRPC Map Stream Servicer.
18+
Spawns one background task per incoming MapRequest; each task streams
19+
results as produced and finally emits an EOT for that request.
1920
"""
2021

21-
def __init__(
22-
self,
23-
handler: MapStreamCallable,
24-
):
22+
def __init__(self, handler: MapStreamCallable):
2523
self.__map_stream_handler: MapStreamCallable = handler
24+
self._background_tasks: set[asyncio.Task] = set()
2625

2726
async def MapFn(
2827
self,
@@ -31,51 +30,105 @@ async def MapFn(
3130
) -> AsyncIterable[map_pb2.MapResponse]:
3231
"""
3332
Applies a map function to a datum stream in streaming mode.
34-
The pascal case function name comes from the proto map_pb2_grpc.py file.
33+
The PascalCase name comes from the generated map_pb2_grpc.py file.
3534
"""
3635
try:
37-
# The first message to be received should be a valid handshake
38-
req = await request_iterator.__anext__()
39-
# check if it is a valid handshake req
40-
if not (req.handshake and req.handshake.sot):
36+
# First message must be a handshake
37+
first = await request_iterator.__anext__()
38+
if not (first.handshake and first.handshake.sot):
4139
raise MapStreamError("MapStreamFn: expected handshake as the first message")
40+
# Acknowledge handshake
4241
yield map_pb2.MapResponse(handshake=map_pb2.Handshake(sot=True))
4342

44-
# read for each input request
45-
async for req in request_iterator:
46-
# yield messages as received from the UDF
47-
async for res in self.__invoke_map_stream(
48-
list(req.request.keys),
49-
Datum(
50-
keys=list(req.request.keys),
51-
value=req.request.value,
52-
event_time=req.request.event_time.ToDatetime(),
53-
watermark=req.request.watermark.ToDatetime(),
54-
headers=dict(req.request.headers),
55-
),
56-
):
57-
yield map_pb2.MapResponse(results=[res], id=req.id)
58-
# send EOT to indicate end of transmission for a given message
59-
yield map_pb2.MapResponse(status=map_pb2.TransmissionStatus(eot=True), id=req.id)
60-
except BaseException as err:
43+
# Global non-blocking queue for outbound responses / errors
44+
global_result_queue = NonBlockingIterator()
45+
46+
# Start producer that turns each inbound request into a background task
47+
producer = asyncio.create_task(
48+
self._process_inputs(request_iterator, global_result_queue)
49+
)
50+
51+
# Consume results as they arrive and stream them to the client
52+
async for msg in global_result_queue.read_iterator():
53+
if isinstance(msg, BaseException):
54+
await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING)
55+
return
56+
else:
57+
# msg is a map_pb2.MapResponse, already formed
58+
yield msg
59+
60+
# Ensure producer has finished (covers graceful shutdown)
61+
await producer
62+
63+
except BaseException as e:
6164
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
62-
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING)
65+
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
6366
return
6467

65-
async def __invoke_map_stream(self, keys: list[str], req: Datum):
68+
async def _process_inputs(
69+
self,
70+
request_iterator: AsyncIterable[map_pb2.MapRequest],
71+
result_queue: NonBlockingIterator,
72+
) -> None:
73+
"""
74+
Reads MapRequests from the client and spawns a background task per request.
75+
Each task streams results to result_queue as they are produced.
76+
"""
6677
try:
67-
# Invoke the user handler for map stream
68-
async for msg in self.__map_stream_handler(keys, req):
69-
yield map_pb2.MapResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags)
78+
async for req in request_iterator:
79+
task = asyncio.create_task(self._invoke_map_stream(req, result_queue))
80+
self._background_tasks.add(task)
81+
# Remove from the set when done to avoid memory growth
82+
task.add_done_callback(self._background_tasks.discard)
83+
84+
# Wait for all in-flight tasks to complete
85+
if self._background_tasks:
86+
await asyncio.gather(*list(self._background_tasks), return_exceptions=False)
87+
88+
# Signal end-of-stream to the consumer
89+
await result_queue.put(STREAM_EOF)
90+
91+
except BaseException as e:
92+
_LOGGER.critical("MapFn Error, re-raising the error", exc_info=True)
93+
# Surface the error to the consumer; MapFn will handle and close the RPC
94+
await result_queue.put(e)
95+
96+
async def _invoke_map_stream(
97+
self,
98+
req: map_pb2.MapRequest,
99+
result_queue: NonBlockingIterator,
100+
) -> None:
101+
"""
102+
Invokes the user-provided async generator for a single request and
103+
pushes each result onto the global queue, followed by an EOT for this id.
104+
"""
105+
try:
106+
datum = Datum(
107+
keys=list(req.request.keys),
108+
value=req.request.value,
109+
event_time=req.request.event_time.ToDatetime(),
110+
watermark=req.request.watermark.ToDatetime(),
111+
headers=dict(req.request.headers),
112+
)
113+
114+
# Stream results from the user handler as they are produced
115+
async for msg in self.__map_stream_handler(list(req.request.keys), datum):
116+
res = map_pb2.MapResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags)
117+
await result_queue.put(map_pb2.MapResponse(results=[res], id=req.id))
118+
119+
# Emit EOT for this request id
120+
await result_queue.put(
121+
map_pb2.MapResponse(status=map_pb2.TransmissionStatus(eot=True), id=req.id)
122+
)
123+
70124
except BaseException as err:
71125
_LOGGER.critical("MapFn handler error", exc_info=True)
72-
raise err
126+
# Surface handler error to the main producer;
127+
# it will call handle_async_error and end the RPC
128+
await result_queue.put(err)
73129

74130
async def IsReady(
75131
self, request: _empty_pb2.Empty, context: NumaflowServicerContext
76132
) -> map_pb2.ReadyResponse:
77-
"""
78-
IsReady is the heartbeat endpoint for gRPC.
79-
The pascal case function name comes from the proto map_pb2_grpc.py file.
80-
"""
133+
"""Heartbeat endpoint for gRPC."""
81134
return map_pb2.ReadyResponse(ready=True)

0 commit comments

Comments
 (0)