Skip to content

Commit 6cd3467

Browse files
devin-ai-integration[bot]bot_apk
andcommitted
refactor: narrow exception catch to AirbyteTracedException, add flush-before-raise test
Co-Authored-By: bot_apk <apk@cognition.ai>
1 parent 94b664f commit 6cd3467

2 files changed

Lines changed: 76 additions & 3 deletions

File tree

airbyte_cdk/entrypoint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,10 +283,10 @@ def read(
283283
yield self.handle_record_counts(message, stream_message_counter)
284284
try:
285285
self._memory_monitor.check_memory_usage()
286-
except Exception:
286+
except AirbyteTracedException:
287287
# Flush queued messages (state checkpoints, logs) before propagating
288-
# a memory fail-fast (or other) exception, so the platform receives
289-
# the last committed state for the next sync.
288+
# the memory fail-fast exception, so the platform receives the last
289+
# committed state for the next sync.
290290
for queued_message in self._emit_queued_messages(self.source):
291291
yield self.handle_record_counts(queued_message, stream_message_counter)
292292
raise

unit_tests/test_entrypoint.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -856,3 +856,76 @@ def test_given_serialization_error_using_orjson_then_fallback_on_json(
856856
# There will be multiple messages here because the fixture `entrypoint` sets a control message. We only care about records here
857857
record_messages = list(filter(lambda message: "RECORD" in message, messages))
858858
assert len(record_messages) == 2
859+
860+
861+
def test_memory_failfast_flushes_queued_state_before_raising(mocker):
862+
"""Queued state messages are emitted before AirbyteTracedException propagates from memory monitor."""
863+
# Build a state message that will sit in the message repository queue
864+
queued_state = AirbyteMessage(
865+
type=Type.STATE,
866+
state=AirbyteStateMessage(
867+
type=AirbyteStateType.STREAM,
868+
stream=AirbyteStreamState(
869+
stream_descriptor=StreamDescriptor(name="users", namespace=None),
870+
stream_state=AirbyteStateBlob({"cursor": "abc123"}),
871+
),
872+
),
873+
)
874+
875+
# Set up the message repository mock so consume_queue returns the state on first call
876+
message_repository = MagicMock()
877+
message_repository.consume_queue.side_effect = [
878+
[], # initial flush in run() before read()
879+
[queued_state], # flush during fail-fast exception handling
880+
[], # final flush in run() finally block
881+
]
882+
mocker.patch.object(
883+
MockSource,
884+
"message_repository",
885+
new_callable=mocker.PropertyMock,
886+
return_value=message_repository,
887+
)
888+
889+
# Source emits one record before memory monitor raises
890+
record = AirbyteMessage(
891+
record=AirbyteRecordMessage(stream="users", data={"id": 1}, emitted_at=1),
892+
type=Type.RECORD,
893+
)
894+
mocker.patch.object(MockSource, "read_state", return_value={})
895+
mocker.patch.object(MockSource, "read_catalog", return_value={})
896+
mocker.patch.object(MockSource, "read", return_value=[record])
897+
898+
fail_fast_exc = AirbyteTracedException(
899+
message="Memory usage exceeded critical threshold (98%)",
900+
failure_type=FailureType.system_error,
901+
)
902+
903+
config = {"username": "fake"}
904+
mocker.patch.object(MockSource, "read_config", return_value=config)
905+
mocker.patch.object(MockSource, "configure", return_value=config)
906+
907+
entrypoint_obj = AirbyteEntrypoint(MockSource())
908+
mocker.patch.object(
909+
entrypoint_obj._memory_monitor, "check_memory_usage", side_effect=fail_fast_exc
910+
)
911+
912+
mocker.patch.object(
913+
MockSource, "spec", return_value=ConnectorSpecification(connectionSpecification={})
914+
)
915+
916+
parsed_args = Namespace(
917+
command="read", config="config_path", state="statepath", catalog="catalogpath"
918+
)
919+
920+
# Collect all yielded messages before the exception
921+
emitted: list[str] = []
922+
with pytest.raises(AirbyteTracedException) as exc_info:
923+
for msg in entrypoint_obj.run(parsed_args):
924+
emitted.append(msg)
925+
926+
assert exc_info.value is fail_fast_exc
927+
928+
# The record should be yielded first, then the queued state (flushed during exception handling)
929+
state_messages = [m for m in emitted if "STATE" in m]
930+
assert len(state_messages) == 1, "Queued state should be flushed before exception propagates"
931+
assert "abc123" in state_messages[0]

0 commit comments

Comments
 (0)