diff --git a/elementary/clients/dbt/api_dbt_runner.py b/elementary/clients/dbt/api_dbt_runner.py index 886b8caa9..79768d90d 100644 --- a/elementary/clients/dbt/api_dbt_runner.py +++ b/elementary/clients/dbt/api_dbt_runner.py @@ -27,7 +27,6 @@ class APIDbtRunner(CommandLineDbtRunner): def _inner_run_command( self, dbt_command_args: List[str], - capture_output: bool, quiet: bool, log_output: bool, log_format: str, diff --git a/elementary/clients/dbt/command_line_dbt_runner.py b/elementary/clients/dbt/command_line_dbt_runner.py index a7261c73a..a450e429a 100644 --- a/elementary/clients/dbt/command_line_dbt_runner.py +++ b/elementary/clients/dbt/command_line_dbt_runner.py @@ -5,9 +5,17 @@ from typing import Any, Dict, List, Optional import yaml +from tenacity import ( + RetryCallState, + retry, + retry_if_exception, + stop_after_attempt, + wait_exponential, +) from elementary.clients.dbt.base_dbt_runner import BaseDbtRunner from elementary.clients.dbt.dbt_log import parse_dbt_output +from elementary.clients.dbt.transient_errors import is_transient_error from elementary.exceptions.exceptions import DbtCommandError, DbtLsCommandError from elementary.monitor.dbt_project_utils import is_dbt_package_up_to_date from elementary.utils.env_vars import is_debug @@ -15,6 +23,32 @@ logger = get_logger(__name__) +# Retry configuration for transient errors. +_TRANSIENT_MAX_RETRIES = 3 +_TRANSIENT_WAIT_MULTIPLIER = 10 # seconds +_TRANSIENT_WAIT_MAX = 60 # seconds + + +class DbtTransientError(Exception): + """Raised internally to signal a transient dbt failure that should be retried.""" + + def __init__(self, result: "DbtCommandResult", message: str) -> None: + super().__init__(message) + self.result = result + + +def _before_retry_log(retry_state: RetryCallState) -> None: + """Log before each retry. Reads log_command_args from the retried call.""" + log_command_args = retry_state.kwargs.get("log_command_args", []) + attempt = retry_state.attempt_number + logger.warning( + "Transient error detected for dbt command '%s' (attempt %d/%d). Retrying...", + " ".join(log_command_args), + attempt, + _TRANSIENT_MAX_RETRIES, + ) + + MACRO_RESULT_PATTERN = re.compile( "Elementary: --ELEMENTARY-MACRO-OUTPUT-START--(.*)--ELEMENTARY-MACRO-OUTPUT-END--" ) @@ -50,6 +84,7 @@ def __init__( secret_vars, allow_macros_without_package_prefix, ) + self.adapter_type = self._get_adapter_type() self.raise_on_failure = raise_on_failure self.env_vars = env_vars if force_dbt_deps: @@ -57,10 +92,70 @@ def __init__( elif run_deps_if_needed: self._run_deps_if_needed() + def _get_adapter_type(self) -> Optional[str]: + """Resolve the adapter type from ``profiles.yml``. + + Reads the profile name from ``dbt_project.yml``, then looks up the + selected target in ``profiles.yml`` to extract its ``type`` field + (e.g. ``"bigquery"``, ``"snowflake"``). + + Returns ``None`` when profiles.yml or the expected keys are missing. + """ + profiles_dir = ( + self.profiles_dir + if self.profiles_dir + else os.path.join(os.path.expanduser("~"), ".dbt") + ) + profiles_path = os.path.join(profiles_dir, "profiles.yml") + if not os.path.exists(profiles_path): + logger.debug("profiles.yml not found at %s", profiles_path) + return None + + with open(profiles_path) as f: + profiles = yaml.safe_load(f) + + # Read dbt_project.yml to get the profile name. + dbt_project_path = os.path.join(self.project_dir, "dbt_project.yml") + if not os.path.exists(dbt_project_path): + logger.debug("dbt_project.yml not found at %s", dbt_project_path) + return None + + with open(dbt_project_path) as f: + dbt_project = yaml.safe_load(f) + + profile_name = dbt_project.get("profile") + if not profile_name: + logger.debug("No profile name found in dbt_project.yml") + return None + + profile = profiles.get(profile_name) if profiles else None + if not profile: + logger.debug("Profile '%s' not found in profiles.yml", profile_name) + return None + + # Determine which target to use. + target_name = self.target or profile.get("target") + if not target_name: + logger.debug("No target specified and no default target in profile") + return None + + target_config = profile.get("outputs", {}).get(target_name) + if not target_config: + logger.debug("Target '%s' not found in profile outputs", target_name) + return None + + adapter_type = target_config.get("type") + if adapter_type: + logger.debug( + "Resolved adapter type '%s' for target '%s'", + adapter_type, + target_name, + ) + return adapter_type + def _inner_run_command( self, dbt_command_args: List[str], - capture_output: bool, quiet: bool, log_output: bool, log_format: str, @@ -75,15 +170,13 @@ def _parse_ls_command_result( def _run_command( self, command_args: List[str], - capture_output: bool = False, log_format: str = "json", vars: Optional[dict] = None, quiet: bool = False, log_output: bool = True, ) -> DbtCommandResult: dbt_command_args = [] - if capture_output: - dbt_command_args.extend(["--log-format", log_format]) + dbt_command_args.extend(["--log-format", log_format]) dbt_command_args.extend(command_args) dbt_command_args.extend(["--project-dir", os.path.abspath(self.project_dir)]) if self.profiles_dir: @@ -112,28 +205,108 @@ def _run_command( else: logger.debug(log_msg) - result = self._inner_run_command( - dbt_command_args, - capture_output=capture_output, - quiet=quiet, - log_output=log_output, - log_format=log_format, - ) - - if capture_output and result.output: + try: + return self._inner_run_command_with_retries( + dbt_command_args=dbt_command_args, + log_command_args=log_command_args, + quiet=quiet, + log_output=log_output, + log_format=log_format, + ) + except DbtTransientError as exc: + logger.exception( + "dbt command '%s' failed after %d attempts due to transient errors.", + " ".join(log_command_args), + _TRANSIENT_MAX_RETRIES, + ) + if isinstance(exc.__cause__, DbtCommandError): + raise exc.__cause__ from exc + return exc.result + + @retry( + retry=retry_if_exception(lambda exc: isinstance(exc, DbtTransientError)), + stop=stop_after_attempt(_TRANSIENT_MAX_RETRIES), + wait=wait_exponential( + multiplier=_TRANSIENT_WAIT_MULTIPLIER, + max=_TRANSIENT_WAIT_MAX, + ), + before_sleep=_before_retry_log, + reraise=True, + ) + def _inner_run_command_with_retries( + self, + dbt_command_args: List[str], + log_command_args: List[str], + quiet: bool, + log_output: bool, + log_format: str, + ) -> DbtCommandResult: + """Run one dbt command attempt. Raises DbtTransientError for transient failures so tenacity can retry.""" + try: + result = self._inner_run_command( + dbt_command_args, + quiet=quiet, + log_output=log_output, + log_format=log_format, + ) + except DbtCommandError as exc: + output_text = str(exc) + stderr_text: Optional[str] = None + if exc.proc_err is not None: + if exc.proc_err.output: + output_text = ( + exc.proc_err.output.decode() + if isinstance(exc.proc_err.output, bytes) + else str(exc.proc_err.output) + ) + if exc.proc_err.stderr: + stderr_text = ( + exc.proc_err.stderr.decode() + if isinstance(exc.proc_err.stderr, bytes) + else str(exc.proc_err.stderr) + ) + if is_transient_error( + self.adapter_type, output=output_text, stderr=stderr_text + ): + raise DbtTransientError( + result=DbtCommandResult( + success=False, + output=output_text, + stderr=stderr_text, + ), + message=f"Transient error during dbt command: {exc}", + ) from exc + raise + + if result.output: logger.debug( - f"Result bytes size for command '{log_command_args}' is {len(result.output)}" + "Result bytes size for command '%s' is %d", + " ".join(log_command_args), + len(result.output), ) if log_output or is_debug(): for log in parse_dbt_output(result.output, log_format): logger.info(log.msg) + if not result.success and is_transient_error( + self.adapter_type, output=result.output, stderr=result.stderr + ): + raise DbtTransientError( + result=result, + message=( + f"Transient error during dbt command: " + f"{' '.join(log_command_args)}" + ), + ) + return result - def deps(self, quiet: bool = False, capture_output: bool = True) -> bool: - result = self._run_command( - command_args=["deps"], quiet=quiet, capture_output=capture_output - ) + def deps( + self, + quiet: bool = False, + capture_output: bool = True, # Deprecated: no-op, kept for backward compatibility. + ) -> bool: + result = self._run_command(command_args=["deps"], quiet=quiet) return result.success def seed(self, select: Optional[str] = None, full_refresh: bool = False) -> bool: @@ -152,7 +325,7 @@ def snapshot(self) -> bool: def run_operation( self, macro_name: str, - capture_output: bool = True, + capture_output: bool = True, # Deprecated: no-op, kept for backward compatibility. macro_args: Optional[dict] = None, log_errors: bool = True, vars: Optional[dict] = None, @@ -177,7 +350,6 @@ def run_operation( command_args.extend(["--args", json_args]) result = self._run_command( command_args=command_args, - capture_output=capture_output, vars=vars, quiet=quiet, log_output=log_output, @@ -191,23 +363,22 @@ def run_operation( log_pattern = ( RAW_EDR_LOGS_PATTERN if return_raw_edr_logs else MACRO_RESULT_PATTERN ) - if capture_output: - if result.output is not None: - for log in parse_dbt_output(result.output): - if log_errors and log.level == "error": - logger.error(log.msg) - continue - - if log.msg: - match = log_pattern.match(log.msg) - if match: - run_operation_results.append(match.group(1)) - - if result.stderr is not None and log_errors: - for log in parse_dbt_output(result.stderr): - if log.level == "error": - logger.error(log.msg) - continue + if result.output is not None: + for log in parse_dbt_output(result.output): + if log_errors and log.level == "error": + logger.error(log.msg) + continue + + if log.msg: + match = log_pattern.match(log.msg) + if match: + run_operation_results.append(match.group(1)) + + if result.stderr is not None and log_errors: + for log in parse_dbt_output(result.stderr): + if log.level == "error": + logger.error(log.msg) + continue return run_operation_results @@ -218,7 +389,7 @@ def run( full_refresh: bool = False, vars: Optional[dict] = None, quiet: bool = False, - capture_output: bool = False, + capture_output: bool = False, # Deprecated: no-op, kept for backward compatibility. ) -> bool: command_args = ["run"] if full_refresh: @@ -231,7 +402,6 @@ def run( command_args=command_args, vars=vars, quiet=quiet, - capture_output=capture_output, ) return result.success @@ -240,7 +410,7 @@ def test( select: Optional[str] = None, vars: Optional[dict] = None, quiet: bool = False, - capture_output: bool = False, + capture_output: bool = False, # Deprecated: no-op, kept for backward compatibility. ) -> bool: command_args = ["test"] if select: @@ -249,7 +419,6 @@ def test( command_args=command_args, vars=vars, quiet=quiet, - capture_output=capture_output, ) return result.success @@ -266,9 +435,7 @@ def ls(self, select: Optional[str] = None) -> list: if select: command_args.extend(["-s", select]) try: - result = self._run_command( - command_args=command_args, capture_output=True, log_format="text" - ) + result = self._run_command(command_args=command_args, log_format="text") return self._parse_ls_command_result(select, result) except DbtCommandError: raise DbtLsCommandError(select) diff --git a/elementary/clients/dbt/subprocess_dbt_runner.py b/elementary/clients/dbt/subprocess_dbt_runner.py index 18d74cf23..ca33eb208 100644 --- a/elementary/clients/dbt/subprocess_dbt_runner.py +++ b/elementary/clients/dbt/subprocess_dbt_runner.py @@ -19,7 +19,6 @@ class SubprocessDbtRunner(CommandLineDbtRunner): def _inner_run_command( self, dbt_command_args: List[str], - capture_output: bool, quiet: bool, log_output: bool, log_format: str, @@ -28,7 +27,7 @@ def _inner_run_command( result = subprocess.run( [self._get_dbt_command_name()] + dbt_command_args, check=self.raise_on_failure, - capture_output=capture_output or quiet, + capture_output=True, env=self._get_command_env(), cwd=self.project_dir, ) @@ -43,7 +42,7 @@ def _inner_run_command( if err.output else [] ) - if capture_output and (log_output or is_debug()): + if log_output or is_debug(): for log in logs: logger.info(log.msg) raise DbtCommandError( diff --git a/elementary/clients/dbt/transient_errors.py b/elementary/clients/dbt/transient_errors.py new file mode 100644 index 000000000..7ea625b90 --- /dev/null +++ b/elementary/clients/dbt/transient_errors.py @@ -0,0 +1,159 @@ +"""Per-adapter transient error patterns for automatic retry. + +Each adapter may produce transient errors that are safe to retry. This +module centralises those patterns so that the runner can decide whether a +failed dbt command should be retried transparently. + +To add patterns for a new adapter, append a new entry to +``_ADAPTER_PATTERNS`` with the **adapter type** as key (e.g. +``"bigquery"``, ``"snowflake"``) and a tuple of **plain, lowercase** +substrings that appear in the error output. Matching is +case-insensitive substring search so regex is not needed. + +The ``adapter_type`` argument accepted by :func:`is_transient_error` +should be the dbt **adapter type** (e.g. ``"bigquery"``, ``"snowflake"``), +as resolved from ``profiles.yml``. When the value does not match any +known adapter key (or is ``None``), **all** adapter patterns are checked +defensively so that transient errors are never missed. +""" + +from typing import Dict, Optional, Sequence, Tuple + +# --------------------------------------------------------------------------- +# Per-adapter transient error substrings (all lowercase). +# +# A command failure is considered *transient* when the dbt output +# (stdout + stderr, lowercased) contains **any** of the substrings +# listed for the active adapter **or** in the ``_COMMON`` list. +# --------------------------------------------------------------------------- + +_COMMON: Tuple[str, ...] = ( + # Generic connection / HTTP errors that any adapter can surface. + "connection reset by peer", + "connection was closed", + "remotedisconnected", + "connectionerror", + "brokenpipeerror", + "connection aborted", + "read timed out", +) + +_DATABRICKS_PATTERNS: Tuple[str, ...] = ( + "temporarily_unavailable", + "504 gateway timeout", + "502 bad gateway", + "service unavailable", +) + +_ADAPTER_PATTERNS: Dict[str, Tuple[str, ...]] = { + "bigquery": ( + # Streaming-buffer delay after a streaming insert. + "streaming data from", + "is temporarily unavailable", + # Generic transient backend error (500). + "retrying may solve the problem", + "backenderror", + # Rate-limit / quota errors. + "exceeded rate limits", + "rateLimitExceeded".lower(), + "quota exceeded", + # Duplicate job ID (409 conflict) — seen with dbt-fusion + xdist. + "error 409", + # Internal errors surfaced as 503 / "internal error". + "internal error encountered", + "503 service unavailable", + "http 503", + ), + "snowflake": ( + "could not connect to snowflake backend", + "authentication token has expired", + "incident id:", + "service is unavailable", + ), + "redshift": ( + "connection timed out", + "could not connect to the server", + "ssl syscall error", + ), + "databricks": _DATABRICKS_PATTERNS, + "athena": ( + "throttlingexception", + "toomanyrequestsexception", + "service unavailable", + ), + "dremio": ( + # Common patterns (remotedisconnected, connection was closed) already + # cover the most frequent Dremio transient errors. Add Dremio-specific + # patterns here as they are identified. + ), + "postgres": ( + "could not connect to server", + "connection timed out", + "server closed the connection unexpectedly", + "ssl syscall error", + ), + "trino": ( + "service unavailable", + "server returned http response code: 503", + ), + "clickhouse": ( + "connection timed out", + "broken pipe", + ), +} + +# Pre-computed union of all adapter-specific patterns for the fallback path +# when adapter_type is unknown. Built once at import time. +_ALL_ADAPTER_PATTERNS: Tuple[str, ...] = tuple( + pattern for patterns in _ADAPTER_PATTERNS.values() for pattern in patterns +) + + +def is_transient_error( + adapter_type: Optional[str], + output: Optional[str] = None, + stderr: Optional[str] = None, +) -> bool: + """Return ``True`` if *output*/*stderr* contain a known transient error. + + Parameters + ---------- + adapter_type: + The dbt adapter type (e.g. ``"bigquery"``, ``"snowflake"``), + typically resolved from ``profiles.yml`` by the runner. + When the value matches a key in ``_ADAPTER_PATTERNS``, only that + adapter's patterns (plus ``_COMMON``) are used. When it does + **not** match any known adapter (or is ``None``), **all** adapter + patterns are checked defensively to avoid missing transient errors. + output: + The captured stdout of the dbt command (may be ``None``). + stderr: + The captured stderr of the dbt command (may be ``None``). + """ + haystack = _build_haystack(output, stderr) + if not haystack: + return False + + if isinstance(adapter_type, str): + adapter_patterns = _ADAPTER_PATTERNS.get(adapter_type.lower()) + if adapter_patterns is not None: + # Known adapter — use common + adapter-specific patterns. + patterns: Sequence[str] = (*_COMMON, *adapter_patterns) + else: + # Unknown adapter type. Check all adapters defensively. + patterns = (*_COMMON, *_ALL_ADAPTER_PATTERNS) + else: + # No adapter type provided; check all adapters defensively. + patterns = (*_COMMON, *_ALL_ADAPTER_PATTERNS) + + return any(pattern in haystack for pattern in patterns) + + +def _build_haystack(output: Optional[str] = None, stderr: Optional[str] = None) -> str: + """Concatenate and lowercase *output* + *stderr* for matching.""" + parts = [] + if output and isinstance(output, str): + parts.append(output) + if stderr and isinstance(stderr, str): + parts.append(stderr) + return "\n".join(parts).lower() diff --git a/pyproject.toml b/pyproject.toml index eb056c0a8..ff7b1fc91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ packaging = ">=20.9" azure-storage-blob = ">=12.11.0" pymsteams = ">=0.2.2,<1.0.0" tabulate = ">= 0.9.0" +tenacity = ">=8.0,<10.0" pytz = ">= 2025.1" dbt-snowflake = {version = ">=0.20,<2.0.0", optional = true} diff --git a/tests/unit/clients/dbt_runner/test_retry_logic.py b/tests/unit/clients/dbt_runner/test_retry_logic.py new file mode 100644 index 000000000..c1e60133f --- /dev/null +++ b/tests/unit/clients/dbt_runner/test_retry_logic.py @@ -0,0 +1,203 @@ +"""Unit tests for transient-error retry logic in _inner_run_command_with_retries.""" + +import subprocess +from unittest import mock + +import pytest + +from elementary.clients.dbt.command_line_dbt_runner import _TRANSIENT_MAX_RETRIES +from elementary.exceptions.exceptions import DbtCommandError + +# Patch tenacity wait to zero so tests don't block on exponential backoff. +_ZERO_WAIT = mock.patch( + "elementary.clients.dbt.command_line_dbt_runner._TRANSIENT_WAIT_MULTIPLIER", 0 +) + + +def _make_runner(**kwargs): + """Create a SubprocessDbtRunner with deps/packages stubbed out.""" + defaults = dict( + project_dir="/tmp/fake_project", + profiles_dir="/tmp/fake_profiles", + target=None, + raise_on_failure=True, + run_deps_if_needed=False, + ) + defaults.update(kwargs) + # Use SubprocessDbtRunner but stub out _run_deps_if_needed so it + # doesn't touch the filesystem. + from elementary.clients.dbt.subprocess_dbt_runner import SubprocessDbtRunner + + with mock.patch.object(SubprocessDbtRunner, "_run_deps_if_needed"): + return SubprocessDbtRunner(**defaults) + + +@_ZERO_WAIT +class TestTransientRetryDbtCommandError: + """Test retry behaviour when _inner_run_command raises DbtCommandError + (raise_on_failure=True path).""" + + @mock.patch( + "elementary.clients.dbt.command_line_dbt_runner.is_transient_error", + return_value=True, + ) + @mock.patch("subprocess.run") + def test_retries_and_reraises_dbt_command_error( + self, mock_subprocess_run, mock_is_transient + ): + """After exhausting retries on a transient DbtCommandError, the original + DbtCommandError should be re-raised (preserving raise_on_failure contract).""" + proc_err = subprocess.CalledProcessError( + 1, "dbt run", output=b"connection reset by peer", stderr=b"" + ) + mock_subprocess_run.side_effect = proc_err + + runner = _make_runner(raise_on_failure=True) + + with pytest.raises(DbtCommandError): + runner.run() + + # _inner_run_command should have been called _TRANSIENT_MAX_RETRIES times + assert mock_subprocess_run.call_count == _TRANSIENT_MAX_RETRIES + + @mock.patch( + "elementary.clients.dbt.command_line_dbt_runner.is_transient_error", + return_value=True, + ) + @mock.patch("subprocess.run") + def test_retry_count_matches_config(self, mock_subprocess_run, mock_is_transient): + """Verify exactly _TRANSIENT_MAX_RETRIES attempts are made.""" + proc_err = subprocess.CalledProcessError( + 1, "dbt test", output=b"connection reset by peer", stderr=b"" + ) + mock_subprocess_run.side_effect = proc_err + + runner = _make_runner(raise_on_failure=True) + + with pytest.raises(DbtCommandError): + runner.test() + + assert mock_subprocess_run.call_count == _TRANSIENT_MAX_RETRIES + + +@_ZERO_WAIT +class TestTransientRetryFailedResult: + """Test retry behaviour when command returns non-success result + (raise_on_failure=False path).""" + + @mock.patch( + "elementary.clients.dbt.command_line_dbt_runner.is_transient_error", + return_value=True, + ) + @mock.patch("subprocess.run") + def test_retries_and_returns_last_result( + self, mock_subprocess_run, mock_is_transient + ): + """After exhausting retries on a transient failed result, the last + DbtCommandResult should be returned (not an exception).""" + fake_result = mock.MagicMock() + fake_result.returncode = 1 + fake_result.stdout = b"service unavailable" + fake_result.stderr = b"" + mock_subprocess_run.return_value = fake_result + + runner = _make_runner(raise_on_failure=False) + result = runner.run() + + # Should have retried _TRANSIENT_MAX_RETRIES times + assert mock_subprocess_run.call_count == _TRANSIENT_MAX_RETRIES + # Result should indicate failure (not raise) + assert result is False + + @mock.patch( + "elementary.clients.dbt.command_line_dbt_runner.is_transient_error", + return_value=True, + ) + @mock.patch("subprocess.run") + def test_retry_succeeds_on_second_attempt( + self, mock_subprocess_run, mock_is_transient + ): + """A transient failure followed by success should return the successful result.""" + fail_result = mock.MagicMock() + fail_result.returncode = 1 + fail_result.stdout = b"service unavailable" + fail_result.stderr = b"" + + success_result = mock.MagicMock() + success_result.returncode = 0 + success_result.stdout = b"ok" + success_result.stderr = b"" + + mock_subprocess_run.side_effect = [fail_result, success_result] + + runner = _make_runner(raise_on_failure=False) + result = runner.run() + + assert mock_subprocess_run.call_count == 2 + assert result is True + + +@_ZERO_WAIT +class TestNonTransientNotRetried: + """Test that non-transient failures are NOT retried.""" + + @mock.patch( + "elementary.clients.dbt.command_line_dbt_runner.is_transient_error", + return_value=False, + ) + @mock.patch("subprocess.run") + def test_non_transient_error_not_retried( + self, mock_subprocess_run, mock_is_transient + ): + """A non-transient DbtCommandError should propagate immediately + without any retries.""" + proc_err = subprocess.CalledProcessError( + 1, "dbt run", output=b"syntax error in model", stderr=b"" + ) + mock_subprocess_run.side_effect = proc_err + + runner = _make_runner(raise_on_failure=True) + + with pytest.raises(DbtCommandError): + runner.run() + + # Only called once — no retry + assert mock_subprocess_run.call_count == 1 + + @mock.patch( + "elementary.clients.dbt.command_line_dbt_runner.is_transient_error", + return_value=False, + ) + @mock.patch("subprocess.run") + def test_non_transient_failed_result_not_retried( + self, mock_subprocess_run, mock_is_transient + ): + """A non-transient failed result should be returned immediately + without any retries.""" + fake_result = mock.MagicMock() + fake_result.returncode = 1 + fake_result.stdout = b"compilation error" + fake_result.stderr = b"" + mock_subprocess_run.return_value = fake_result + + runner = _make_runner(raise_on_failure=False) + result = runner.run() + + # Only called once — no retry + assert mock_subprocess_run.call_count == 1 + assert result is False + + @mock.patch("subprocess.run") + def test_successful_command_not_retried(self, mock_subprocess_run): + """A successful command should return immediately without retries.""" + fake_result = mock.MagicMock() + fake_result.returncode = 0 + fake_result.stdout = b"ok" + fake_result.stderr = b"" + mock_subprocess_run.return_value = fake_result + + runner = _make_runner(raise_on_failure=False) + result = runner.run() + + assert mock_subprocess_run.call_count == 1 + assert result is True diff --git a/tests/unit/monitor/fetchers/alerts/test_alerts_fetcher.py b/tests/unit/monitor/fetchers/alerts/test_alerts_fetcher.py index fb9dec9bb..a860feb8b 100644 --- a/tests/unit/monitor/fetchers/alerts/test_alerts_fetcher.py +++ b/tests/unit/monitor/fetchers/alerts/test_alerts_fetcher.py @@ -33,10 +33,11 @@ def test_update_sent_alerts( calls_args = mock_subprocess_run.call_args_list for call_args in calls_args: # Test that update_sent_alerts has been called with alert_ids as arguments. - assert call_args[0][0][1] == "run" - assert call_args[0][0][2] == "-s" - assert call_args[0][0][3] == "elementary_cli.update_alerts.update_sent_alerts" - dbt_run_params = json.loads(call_args[0][0][9]) + # Indices account for --log-format json being prepended to all dbt commands. + assert call_args[0][0][3] == "run" + assert call_args[0][0][4] == "-s" + assert call_args[0][0][5] == "elementary_cli.update_alerts.update_sent_alerts" + dbt_run_params = json.loads(call_args[0][0][11]) assert "alert_ids" in dbt_run_params assert "sent_at" in dbt_run_params @@ -57,12 +58,13 @@ def test_skip_alerts(mock_subprocess_run, alerts_fetcher_mock: MockAlertsFetcher calls_args = mock_subprocess_run.call_args_list for call_args in calls_args: # Test that update_skipped_alerts has been called with alert_ids as arguments. - assert call_args[0][0][1] == "run" - assert call_args[0][0][2] == "-s" + # Indices account for --log-format json being prepended to all dbt commands. + assert call_args[0][0][3] == "run" + assert call_args[0][0][4] == "-s" assert ( - call_args[0][0][3] == "elementary_cli.update_alerts.update_skipped_alerts" + call_args[0][0][5] == "elementary_cli.update_alerts.update_skipped_alerts" ) - dbt_run_params = json.loads(call_args[0][0][9]) + dbt_run_params = json.loads(call_args[0][0][11]) assert "alert_ids" in dbt_run_params