Skip to content

Commit 0b94cbe

Browse files
devin-ai-integration[bot]bot_apkpnilan
authored
feat(cdk): enable fail-fast shutdown on memory threshold with dual-condition check (#962)
Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: bot_apk <apk@cognition.ai> Co-authored-by: Patrick Nilan <nilan.patrick@gmail.com>
1 parent 5b6c307 commit 0b94cbe

File tree

4 files changed

+519
-73
lines changed

4 files changed

+519
-73
lines changed

airbyte_cdk/entrypoint.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,17 @@ def read(
281281
stream_message_counter: DefaultDict[HashableStreamDescriptor, float] = defaultdict(float)
282282
for message in self.source.read(self.logger, config, catalog, state):
283283
yield self.handle_record_counts(message, stream_message_counter)
284-
self._memory_monitor.check_memory_usage()
284+
try:
285+
self._memory_monitor.check_memory_usage()
286+
except AirbyteTracedException:
287+
# Flush queued messages (state checkpoints, logs) before propagating
288+
# the memory fail-fast exception, so the platform receives the last
289+
# committed state for the next sync.
290+
for queued_message in self._emit_queued_messages(self.source):
291+
yield self.handle_record_counts(queued_message, stream_message_counter)
292+
raise
293+
294+
# Flush queued messages after normal completion of the read loop.
285295
for message in self._emit_queued_messages(self.source):
286296
yield self.handle_record_counts(message, stream_message_counter)
287297

airbyte_cdk/utils/memory_monitor.py

Lines changed: 158 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,122 @@
22
# Copyright (c) 2026 Airbyte, Inc., all rights reserved.
33
#
44

5-
"""Source-side memory introspection to log memory usage approaching container limits."""
5+
"""Source-side memory introspection with fail-fast shutdown on memory threshold."""
66

77
import logging
88
from pathlib import Path
99
from typing import Optional
1010

11+
from airbyte_cdk.models import FailureType
12+
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
13+
1114
logger = logging.getLogger("airbyte")
1215

1316
# cgroup v2 paths
1417
_CGROUP_V2_CURRENT = Path("/sys/fs/cgroup/memory.current")
1518
_CGROUP_V2_MAX = Path("/sys/fs/cgroup/memory.max")
19+
_CGROUP_V2_STAT = Path("/sys/fs/cgroup/memory.stat")
1620

1721
# cgroup v1 paths — TODO: remove if all deployments are confirmed cgroup v2
1822
_CGROUP_V1_USAGE = Path("/sys/fs/cgroup/memory/memory.usage_in_bytes")
1923
_CGROUP_V1_LIMIT = Path("/sys/fs/cgroup/memory/memory.limit_in_bytes")
2024

21-
# Log when usage is at or above 90%
22-
_MEMORY_THRESHOLD = 0.90
25+
# Process-level anonymous RSS from /proc/self/status (Linux only, no extra dependency)
26+
_PROC_SELF_STATUS = Path("/proc/self/status")
27+
28+
# Raise AirbyteTracedException when BOTH conditions are met:
29+
# 1. cgroup usage >= critical threshold
30+
# 2. anonymous memory >= anon-share threshold of *current cgroup usage*
31+
# Comparing anon to usage (not limit) answers the more relevant question:
32+
# "is most of the near-OOM memory actually process-owned anonymous memory?"
33+
_CRITICAL_THRESHOLD = 0.98
34+
_ANON_SHARE_OF_USAGE_THRESHOLD = 0.85
2335

24-
# Check interval (every N messages)
36+
# Check interval (every N messages) — tightens after crossing high-pressure threshold
2537
_DEFAULT_CHECK_INTERVAL = 5000
38+
_HIGH_PRESSURE_CHECK_INTERVAL = 100
39+
_HIGH_PRESSURE_THRESHOLD = 0.95
40+
41+
42+
def _read_cgroup_v2_anon_bytes() -> Optional[int]:
43+
"""Read cgroup-level anonymous memory from ``/sys/fs/cgroup/memory.stat``.
44+
45+
The ``anon`` field in ``memory.stat`` accounts for all anonymous pages
46+
charged to the cgroup, which is a more accurate view of process-private
47+
memory pressure than per-process ``RssAnon`` in multi-process containers.
48+
49+
Returns anonymous bytes, or ``None`` if unavailable or malformed.
50+
"""
51+
try:
52+
for line in _CGROUP_V2_STAT.read_text().splitlines():
53+
if line.startswith("anon "):
54+
return int(line.split()[1])
55+
except (OSError, ValueError):
56+
return None
57+
return None
58+
59+
60+
def _read_process_anon_rss_bytes() -> Optional[int]:
61+
"""Read process-private anonymous resident memory from /proc/self/status.
62+
63+
Parses the ``RssAnon`` field which represents private anonymous pages — the
64+
closest proxy for Python-heap memory pressure. Unlike ``VmRSS`` (which is
65+
``RssAnon + RssFile + RssShmem``), ``RssAnon`` is not inflated by mmap'd
66+
file-backed or shared resident pages.
67+
68+
Returns anonymous RSS in bytes, or None if unavailable (non-Linux,
69+
permission error, or ``RssAnon`` field not present in the kernel).
70+
"""
71+
try:
72+
status_text = _PROC_SELF_STATUS.read_text()
73+
for line in status_text.splitlines():
74+
if line.startswith("RssAnon:"):
75+
# Format: "RssAnon: 12345 kB"
76+
parts = line.split()
77+
if len(parts) >= 2:
78+
return int(parts[1]) * 1024 # Convert kB to bytes
79+
return None
80+
except (OSError, ValueError):
81+
return None
2682

2783

2884
class MemoryMonitor:
29-
"""Monitors container memory usage via cgroup files and logs warnings when usage is high.
85+
"""Monitors container memory usage via cgroup files and raises on critical pressure.
3086
3187
Lazily probes cgroup v2 then v1 files on the first call to
3288
``check_memory_usage()``. Caches which version exists.
3389
If neither is found (local dev / CI), all subsequent calls are instant no-ops.
3490
35-
Logs a WARNING on every check interval (default 5000 messages) when memory
36-
usage is at or above 90% of the container limit. This gives breadcrumb
37-
trails showing whether memory is climbing, plateauing, or sawtoothing.
91+
**Logging (event-based, not periodic):**
92+
93+
- One INFO when high-pressure mode activates (usage first crosses 95%)
94+
- One INFO/WARNING when critical threshold (98%) is crossed but we do
95+
*not* raise (either anon share is below the fail-fast gate or the
96+
anonymous memory signal is unavailable)
97+
- No repeated per-check warnings — logging is driven by state
98+
transitions, not periodic sampling
99+
100+
**High-pressure polling:** Once cgroup usage first crosses 95%, the check
101+
interval permanently tightens from 5000 to 100 messages to narrow the race
102+
window near OOM.
103+
104+
**Fail-fast:** Raises ``AirbyteTracedException`` with
105+
``FailureType.system_error`` when *both*:
106+
107+
1. Cgroup usage >= 98% of the container limit (container is near OOM-kill)
108+
2. Anonymous memory >= 85% of *current cgroup usage* (most of the charged
109+
memory is process-private anonymous pages, not file-backed cache)
110+
111+
The anonymous memory signal is read from cgroup v2 ``memory.stat`` (``anon``
112+
field) when available, falling back to ``/proc/self/status`` ``RssAnon``.
113+
Comparing anonymous memory to current usage (not the container limit) answers
114+
the more relevant question: "is most of the near-OOM memory actually
115+
process-owned?" This avoids the brittleness of comparing to the full limit
116+
where anonymous memory can dominate usage yet still fall short of a
117+
limit-based percentage threshold.
118+
119+
If the anonymous memory signal is unavailable, the monitor logs a warning
120+
and skips fail-fast rather than falling back to cgroup-only raising.
38121
"""
39122

40123
def __init__(
@@ -47,6 +130,8 @@ def __init__(
47130
self._message_count = 0
48131
self._cgroup_version: Optional[int] = None
49132
self._probed = False
133+
self._high_pressure_mode = False
134+
self._critical_logged = False
50135

51136
def _probe_cgroup(self) -> None:
52137
"""Detect which cgroup version (if any) is available.
@@ -101,15 +186,33 @@ def _read_memory(self) -> Optional[tuple[int, int]]:
101186
logger.debug("Failed to read cgroup memory files; skipping memory check.")
102187
return None
103188

189+
def _read_anon_bytes(self) -> Optional[tuple[int, str]]:
190+
"""Read anonymous memory bytes from the best available source.
191+
192+
Tries cgroup v2 ``memory.stat`` (``anon`` field) first, then falls back
193+
to ``/proc/self/status`` ``RssAnon``. Returns ``(bytes, source_label)``
194+
or ``None`` if neither is available.
195+
"""
196+
if self._cgroup_version == 2:
197+
cgroup_anon = _read_cgroup_v2_anon_bytes()
198+
if cgroup_anon is not None:
199+
return cgroup_anon, "cgroup memory.stat anon"
200+
201+
proc_anon = _read_process_anon_rss_bytes()
202+
if proc_anon is not None:
203+
return proc_anon, "process RssAnon"
204+
205+
return None
206+
104207
def check_memory_usage(self) -> None:
105-
"""Check memory usage and log when above 90%.
208+
"""Check memory usage and raise at critical dual-condition.
106209
107210
Intended to be called on every message. The monitor internally tracks
108211
a message counter and only reads cgroup files every ``check_interval``
109-
messages (default 5000) to minimise I/O overhead.
212+
messages (default 5000). Once usage crosses 95%, the interval tightens
213+
to 100 messages for the remainder of the sync.
110214
111-
Logs a WARNING on every check above 90% to provide breadcrumb trails
112-
showing memory trends over the sync lifetime.
215+
Logging is event-based (one-shot on state transitions), not periodic.
113216
114217
This method is a no-op if cgroup files are unavailable.
115218
"""
@@ -118,7 +221,10 @@ def check_memory_usage(self) -> None:
118221
return
119222

120223
self._message_count += 1
121-
if self._message_count % self._check_interval != 0:
224+
interval = (
225+
_HIGH_PRESSURE_CHECK_INTERVAL if self._high_pressure_mode else self._check_interval
226+
)
227+
if self._message_count % interval != 0:
122228
return
123229

124230
memory_info = self._read_memory()
@@ -131,10 +237,43 @@ def check_memory_usage(self) -> None:
131237
usage_gb = usage_bytes / (1024**3)
132238
limit_gb = limit_bytes / (1024**3)
133239

134-
if usage_ratio >= _MEMORY_THRESHOLD:
135-
logger.warning(
136-
"Source memory usage at %d%% of container limit (%.2f / %.2f GB).",
137-
usage_percent,
138-
usage_gb,
139-
limit_gb,
240+
if usage_ratio >= _HIGH_PRESSURE_THRESHOLD and not self._high_pressure_mode:
241+
self._high_pressure_mode = True
242+
logger.info(
243+
"Memory usage crossed %d%%; tightening check interval from %d to %d messages.",
244+
int(_HIGH_PRESSURE_THRESHOLD * 100),
245+
self._check_interval,
246+
_HIGH_PRESSURE_CHECK_INTERVAL,
140247
)
248+
249+
# Fail-fast: dual-condition check
250+
if usage_ratio >= _CRITICAL_THRESHOLD:
251+
anon_info = self._read_anon_bytes()
252+
if anon_info is not None:
253+
anon_bytes, anon_source = anon_info
254+
anon_share = anon_bytes / usage_bytes
255+
if anon_share >= _ANON_SHARE_OF_USAGE_THRESHOLD:
256+
raise AirbyteTracedException(
257+
message=f"Source memory usage exceeded critical threshold ({usage_percent}% of container limit).",
258+
internal_message=(
259+
f"Cgroup memory: {usage_bytes} / {limit_bytes} bytes ({usage_percent}%). "
260+
f"Anonymous memory ({anon_source}): {anon_bytes} bytes "
261+
f"({int(anon_share * 100)}% of current cgroup usage). "
262+
f"Thresholds: cgroup >= {int(_CRITICAL_THRESHOLD * 100)}%, "
263+
f"anon share of usage >= {int(_ANON_SHARE_OF_USAGE_THRESHOLD * 100)}%."
264+
),
265+
failure_type=FailureType.system_error,
266+
)
267+
elif not self._critical_logged:
268+
self._critical_logged = True
269+
logger.info(
270+
"Cgroup usage crossed %d%% but anonymous memory is only %d%% of current cgroup usage; not raising.",
271+
int(_CRITICAL_THRESHOLD * 100),
272+
int(anon_share * 100),
273+
)
274+
elif not self._critical_logged:
275+
self._critical_logged = True
276+
logger.warning(
277+
"Cgroup usage crossed %d%% but anonymous memory signal unavailable; skipping fail-fast.",
278+
int(_CRITICAL_THRESHOLD * 100),
279+
)

unit_tests/test_entrypoint.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -856,3 +856,72 @@ def test_given_serialization_error_using_orjson_then_fallback_on_json(
856856
# There will be multiple messages here because the fixture `entrypoint` sets a control message. We only care about records here
857857
record_messages = list(filter(lambda message: "RECORD" in message, messages))
858858
assert len(record_messages) == 2
859+
860+
861+
def test_memory_failfast_flushes_queued_state_before_raising(mocker):
862+
"""Record emitted → check_memory_usage raises → queued STATE flushed with recordCount → exception propagates."""
863+
queued_state = AirbyteMessage(
864+
type=Type.STATE,
865+
state=AirbyteStateMessage(
866+
type=AirbyteStateType.STREAM,
867+
stream=AirbyteStreamState(
868+
stream_descriptor=StreamDescriptor(name="users", namespace=None),
869+
stream_state=AirbyteStateBlob({"cursor": "abc123"}),
870+
),
871+
),
872+
)
873+
874+
message_repository = MagicMock()
875+
message_repository.consume_queue.side_effect = [
876+
[queued_state], # flush during fail-fast exception handling
877+
[], # normal end-of-loop flush (not reached)
878+
]
879+
mocker.patch.object(
880+
MockSource,
881+
"message_repository",
882+
new_callable=mocker.PropertyMock,
883+
return_value=message_repository,
884+
)
885+
886+
record = AirbyteMessage(
887+
record=AirbyteRecordMessage(stream="users", data={"id": 1}, emitted_at=1),
888+
type=Type.RECORD,
889+
)
890+
mocker.patch.object(MockSource, "read_state", return_value={})
891+
mocker.patch.object(MockSource, "read_catalog", return_value={})
892+
mocker.patch.object(MockSource, "read", return_value=[record])
893+
894+
fail_fast_exc = AirbyteTracedException(
895+
message="Memory usage exceeded critical threshold (98%)",
896+
failure_type=FailureType.system_error,
897+
)
898+
899+
entrypoint_obj = AirbyteEntrypoint(MockSource())
900+
mocker.patch.object(
901+
entrypoint_obj._memory_monitor, "check_memory_usage", side_effect=fail_fast_exc
902+
)
903+
904+
spec = ConnectorSpecification(connectionSpecification={})
905+
config: dict[str, str] = {}
906+
907+
# Call read() directly to get AirbyteMessage objects (not serialised strings)
908+
gen = entrypoint_obj.read(spec, config, {}, [])
909+
910+
# 1. First yielded message is the RECORD
911+
first = next(gen)
912+
assert first.type == Type.RECORD
913+
assert first.record.stream == "users" # type: ignore[union-attr]
914+
915+
# 2. Second yielded message is the queued STATE (flushed before exception)
916+
second = next(gen)
917+
assert second.type == Type.STATE
918+
assert second.state.stream.stream_state == AirbyteStateBlob({"cursor": "abc123"}) # type: ignore[union-attr]
919+
920+
# 3. The STATE passed through handle_record_counts, so sourceStats.recordCount == 1.0
921+
assert second.state.sourceStats is not None # type: ignore[union-attr]
922+
assert second.state.sourceStats.recordCount == 1.0 # type: ignore[union-attr]
923+
924+
# 4. Next iteration re-raises the AirbyteTracedException
925+
with pytest.raises(AirbyteTracedException) as exc_info:
926+
next(gen)
927+
assert exc_info.value is fail_fast_exc

0 commit comments

Comments
 (0)