|
| 1 | +""" |
| 2 | +Shutdown-event tests for the multiproc SourceTransform servicer. |
| 3 | +
|
| 4 | +These tests verify that the SourceTransformServicer (as used by |
| 5 | +SourceTransformMultiProcServer) correctly sets shutdown_event on error, |
| 6 | +enabling coordinated graceful shutdown across all worker processes via |
| 7 | +the shared multiprocessing.Event. |
| 8 | +""" |
| 9 | + |
| 10 | +from unittest import mock |
| 11 | + |
| 12 | +import grpc |
| 13 | +from grpc import StatusCode |
| 14 | +from grpc_testing import server_from_dictionary, strict_real_time |
| 15 | + |
| 16 | +from pynumaflow.sourcetransformer.multiproc_server import SourceTransformMultiProcServer |
| 17 | +from pynumaflow.proto.sourcetransformer import transform_pb2 |
| 18 | +from tests.sourcetransform.utils import transform_handler, err_transform_handler, get_test_datums |
| 19 | + |
| 20 | + |
| 21 | +def test_shutdown_event_set_on_handler_error(): |
| 22 | + """When the UDF handler raises, the servicer must signal the shutdown event.""" |
| 23 | + server = SourceTransformMultiProcServer(source_transform_instance=err_transform_handler) |
| 24 | + servicer = server.servicer |
| 25 | + |
| 26 | + services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: servicer} |
| 27 | + test_server = server_from_dictionary(services, strict_real_time()) |
| 28 | + |
| 29 | + test_datums = get_test_datums(handshake=True) |
| 30 | + |
| 31 | + method = test_server.invoke_stream_stream( |
| 32 | + method_descriptor=( |
| 33 | + transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[ |
| 34 | + "SourceTransformFn" |
| 35 | + ] |
| 36 | + ), |
| 37 | + invocation_metadata={}, |
| 38 | + timeout=2, |
| 39 | + ) |
| 40 | + |
| 41 | + for d in test_datums: |
| 42 | + method.send_request(d) |
| 43 | + method.requests_closed() |
| 44 | + |
| 45 | + while True: |
| 46 | + try: |
| 47 | + method.take_response() |
| 48 | + except ValueError: |
| 49 | + break |
| 50 | + |
| 51 | + _, code, _ = method.termination() |
| 52 | + assert code == StatusCode.INTERNAL |
| 53 | + assert servicer.shutdown_event.is_set() |
| 54 | + assert servicer.error is not None |
| 55 | + |
| 56 | + |
| 57 | +def test_shutdown_event_set_on_handshake_error(): |
| 58 | + """Missing handshake must also signal the shutdown event.""" |
| 59 | + server = SourceTransformMultiProcServer(source_transform_instance=transform_handler) |
| 60 | + servicer = server.servicer |
| 61 | + |
| 62 | + services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: servicer} |
| 63 | + test_server = server_from_dictionary(services, strict_real_time()) |
| 64 | + |
| 65 | + test_datums = get_test_datums(handshake=False) |
| 66 | + |
| 67 | + method = test_server.invoke_stream_stream( |
| 68 | + method_descriptor=( |
| 69 | + transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[ |
| 70 | + "SourceTransformFn" |
| 71 | + ] |
| 72 | + ), |
| 73 | + invocation_metadata={}, |
| 74 | + timeout=1, |
| 75 | + ) |
| 76 | + |
| 77 | + for d in test_datums: |
| 78 | + method.send_request(d) |
| 79 | + method.requests_closed() |
| 80 | + |
| 81 | + while True: |
| 82 | + try: |
| 83 | + method.take_response() |
| 84 | + except ValueError: |
| 85 | + break |
| 86 | + |
| 87 | + _, code, details = method.termination() |
| 88 | + assert code == StatusCode.INTERNAL |
| 89 | + assert "SourceTransformFn: expected handshake message" in details |
| 90 | + assert servicer.shutdown_event.is_set() |
| 91 | + assert servicer.error is not None |
| 92 | + |
| 93 | + |
| 94 | +def test_shutdown_event_set_on_stream_close_before_handshake(): |
| 95 | + """grpc.RpcError on the first read (before handshake): shutdown_event set, |
| 96 | + result_queue is None so close is skipped.""" |
| 97 | + server = SourceTransformMultiProcServer(source_transform_instance=transform_handler) |
| 98 | + servicer = server.servicer |
| 99 | + |
| 100 | + def _cancelled_iter(): |
| 101 | + raise grpc.RpcError() |
| 102 | + yield # make it a generator |
| 103 | + |
| 104 | + responses = list(servicer.SourceTransformFn(_cancelled_iter(), mock.MagicMock())) |
| 105 | + |
| 106 | + assert responses == [] |
| 107 | + assert servicer.shutdown_event.is_set() |
| 108 | + assert servicer.error is None |
| 109 | + |
| 110 | + |
| 111 | +def test_shutdown_event_set_on_stream_close_mid_processing(): |
| 112 | + """grpc.RpcError mid-processing: result_queue is closed (unblocking the handler |
| 113 | + thread) and shutdown_event is set.""" |
| 114 | + server = SourceTransformMultiProcServer(source_transform_instance=transform_handler) |
| 115 | + servicer = server.servicer |
| 116 | + |
| 117 | + test_datums = get_test_datums(handshake=True) |
| 118 | + |
| 119 | + def _cancelled_iter(): |
| 120 | + yield test_datums[0] # handshake |
| 121 | + yield test_datums[1] # first data message |
| 122 | + raise grpc.RpcError() |
| 123 | + |
| 124 | + responses = list(servicer.SourceTransformFn(_cancelled_iter(), mock.MagicMock())) |
| 125 | + |
| 126 | + assert responses[0].handshake.sot |
| 127 | + assert servicer.shutdown_event.is_set() |
| 128 | + assert servicer.error is None |
0 commit comments