|
1 | 1 | import asyncio |
| 2 | +import contextlib |
2 | 3 | from collections.abc import AsyncIterable |
3 | 4 |
|
4 | 5 | from google.protobuf import empty_pb2 as _empty_pb2 |
@@ -35,6 +36,7 @@ async def MapFn( |
35 | 36 | """ |
36 | 37 | # proto repeated field(keys) is of type google._upb._message.RepeatedScalarContainer |
37 | 38 | # we need to explicitly convert it to list |
| 39 | + producer = None |
38 | 40 | try: |
39 | 41 | # The first message to be received should be a valid handshake |
40 | 42 | req = await request_iterator.__anext__() |
@@ -62,37 +64,57 @@ async def MapFn( |
62 | 64 | yield msg |
63 | 65 | # wait for the producer task to complete |
64 | 66 | await producer |
| 67 | + except GeneratorExit: |
| 68 | + _LOGGER.info("Client disconnected, generator closed.") |
| 69 | + raise |
65 | 70 | except BaseException as e: |
66 | 71 | _LOGGER.critical("UDFError, re-raising the error", exc_info=True) |
67 | 72 | await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING, self.multiproc) |
68 | 73 | return |
| 74 | + finally: |
| 75 | + if producer and not producer.done(): |
| 76 | + producer.cancel() |
| 77 | + with contextlib.suppress(asyncio.CancelledError): |
| 78 | + await producer |
69 | 79 |
|
70 | | - async def _process_inputs( |
71 | | - self, |
72 | | - request_iterator: AsyncIterable[map_pb2.MapRequest], |
73 | | - result_queue: NonBlockingIterator, |
74 | | - ): |
75 | | - """ |
76 | | - Utility function for processing incoming MapRequests |
77 | | - """ |
| 80 | + async def _process_inputs(self, request_iterator, result_queue): |
78 | 81 | try: |
79 | | - # for each incoming request, create a background task to execute the |
80 | | - # UDF code |
81 | 82 | async for req in request_iterator: |
82 | | - msg_task = asyncio.create_task(self._invoke_map(req, result_queue)) |
83 | | - # save a reference to a set to store active tasks |
84 | | - self.background_tasks.add(msg_task) |
85 | | - msg_task.add_done_callback(self.background_tasks.discard) |
86 | | - |
87 | | - # wait for all tasks to complete |
88 | | - for task in self.background_tasks: |
89 | | - await task |
90 | | - |
91 | | - # send an EOF to result queue to indicate that all tasks have completed |
92 | | - await result_queue.put(STREAM_EOF) |
| 83 | + task = asyncio.create_task(self._invoke_map(req, result_queue)) |
| 84 | + self.background_tasks.add(task) |
| 85 | + task.add_done_callback(self.background_tasks.discard) |
93 | 86 |
|
| 87 | + await asyncio.gather(*self.background_tasks) |
94 | 88 | except BaseException: |
95 | | - _LOGGER.critical("MapFn Error, re-raising the error", exc_info=True) |
| 89 | + _LOGGER.critical("MapFn Error in _process_inputs", exc_info=True) |
| 90 | + finally: |
| 91 | + await result_queue.put(STREAM_EOF) |
| 92 | + # async def _process_inputs( |
| 93 | + # self, |
| 94 | + # request_iterator: AsyncIterable[map_pb2.MapRequest], |
| 95 | + # result_queue: NonBlockingIterator, |
| 96 | + # ): |
| 97 | + # """ |
| 98 | + # Utility function for processing incoming MapRequests |
| 99 | + # """ |
| 100 | + # try: |
| 101 | + # # for each incoming request, create a background task to execute the |
| 102 | + # # UDF code |
| 103 | + # async for req in request_iterator: |
| 104 | + # msg_task = asyncio.create_task(self._invoke_map(req, result_queue)) |
| 105 | + # # save a reference to a set to store active tasks |
| 106 | + # self.background_tasks.add(msg_task) |
| 107 | + # msg_task.add_done_callback(self.background_tasks.discard) |
| 108 | + # |
| 109 | + # # wait for all tasks to complete |
| 110 | + # for task in self.background_tasks: |
| 111 | + # await task |
| 112 | + # |
| 113 | + # # send an EOF to result queue to indicate that all tasks have completed |
| 114 | + # await result_queue.put(STREAM_EOF) |
| 115 | + # |
| 116 | + # except BaseException: |
| 117 | + # _LOGGER.critical("MapFn Error, re-raising the error", exc_info=True) |
96 | 118 |
|
97 | 119 | async def _invoke_map(self, req: map_pb2.MapRequest, result_queue: NonBlockingIterator): |
98 | 120 | """ |
|
0 commit comments