Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 94 additions & 41 deletions pynumaflow/mapstreamer/servicer/async_servicer.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,27 @@
import asyncio
from collections.abc import AsyncIterable

from google.protobuf import empty_pb2 as _empty_pb2

from pynumaflow.shared.asynciter import NonBlockingIterator
from pynumaflow._constants import _LOGGER, STREAM_EOF, ERR_UDF_EXCEPTION_STRING
from pynumaflow.mapstreamer import Datum
from pynumaflow.mapstreamer._dtypes import MapStreamCallable, MapStreamError
from pynumaflow.proto.mapper import map_pb2_grpc, map_pb2
from pynumaflow.shared.server import handle_async_error
from pynumaflow.types import NumaflowServicerContext
from pynumaflow._constants import _LOGGER, ERR_UDF_EXCEPTION_STRING


class AsyncMapStreamServicer(map_pb2_grpc.MapServicer):
"""
This class is used to create a new grpc Map Stream Servicer instance.
It implements the SyncMapServicer interface from the proto
map_pb2_grpc.py file.
Provides the functionality for the required rpc methods.
Concurrent gRPC Map Stream Servicer.
Spawns one background task per incoming MapRequest; each task streams
results as produced and finally emits an EOT for that request.
"""

def __init__(
self,
handler: MapStreamCallable,
):
def __init__(self, handler: MapStreamCallable):
self.__map_stream_handler: MapStreamCallable = handler
self._background_tasks: set[asyncio.Task] = set()

async def MapFn(
self,
Expand All @@ -31,51 +30,105 @@ async def MapFn(
) -> AsyncIterable[map_pb2.MapResponse]:
"""
Applies a map function to a datum stream in streaming mode.
The pascal case function name comes from the proto map_pb2_grpc.py file.
The PascalCase name comes from the generated map_pb2_grpc.py file.
"""
try:
# The first message to be received should be a valid handshake
req = await request_iterator.__anext__()
# check if it is a valid handshake req
if not (req.handshake and req.handshake.sot):
# First message must be a handshake
first = await request_iterator.__anext__()
if not (first.handshake and first.handshake.sot):
raise MapStreamError("MapStreamFn: expected handshake as the first message")
# Acknowledge handshake
yield map_pb2.MapResponse(handshake=map_pb2.Handshake(sot=True))

# read for each input request
async for req in request_iterator:
# yield messages as received from the UDF
async for res in self.__invoke_map_stream(
list(req.request.keys),
Datum(
keys=list(req.request.keys),
value=req.request.value,
event_time=req.request.event_time.ToDatetime(),
watermark=req.request.watermark.ToDatetime(),
headers=dict(req.request.headers),
),
):
yield map_pb2.MapResponse(results=[res], id=req.id)
# send EOT to indicate end of transmission for a given message
yield map_pb2.MapResponse(status=map_pb2.TransmissionStatus(eot=True), id=req.id)
except BaseException as err:
# Global non-blocking queue for outbound responses / errors
global_result_queue = NonBlockingIterator()

# Start producer that turns each inbound request into a background task
producer = asyncio.create_task(
self._process_inputs(request_iterator, global_result_queue)
)

# Consume results as they arrive and stream them to the client
async for msg in global_result_queue.read_iterator():
if isinstance(msg, BaseException):
await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING)
return
else:
# msg is a map_pb2.MapResponse, already formed
yield msg

# Ensure producer has finished (covers graceful shutdown)
await producer

except BaseException as e:
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING)
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
return

async def __invoke_map_stream(self, keys: list[str], req: Datum):
async def _process_inputs(
self,
request_iterator: AsyncIterable[map_pb2.MapRequest],
result_queue: NonBlockingIterator,
) -> None:
"""
Reads MapRequests from the client and spawns a background task per request.
Each task streams results to result_queue as they are produced.
"""
try:
# Invoke the user handler for map stream
async for msg in self.__map_stream_handler(keys, req):
yield map_pb2.MapResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags)
async for req in request_iterator:
task = asyncio.create_task(self._invoke_map_stream(req, result_queue))
self._background_tasks.add(task)
# Remove from the set when done to avoid memory growth
task.add_done_callback(self._background_tasks.discard)

# Wait for all in-flight tasks to complete
if self._background_tasks:
await asyncio.gather(*list(self._background_tasks), return_exceptions=False)

# Signal end-of-stream to the consumer
await result_queue.put(STREAM_EOF)

except BaseException as e:
_LOGGER.critical("MapFn Error, re-raising the error", exc_info=True)
# Surface the error to the consumer; MapFn will handle and close the RPC
await result_queue.put(e)

async def _invoke_map_stream(
self,
req: map_pb2.MapRequest,
result_queue: NonBlockingIterator,
) -> None:
"""
Invokes the user-provided async generator for a single request and
pushes each result onto the global queue, followed by an EOT for this id.
"""
try:
datum = Datum(
keys=list(req.request.keys),
value=req.request.value,
event_time=req.request.event_time.ToDatetime(),
watermark=req.request.watermark.ToDatetime(),
headers=dict(req.request.headers),
)

# Stream results from the user handler as they are produced
async for msg in self.__map_stream_handler(list(req.request.keys), datum):
res = map_pb2.MapResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags)
await result_queue.put(map_pb2.MapResponse(results=[res], id=req.id))

# Emit EOT for this request id
await result_queue.put(
map_pb2.MapResponse(status=map_pb2.TransmissionStatus(eot=True), id=req.id)
)

except BaseException as err:
_LOGGER.critical("MapFn handler error", exc_info=True)
raise err
# Surface handler error to the main producer;
# it will call handle_async_error and end the RPC
await result_queue.put(err)

async def IsReady(
self, request: _empty_pb2.Empty, context: NumaflowServicerContext
) -> map_pb2.ReadyResponse:
"""
IsReady is the heartbeat endpoint for gRPC.
The pascal case function name comes from the proto map_pb2_grpc.py file.
"""
"""Heartbeat endpoint for gRPC."""
return map_pb2.ReadyResponse(ready=True)
76 changes: 52 additions & 24 deletions tests/mapstream/test_async_map_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,39 +95,67 @@ def tearDownClass(cls) -> None:

def test_map_stream(self) -> None:
stub = self.__stub()
generator_response = None

# Send >1 requests
req_count = 3
try:
generator_response = stub.MapFn(request_iterator=request_generator(count=1, session=1))
generator_response = stub.MapFn(
request_iterator=request_generator(count=req_count, session=1)
)
except grpc.RpcError as e:
logging.error(e)
self.fail(f"RPC failed: {e}")

# First message must be the handshake
handshake = next(generator_response)
# assert that handshake response is received.
self.assertTrue(handshake.handshake.sot)
data_resp = []
for r in generator_response:
data_resp.append(r)

self.assertEqual(11, len(data_resp))

idx = 0
while idx < len(data_resp) - 1:
# Expected: 10 results per request + 1 EOT per request
expected_result_msgs = req_count * 10
expected_eots = req_count

# Prepare expected payload
expected_payload = bytes(
"payload:test_mock_message "
"event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00",
encoding="utf-8",
)

from collections import Counter

id_counter = Counter()
result_msg_count = 0
eot_count = 0

for msg in generator_response:
# Count EOTs wherever they show up
if hasattr(msg, "status") and msg.status.eot:
eot_count += 1
continue

# Otherwise, it's a data/result message; validate payload and tally by id
self.assertTrue(msg.results, "Expected results in MapResponse.")
self.assertEqual(expected_payload, msg.results[0].value)
id_counter[msg.id] += 1
result_msg_count += 1

# Validate totals
self.assertEqual(
expected_result_msgs,
result_msg_count,
f"Expected {expected_result_msgs} result messages, got {result_msg_count}",
)
self.assertEqual(
expected_eots, eot_count, f"Expected {expected_eots} EOT messages, got {eot_count}"
)

# Validate 10 messages per request id: test-id-0..test-id-(req_count-1)
for i in range(req_count):
self.assertEqual(
bytes(
"payload:test_mock_message "
"event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00",
encoding="utf-8",
),
data_resp[idx].results[0].value,
10,
id_counter[f"test-id-{i}"],
f"Expected 10 results for test-id-{i}, got {id_counter[f'test-id-{i}']}",
)
_id = data_resp[idx].id
self.assertEqual(_id, "test-id-0")
# capture the output from the SinkFn generator and assert.
idx += 1
# EOT Response
self.assertEqual(data_resp[len(data_resp) - 1].status.eot, True)
# 10 sink responses + 1 EOT response
self.assertEqual(11, len(data_resp))

def test_is_ready(self) -> None:
with grpc.insecure_channel("unix:///tmp/async_map_stream.sock") as channel:
Expand Down