Skip to content

Commit 7f7ab40

Browse files
devin-ai-integration[bot]bot_apk
andcommitted
test(cdk): add comprehensive memory monitor tests and graceful shutdown integration test
- Complete test_memory_monitor.py with 24 tests across 4 test classes - TestMemoryMonitorInit: cgroup v1/v2 detection, lazy-init verification - TestMemoryMonitorCheckMemory: thresholds, error degradation, intervals, limit_bytes==0 - TestMemoryLimitExceeded: exception type and attribute validation - TestDefaultCheckInterval: constant value verification - Add test_memory_limit_exceeded_flushes_queued_messages to test_entrypoint.py verifying that try/finally in read() flushes queued state messages even when MemoryLimitExceeded propagates Co-Authored-By: bot_apk <apk@cognition.ai>
1 parent ec56943 commit 7f7ab40

3 files changed

Lines changed: 277 additions & 144 deletions

File tree

airbyte_cdk/utils/memory_monitor.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ class MemoryLimitExceeded(AirbyteTracedException):
3838
class MemoryMonitor:
3939
"""Monitors container memory usage via cgroup files and emits warnings before OOM kills.
4040
41-
On init, probes cgroup v2 then v1 files. Caches which version exists.
41+
Lazily probes cgroup v2 then v1 files on the first call to
42+
``check_memory_usage()``. Caches which version exists.
4243
If neither is found (local dev / CI), all subsequent calls are instant no-ops.
4344
"""
4445

@@ -55,8 +56,18 @@ def __init__(
5556
self._warning_emitted = False
5657
self._critical_raised = False
5758
self._cgroup_version: Optional[int] = None
59+
self._probed = False
60+
61+
def _probe_cgroup(self) -> None:
62+
"""Detect which cgroup version (if any) is available.
63+
64+
Called lazily on the first ``check_memory_usage()`` invocation so
65+
that ``spec`` and ``discover`` commands never incur filesystem I/O.
66+
"""
67+
if self._probed:
68+
return
69+
self._probed = True
5870

59-
# Probe cgroup version on init
6071
if _CGROUP_V2_CURRENT.exists() and _CGROUP_V2_MAX.exists():
6172
self._cgroup_version = 2
6273
elif _CGROUP_V1_USAGE.exists() and _CGROUP_V1_LIMIT.exists():
@@ -114,6 +125,7 @@ def check_memory_usage(self) -> None:
114125
Each threshold triggers at most once per sync to avoid log spam.
115126
This method is a no-op if cgroup files are unavailable.
116127
"""
128+
self._probe_cgroup()
117129
if self._cgroup_version is None:
118130
return
119131

unit_tests/test_entrypoint.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,99 @@ def test_handle_record_counts(
835835
)
836836

837837

838+
def test_memory_limit_exceeded_flushes_queued_messages(mocker, spec_mock, config_mock):
839+
"""When MemoryLimitExceeded is raised mid-read, queued messages should still be flushed.
840+
841+
The read() try/finally ensures _emit_queued_messages runs even when
842+
MemoryLimitExceeded propagates. The exception still surfaces to the
843+
caller, but all messages yielded before (records) and during (finally-
844+
block state messages) the exception are available to the consumer.
845+
"""
846+
queued_state = AirbyteMessage(
847+
type=Type.STATE,
848+
state=AirbyteStateMessage(
849+
type=AirbyteStateType.STREAM,
850+
stream=AirbyteStreamState(
851+
stream_descriptor=StreamDescriptor(name="stream"),
852+
stream_state=AirbyteStateBlob(updated_at="2026-01-01"),
853+
),
854+
),
855+
)
856+
message_repository = MagicMock()
857+
# consume_queue calls:
858+
# 1. run() preamble → initial queued control message
859+
# 2. read() finally block → queued state (the key assertion)
860+
# 3. run() outer finally → nothing
861+
message_repository.consume_queue.side_effect = [
862+
[MESSAGE_FROM_REPOSITORY],
863+
[queued_state],
864+
[],
865+
]
866+
mocker.patch.object(
867+
MockSource,
868+
"message_repository",
869+
new_callable=mocker.PropertyMock,
870+
return_value=message_repository,
871+
)
872+
entrypoint = AirbyteEntrypoint(MockSource())
873+
874+
record = AirbyteMessage(
875+
type=Type.RECORD,
876+
record=AirbyteRecordMessage(stream="stream", data={"id": "1"}, emitted_at=1),
877+
)
878+
mocker.patch.object(MockSource, "read_state", return_value={})
879+
mocker.patch.object(MockSource, "read_catalog", return_value={})
880+
mocker.patch.object(MockSource, "read", return_value=[record, record])
881+
882+
from airbyte_cdk.utils.memory_monitor import MemoryLimitExceeded
883+
884+
call_count = 0
885+
886+
def _raise_on_second_call() -> None:
887+
nonlocal call_count
888+
call_count += 1
889+
if call_count >= 2:
890+
raise MemoryLimitExceeded(
891+
internal_message="Memory at 96%",
892+
message="Source exceeded memory limit (96% used) and must shut down. "
893+
"Reduce the number of streams or increase memory allocation.",
894+
failure_type=FailureType.transient_error,
895+
)
896+
897+
mocker.patch.object(
898+
entrypoint._memory_monitor, "check_memory_usage", side_effect=_raise_on_second_call
899+
)
900+
901+
parsed_args = Namespace(
902+
command="read", config="config_path", state="statepath", catalog="catalogpath"
903+
)
904+
905+
# The generator yields messages until MemoryLimitExceeded propagates.
906+
# Collect everything yielded before the exception surfaces.
907+
messages: list[str] = []
908+
with pytest.raises(MemoryLimitExceeded):
909+
for msg in entrypoint.run(parsed_args):
910+
messages.append(msg)
911+
912+
# 1. The first record was yielded before the exception
913+
record_messages = [m for m in messages if "RECORD" in m]
914+
assert len(record_messages) >= 1, (
915+
"At least the first record should be yielded before MemoryLimitExceeded"
916+
)
917+
918+
# 2. The queued state message was flushed by the finally block
919+
state_messages = [m for m in messages if "STATE" in m]
920+
assert len(state_messages) >= 1, (
921+
"Queued state message should be flushed even after MemoryLimitExceeded"
922+
)
923+
924+
# 3. The flushed state has sourceStats.recordCount set by handle_record_counts.
925+
# Both records are yielded (and counted) before the second check_memory_usage
926+
# raises, so the counter is 2.0 at flush time.
927+
state_json = orjson.loads(state_messages[0])
928+
assert state_json["state"]["sourceStats"]["recordCount"] == 2.0
929+
930+
838931
def test_given_serialization_error_using_orjson_then_fallback_on_json(
839932
entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock
840933
):

0 commit comments

Comments
 (0)