Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
caa0c82
feat(cdk): add source-side memory introspection to emit controlled er…
devin-ai-integration[bot] Mar 9, 2026
8d059cf
style: fix ruff format and import sorting
devin-ai-integration[bot] Mar 9, 2026
b2e233f
fix(cdk): add error handling to _read_memory() for graceful degradation
devin-ai-integration[bot] Mar 9, 2026
ee93565
fix(cdk): wrap read() loop in try/finally to flush queued messages on…
devin-ai-integration[bot] Mar 9, 2026
cca6c50
refactor(cdk): encapsulate check interval inside MemoryMonitor
devin-ai-integration[bot] Mar 9, 2026
ec56943
refactor(cdk): move MemoryMonitor to AirbyteEntrypoint.__init__
devin-ai-integration[bot] Mar 9, 2026
7f7ab40
test(cdk): add comprehensive memory monitor tests and graceful shutdo…
devin-ai-integration[bot] Mar 9, 2026
a021eb7
fix(cdk): change MemoryLimitExceeded to system_error and update user-…
devin-ai-integration[bot] Mar 9, 2026
fc71aba
refactor(cdk): make DEFAULT_CHECK_INTERVAL private and drop circular …
devin-ai-integration[bot] Mar 9, 2026
5f7e16d
refactor(cdk): move memory check before yield, test observable behavi…
devin-ai-integration[bot] Mar 9, 2026
0485fa2
refactor(cdk): move memory check back to after yield for zero data loss
devin-ai-integration[bot] Mar 9, 2026
1621a39
style: fix ruff format in test_entrypoint.py
devin-ai-integration[bot] Mar 9, 2026
4dda57c
refactor(cdk): display memory usage in GB instead of raw bytes
devin-ai-integration[bot] Mar 10, 2026
c96825f
refactor(cdk): remove MemoryLimitExceeded subclass, raise AirbyteTrac…
devin-ai-integration[bot] Mar 10, 2026
cdc518c
fix(cdk): validate check_interval >= 1 to prevent ZeroDivisionError
devin-ai-integration[bot] Mar 10, 2026
bf9db5f
refactor(cdk): switch memory monitor to logging-only trial mode
devin-ai-integration[bot] Mar 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions airbyte_cdk/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from airbyte_cdk.utils import is_cloud_environment, message_utils
from airbyte_cdk.utils.airbyte_secrets_utils import get_secrets, update_secrets
from airbyte_cdk.utils.constants import ENV_REQUEST_CACHE_PATH
from airbyte_cdk.utils.memory_monitor import MemoryMonitor
from airbyte_cdk.utils.traced_exception import AirbyteTracedException

logger = init_logger("airbyte")
Expand All @@ -60,6 +61,7 @@ def __init__(self, source: Source):

self.source = source
self.logger = logging.getLogger(f"airbyte.{getattr(source, 'name', '')}")
self._memory_monitor = MemoryMonitor()

@staticmethod
def parse_args(args: List[str]) -> argparse.Namespace:
Expand Down Expand Up @@ -277,10 +279,13 @@ def read(

# The Airbyte protocol dictates that counts be expressed as float/double to better protect against integer overflows
stream_message_counter: DefaultDict[HashableStreamDescriptor, float] = defaultdict(float)
for message in self.source.read(self.logger, config, catalog, state):
yield self.handle_record_counts(message, stream_message_counter)
for message in self._emit_queued_messages(self.source):
yield self.handle_record_counts(message, stream_message_counter)
try:
for message in self.source.read(self.logger, config, catalog, state):
yield self.handle_record_counts(message, stream_message_counter)
self._memory_monitor.check_memory_usage()
finally:
Comment thread
pnilan marked this conversation as resolved.
Outdated
for message in self._emit_queued_messages(self.source):
yield self.handle_record_counts(message, stream_message_counter)

@staticmethod
def handle_record_counts(
Expand Down
162 changes: 162 additions & 0 deletions airbyte_cdk/utils/memory_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#
# Copyright (c) 2026 Airbyte, Inc., all rights reserved.
#

"""Source-side memory introspection to emit controlled error messages before OOM kills."""

import logging
from pathlib import Path
from typing import Optional

from airbyte_cdk.models import FailureType
from airbyte_cdk.utils.traced_exception import AirbyteTracedException

logger = logging.getLogger("airbyte")

# cgroup v2 paths
_CGROUP_V2_CURRENT = Path("/sys/fs/cgroup/memory.current")
_CGROUP_V2_MAX = Path("/sys/fs/cgroup/memory.max")

# cgroup v1 paths
Comment thread
pnilan marked this conversation as resolved.
Outdated
_CGROUP_V1_USAGE = Path("/sys/fs/cgroup/memory/memory.usage_in_bytes")
_CGROUP_V1_LIMIT = Path("/sys/fs/cgroup/memory/memory.limit_in_bytes")

# Default thresholds
_DEFAULT_WARNING_THRESHOLD = 0.85
_DEFAULT_CRITICAL_THRESHOLD = 0.95

# Check interval (every N messages)
_DEFAULT_CHECK_INTERVAL = 1000


class MemoryLimitExceeded(AirbyteTracedException):
"""Raised when connector memory usage exceeds critical threshold."""

pass


class MemoryMonitor:
"""Monitors container memory usage via cgroup files and emits warnings before OOM kills.

Lazily probes cgroup v2 then v1 files on the first call to
``check_memory_usage()``. Caches which version exists.
If neither is found (local dev / CI), all subsequent calls are instant no-ops.
"""

def __init__(
self,
warning_threshold: float = _DEFAULT_WARNING_THRESHOLD,
critical_threshold: float = _DEFAULT_CRITICAL_THRESHOLD,
check_interval: int = _DEFAULT_CHECK_INTERVAL,
) -> None:
self._warning_threshold = warning_threshold
self._critical_threshold = critical_threshold
self._check_interval = check_interval
self._message_count = 0
self._warning_emitted = False
self._critical_raised = False
self._cgroup_version: Optional[int] = None
self._probed = False

def _probe_cgroup(self) -> None:
"""Detect which cgroup version (if any) is available.

Called lazily on the first ``check_memory_usage()`` invocation so
that ``spec`` and ``discover`` commands never incur filesystem I/O.
"""
if self._probed:
return
self._probed = True

if _CGROUP_V2_CURRENT.exists() and _CGROUP_V2_MAX.exists():
self._cgroup_version = 2
elif _CGROUP_V1_USAGE.exists() and _CGROUP_V1_LIMIT.exists():
self._cgroup_version = 1

if self._cgroup_version is None:
logger.debug(
"No cgroup memory files found. Memory monitoring disabled (likely local dev / CI)."
)

def _read_memory(self) -> Optional[tuple[int, int]]:
"""Read current memory usage and limit from cgroup files.

Returns a tuple of (usage_bytes, limit_bytes) or None if unavailable.
Best-effort: failures to read memory info never crash a sync.
"""
if self._cgroup_version is None:
return None

try:
if self._cgroup_version == 2:
usage_path = _CGROUP_V2_CURRENT
limit_path = _CGROUP_V2_MAX
else:
usage_path = _CGROUP_V1_USAGE
limit_path = _CGROUP_V1_LIMIT

limit_text = limit_path.read_text().strip()
# cgroup v2 memory.max can be the literal string "max" (unlimited)
if limit_text == "max":
return None

usage_bytes = int(usage_path.read_text().strip())
limit_bytes = int(limit_text)

if limit_bytes <= 0:
return None

return usage_bytes, limit_bytes
except (OSError, ValueError):
logger.debug("Failed to read cgroup memory files; skipping memory check.")
return None

def check_memory_usage(self) -> None:
"""Check memory usage against thresholds.

Intended to be called on every message. The monitor internally tracks
a message counter and only reads cgroup files every ``check_interval``
messages (default 1000) to minimise I/O overhead.

At the warning threshold (default 85%), logs a warning message.
At the critical threshold (default 95%), raises MemoryLimitExceeded to
trigger a graceful shutdown with an actionable error message.

Each threshold triggers at most once per sync to avoid log spam.
This method is a no-op if cgroup files are unavailable.
"""
self._probe_cgroup()
if self._cgroup_version is None:
return

self._message_count += 1
if self._message_count % self._check_interval != 0:
return
Comment thread
pnilan marked this conversation as resolved.

memory_info = self._read_memory()
if memory_info is None:
return

usage_bytes, limit_bytes = memory_info
usage_ratio = usage_bytes / limit_bytes
usage_percent = int(usage_ratio * 100)
usage_gb = usage_bytes / (1024**3)
limit_gb = limit_bytes / (1024**3)

if usage_ratio >= self._critical_threshold and not self._critical_raised:
self._critical_raised = True
raise MemoryLimitExceeded(
internal_message=f"Memory usage is {usage_percent}% ({usage_gb:.2f} / {limit_gb:.2f} GB). "
f"Critical threshold is {int(self._critical_threshold * 100)}%.",
message=f"Source exceeded memory limit ({usage_percent}% used) and must shut down to avoid an out-of-memory crash.",
failure_type=FailureType.system_error,
)

if usage_ratio >= self._warning_threshold and not self._warning_emitted:
self._warning_emitted = True
logger.warning(
"Source memory usage reached %d%% of container limit (%.2f / %.2f GB).",
usage_percent,
usage_gb,
limit_gb,
)
91 changes: 91 additions & 0 deletions unit_tests/test_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,97 @@ def test_handle_record_counts(
)


def test_memory_limit_exceeded_flushes_queued_messages(mocker, spec_mock, config_mock):
"""When MemoryLimitExceeded is raised mid-read, queued messages should still be flushed.

The read() try/finally ensures _emit_queued_messages runs even when
MemoryLimitExceeded propagates. The exception still surfaces to the
caller, but all messages yielded before (records) and during (finally-
block state messages) the exception are available to the consumer.
"""
queued_state = AirbyteMessage(
type=Type.STATE,
state=AirbyteStateMessage(
type=AirbyteStateType.STREAM,
stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name="stream"),
stream_state=AirbyteStateBlob(updated_at="2026-01-01"),
),
),
)
message_repository = MagicMock()
# consume_queue calls:
# 1. run() preamble → initial queued control message
# 2. read() finally block → queued state (the key assertion)
# 3. run() outer finally → nothing
message_repository.consume_queue.side_effect = [
[MESSAGE_FROM_REPOSITORY],
[queued_state],
[],
]
mocker.patch.object(
MockSource,
"message_repository",
new_callable=mocker.PropertyMock,
return_value=message_repository,
)
entrypoint = AirbyteEntrypoint(MockSource())

record = AirbyteMessage(
type=Type.RECORD,
record=AirbyteRecordMessage(stream="stream", data={"id": "1"}, emitted_at=1),
)
mocker.patch.object(MockSource, "read_state", return_value={})
mocker.patch.object(MockSource, "read_catalog", return_value={})
mocker.patch.object(MockSource, "read", return_value=[record, record])

from airbyte_cdk.utils.memory_monitor import MemoryLimitExceeded

call_count = 0

def _raise_on_second_call() -> None:
nonlocal call_count
call_count += 1
if call_count >= 2:
raise MemoryLimitExceeded(
internal_message="Memory at 96%",
message="Source exceeded memory limit (96% used) and must shut down to avoid an out-of-memory crash.",
failure_type=FailureType.system_error,
)

mocker.patch.object(
entrypoint._memory_monitor, "check_memory_usage", side_effect=_raise_on_second_call
)

parsed_args = Namespace(
command="read", config="config_path", state="statepath", catalog="catalogpath"
)

# The generator yields messages until MemoryLimitExceeded propagates.
# Collect everything yielded before the exception surfaces.
messages: list[str] = []
with pytest.raises(MemoryLimitExceeded):
for msg in entrypoint.run(parsed_args):
messages.append(msg)

# 1. Both records were yielded before the exception — the memory check
# runs after yield so every message pulled from the source is emitted.
record_messages = [m for m in messages if "RECORD" in m]
assert len(record_messages) == 2, "Both records should be yielded before MemoryLimitExceeded"

# 2. The queued state message was flushed by the finally block
state_messages = [m for m in messages if "STATE" in m]
assert len(state_messages) >= 1, (
"Queued state message should be flushed even after MemoryLimitExceeded"
)

# 3. The flushed state has sourceStats.recordCount set by handle_record_counts.
# Both records are yielded (and counted) before the second check_memory_usage
# raises, so the counter is 2.0 at flush time.
state_json = orjson.loads(state_messages[0])
assert state_json["state"]["sourceStats"]["recordCount"] == 2.0
Comment thread
pnilan marked this conversation as resolved.
Outdated


def test_given_serialization_error_using_orjson_then_fallback_on_json(
entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock
):
Expand Down
Loading
Loading