Skip to content

Commit e649a70

Browse files
devin-ai-integration[bot]bot_apk
andcommitted
test: enhance flush-before-raise test to verify sourceStats.recordCount and step-by-step generator iteration
Co-Authored-By: bot_apk <apk@cognition.ai>
1 parent 6cd3467 commit e649a70

1 file changed

Lines changed: 22 additions & 26 deletions

File tree

unit_tests/test_entrypoint.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -859,8 +859,7 @@ def test_given_serialization_error_using_orjson_then_fallback_on_json(
859859

860860

861861
def test_memory_failfast_flushes_queued_state_before_raising(mocker):
862-
"""Queued state messages are emitted before AirbyteTracedException propagates from memory monitor."""
863-
# Build a state message that will sit in the message repository queue
862+
"""Record emitted → check_memory_usage raises → queued STATE flushed with recordCount → exception propagates."""
864863
queued_state = AirbyteMessage(
865864
type=Type.STATE,
866865
state=AirbyteStateMessage(
@@ -872,12 +871,10 @@ def test_memory_failfast_flushes_queued_state_before_raising(mocker):
872871
),
873872
)
874873

875-
# Set up the message repository mock so consume_queue returns the state on first call
876874
message_repository = MagicMock()
877875
message_repository.consume_queue.side_effect = [
878-
[], # initial flush in run() before read()
879876
[queued_state], # flush during fail-fast exception handling
880-
[], # final flush in run() finally block
877+
[], # normal end-of-loop flush (not reached)
881878
]
882879
mocker.patch.object(
883880
MockSource,
@@ -886,7 +883,6 @@ def test_memory_failfast_flushes_queued_state_before_raising(mocker):
886883
return_value=message_repository,
887884
)
888885

889-
# Source emits one record before memory monitor raises
890886
record = AirbyteMessage(
891887
record=AirbyteRecordMessage(stream="users", data={"id": 1}, emitted_at=1),
892888
type=Type.RECORD,
@@ -900,32 +896,32 @@ def test_memory_failfast_flushes_queued_state_before_raising(mocker):
900896
failure_type=FailureType.system_error,
901897
)
902898

903-
config = {"username": "fake"}
904-
mocker.patch.object(MockSource, "read_config", return_value=config)
905-
mocker.patch.object(MockSource, "configure", return_value=config)
906-
907899
entrypoint_obj = AirbyteEntrypoint(MockSource())
908900
mocker.patch.object(
909901
entrypoint_obj._memory_monitor, "check_memory_usage", side_effect=fail_fast_exc
910902
)
911903

912-
mocker.patch.object(
913-
MockSource, "spec", return_value=ConnectorSpecification(connectionSpecification={})
914-
)
904+
spec = ConnectorSpecification(connectionSpecification={})
905+
config: dict[str, str] = {}
915906

916-
parsed_args = Namespace(
917-
command="read", config="config_path", state="statepath", catalog="catalogpath"
918-
)
907+
# Call read() directly to get AirbyteMessage objects (not serialised strings)
908+
gen = entrypoint_obj.read(spec, config, {}, [])
919909

920-
# Collect all yielded messages before the exception
921-
emitted: list[str] = []
922-
with pytest.raises(AirbyteTracedException) as exc_info:
923-
for msg in entrypoint_obj.run(parsed_args):
924-
emitted.append(msg)
910+
# 1. First yielded message is the RECORD
911+
first = next(gen)
912+
assert first.type == Type.RECORD
913+
assert first.record.stream == "users" # type: ignore[union-attr]
925914

926-
assert exc_info.value is fail_fast_exc
915+
# 2. Second yielded message is the queued STATE (flushed before exception)
916+
second = next(gen)
917+
assert second.type == Type.STATE
918+
assert second.state.stream.stream_state == AirbyteStateBlob({"cursor": "abc123"}) # type: ignore[union-attr]
927919

928-
# The record should be yielded first, then the queued state (flushed during exception handling)
929-
state_messages = [m for m in emitted if "STATE" in m]
930-
assert len(state_messages) == 1, "Queued state should be flushed before exception propagates"
931-
assert "abc123" in state_messages[0]
920+
# 3. The STATE passed through handle_record_counts, so sourceStats.recordCount == 1.0
921+
assert second.state.sourceStats is not None # type: ignore[union-attr]
922+
assert second.state.sourceStats.recordCount == 1.0 # type: ignore[union-attr]
923+
924+
# 4. Next iteration re-raises the AirbyteTracedException
925+
with pytest.raises(AirbyteTracedException) as exc_info:
926+
next(gen)
927+
assert exc_info.value is fail_fast_exc

0 commit comments

Comments
 (0)