@@ -346,6 +346,95 @@ async def requests():
346346 ctx .set_code .assert_called_once_with (grpc .StatusCode .INTERNAL )
347347
348348
349+ async def _blocking_handler (keys , datums , output , md ):
350+ """Handler that blocks forever reading datums (never finishes on its own)."""
351+ async for _ in datums :
352+ pass
353+ await output .put (Message (b"done" , keys = keys ))
354+
355+
356+ def test_cancel_and_await_remaining_tasks_on_post_processing_error ():
357+ """
358+ When a BaseException occurs during post-processing (after the input stream
359+ is exhausted), the TaskManager should cancel and await all remaining task
360+ futures to suppress 'never retrieved' warnings.
361+ """
362+ from unittest .mock import patch
363+
364+ tm = TaskManager (_blocking_handler )
365+
366+ request , _ = start_request (multiple_window = False )
367+ # Use OPEN so create_task is called
368+ request .operation .event = reduce_pb2 .ReduceRequest .WindowOperation .Event .OPEN
369+
370+ async def _run ():
371+ async def requests ():
372+ yield request
373+
374+ # Patch stream_send_eof to raise after the task is created but before
375+ # it completes, so the task futures are still running when the except
376+ # block executes.
377+ with patch .object (tm , "stream_send_eof" , side_effect = RuntimeError ("send_eof boom" )):
378+ await tm .process_input_stream (requests ())
379+
380+ # After process_input_stream returns, verify the error was placed in
381+ # the global result queue.
382+ reader = tm .global_result_queue .read_iterator ()
383+ first_item = await reader .__anext__ ()
384+ assert isinstance (first_item , RuntimeError )
385+ assert "send_eof boom" in str (first_item )
386+
387+ # Verify all task futures completed (cancelled or finished).
388+ for task in tm .get_tasks ():
389+ assert task .future .done (), "task.future should be done after cleanup"
390+ assert task .consumer_future .done (), "task.consumer_future should be done after cleanup"
391+
392+ asyncio .run (_run ())
393+
394+
395+ def test_cancel_and_await_with_already_done_futures ():
396+ """
397+ When post-processing fails but some futures are already done,
398+ the cleanup code should handle them gracefully (skip cancellation).
399+ """
400+ from unittest .mock import patch
401+
402+ async def _fast_handler (keys , datums , output , md ):
403+ """Handler that finishes immediately without reading datums."""
404+ await output .put (Message (b"fast" , keys = keys ))
405+
406+ tm = TaskManager (_fast_handler )
407+ request , _ = start_request (multiple_window = False )
408+ request .operation .event = reduce_pb2 .ReduceRequest .WindowOperation .Event .OPEN
409+
410+ async def _run ():
411+ async def requests ():
412+ yield request
413+
414+ # Let the real stream_send_eof run (which sends EOF to the handler),
415+ # then patch get_unique_windows to raise after all tasks complete.
416+ original_send_eof = tm .stream_send_eof
417+
418+ async def send_eof_then_wait_and_raise ():
419+ await original_send_eof ()
420+ # Wait for the task futures to finish
421+ for task in tm .get_tasks ():
422+ await task .future
423+ await task .result_queue .put ("__STREAM_EOF__" )
424+ await task .consumer_future
425+ raise RuntimeError ("late post-processing error" )
426+
427+ with patch .object (tm , "stream_send_eof" , side_effect = send_eof_then_wait_and_raise ):
428+ await tm .process_input_stream (requests ())
429+
430+ # Verify cleanup completed without issues
431+ for task in tm .get_tasks ():
432+ assert task .future .done ()
433+ assert task .consumer_future .done ()
434+
435+ asyncio .run (_run ())
436+
437+
349438if __name__ == "__main__" :
350439 logging .basicConfig (level = logging .DEBUG )
351440 unittest .main ()
0 commit comments