@@ -835,6 +835,99 @@ def test_handle_record_counts(
835835 )
836836
837837
838+ def test_memory_limit_exceeded_flushes_queued_messages (mocker , spec_mock , config_mock ):
839+ """When MemoryLimitExceeded is raised mid-read, queued messages should still be flushed.
840+
841+ The read() try/finally ensures _emit_queued_messages runs even when
842+ MemoryLimitExceeded propagates. The exception still surfaces to the
843+ caller, but all messages yielded before (records) and during (finally-
844+ block state messages) the exception are available to the consumer.
845+ """
846+ queued_state = AirbyteMessage (
847+ type = Type .STATE ,
848+ state = AirbyteStateMessage (
849+ type = AirbyteStateType .STREAM ,
850+ stream = AirbyteStreamState (
851+ stream_descriptor = StreamDescriptor (name = "stream" ),
852+ stream_state = AirbyteStateBlob (updated_at = "2026-01-01" ),
853+ ),
854+ ),
855+ )
856+ message_repository = MagicMock ()
857+ # consume_queue calls:
858+ # 1. run() preamble → initial queued control message
859+ # 2. read() finally block → queued state (the key assertion)
860+ # 3. run() outer finally → nothing
861+ message_repository .consume_queue .side_effect = [
862+ [MESSAGE_FROM_REPOSITORY ],
863+ [queued_state ],
864+ [],
865+ ]
866+ mocker .patch .object (
867+ MockSource ,
868+ "message_repository" ,
869+ new_callable = mocker .PropertyMock ,
870+ return_value = message_repository ,
871+ )
872+ entrypoint = AirbyteEntrypoint (MockSource ())
873+
874+ record = AirbyteMessage (
875+ type = Type .RECORD ,
876+ record = AirbyteRecordMessage (stream = "stream" , data = {"id" : "1" }, emitted_at = 1 ),
877+ )
878+ mocker .patch .object (MockSource , "read_state" , return_value = {})
879+ mocker .patch .object (MockSource , "read_catalog" , return_value = {})
880+ mocker .patch .object (MockSource , "read" , return_value = [record , record ])
881+
882+ from airbyte_cdk .utils .memory_monitor import MemoryLimitExceeded
883+
884+ call_count = 0
885+
886+ def _raise_on_second_call () -> None :
887+ nonlocal call_count
888+ call_count += 1
889+ if call_count >= 2 :
890+ raise MemoryLimitExceeded (
891+ internal_message = "Memory at 96%" ,
892+ message = "Source exceeded memory limit (96% used) and must shut down. "
893+ "Reduce the number of streams or increase memory allocation." ,
894+ failure_type = FailureType .transient_error ,
895+ )
896+
897+ mocker .patch .object (
898+ entrypoint ._memory_monitor , "check_memory_usage" , side_effect = _raise_on_second_call
899+ )
900+
901+ parsed_args = Namespace (
902+ command = "read" , config = "config_path" , state = "statepath" , catalog = "catalogpath"
903+ )
904+
905+ # The generator yields messages until MemoryLimitExceeded propagates.
906+ # Collect everything yielded before the exception surfaces.
907+ messages : list [str ] = []
908+ with pytest .raises (MemoryLimitExceeded ):
909+ for msg in entrypoint .run (parsed_args ):
910+ messages .append (msg )
911+
912+ # 1. The first record was yielded before the exception
913+ record_messages = [m for m in messages if "RECORD" in m ]
914+ assert len (record_messages ) >= 1 , (
915+ "At least the first record should be yielded before MemoryLimitExceeded"
916+ )
917+
918+ # 2. The queued state message was flushed by the finally block
919+ state_messages = [m for m in messages if "STATE" in m ]
920+ assert len (state_messages ) >= 1 , (
921+ "Queued state message should be flushed even after MemoryLimitExceeded"
922+ )
923+
924+ # 3. The flushed state has sourceStats.recordCount set by handle_record_counts.
925+ # Both records are yielded (and counted) before the second check_memory_usage
926+ # raises, so the counter is 2.0 at flush time.
927+ state_json = orjson .loads (state_messages [0 ])
928+ assert state_json ["state" ]["sourceStats" ]["recordCount" ] == 2.0
929+
930+
838931def test_given_serialization_error_using_orjson_then_fallback_on_json (
839932 entrypoint : AirbyteEntrypoint , mocker , spec_mock , config_mock
840933):
0 commit comments