diff --git a/elementary/clients/dbt/api_dbt_runner.py b/elementary/clients/dbt/api_dbt_runner.py index 79768d90d..7dcf729ae 100644 --- a/elementary/clients/dbt/api_dbt_runner.py +++ b/elementary/clients/dbt/api_dbt_runner.py @@ -49,15 +49,24 @@ def collect_dbt_command_logs(event): with with_chdir(self.project_dir): res: dbtRunnerResult = dbt.invoke(dbt_command_args) output = "\n".join(dbt_logs) or None + # Surface the exception text so that transient-error detection in + # _inner_run_command_with_retries can match against it. The dbt + # Python API doesn't write to stderr, so we repurpose that field + # for the exception string (analogous to how SubprocessDbtRunner + # captures subprocess stderr). + exception_text = str(res.exception) if res.exception else None if self.raise_on_failure and not res.success: raise DbtCommandError( base_command_args=dbt_command_args, - err_msg=(str(res.exception) if res.exception else output), + err_msg=(exception_text or output), logs=[DbtLog.from_log_line(log) for log in dbt_logs], ) return APIDbtCommandResult( - success=res.success, output=output, stderr=None, result_obj=res + success=res.success, + output=output, + stderr=exception_text, + result_obj=res, ) def _parse_ls_command_result( diff --git a/tests/unit/clients/dbt_runner/test_retry_logic.py b/tests/unit/clients/dbt_runner/test_retry_logic.py index c1e60133f..3a1b16d97 100644 --- a/tests/unit/clients/dbt_runner/test_retry_logic.py +++ b/tests/unit/clients/dbt_runner/test_retry_logic.py @@ -201,3 +201,116 @@ def test_successful_command_not_retried(self, mock_subprocess_run): assert mock_subprocess_run.call_count == 1 assert result is True + + +def _make_api_runner(**kwargs): + """Create an APIDbtRunner with deps/packages stubbed out.""" + defaults = dict( + project_dir="fake_project", + profiles_dir="fake_profiles", + target=None, + raise_on_failure=False, + run_deps_if_needed=False, + ) + defaults.update(kwargs) + from elementary.clients.dbt.api_dbt_runner import APIDbtRunner + + with mock.patch.object(APIDbtRunner, "_run_deps_if_needed"): + return APIDbtRunner(**defaults) + + +@_ZERO_WAIT +class TestAPIDbtRunnerTransientDetection: + """Test that APIDbtRunner surfaces exception text for transient error detection. + + The dbt Python API (APIDbtRunner) only captures JinjaLogInfo and + RunningOperationCaughtError events into ``output``. Transient errors + like RemoteDisconnected appear as ``res.exception`` — not in the + captured output. Without surfacing this, the retry logic has nothing + to match against and never fires. + """ + + @mock.patch( + "elementary.clients.dbt.api_dbt_runner.with_chdir", + return_value=mock.MagicMock( + __enter__=mock.MagicMock(), __exit__=mock.MagicMock() + ), + ) + @mock.patch("elementary.clients.dbt.api_dbt_runner.dbtRunner") + def test_transient_exception_triggers_retry(self, mock_dbt_runner_cls, _mock_chdir): + """A transient exception in res.exception should be retried.""" + # Simulate dbtRunnerResult with a transient exception. + fail_result = mock.MagicMock() + fail_result.success = False + fail_result.exception = ConnectionError( + "('Connection aborted.', " + "RemoteDisconnected('Remote end closed connection without response'))" + ) + + success_result = mock.MagicMock() + success_result.success = True + success_result.exception = None + + # dbtRunner().invoke returns fail first, then success. + mock_dbt_instance = mock.MagicMock() + mock_dbt_instance.invoke.side_effect = [fail_result, success_result] + mock_dbt_runner_cls.return_value = mock_dbt_instance + + runner = _make_api_runner(raise_on_failure=False) + result = runner.seed() + + assert mock_dbt_instance.invoke.call_count == 2 + assert result is True + + @mock.patch( + "elementary.clients.dbt.api_dbt_runner.with_chdir", + return_value=mock.MagicMock( + __enter__=mock.MagicMock(), __exit__=mock.MagicMock() + ), + ) + @mock.patch("elementary.clients.dbt.api_dbt_runner.dbtRunner") + def test_non_transient_exception_not_retried( + self, mock_dbt_runner_cls, _mock_chdir + ): + """A non-transient exception should NOT be retried.""" + fail_result = mock.MagicMock() + fail_result.success = False + fail_result.exception = Exception("Compilation Error in model foo") + + mock_dbt_instance = mock.MagicMock() + mock_dbt_instance.invoke.return_value = fail_result + mock_dbt_runner_cls.return_value = mock_dbt_instance + + runner = _make_api_runner(raise_on_failure=False) + result = runner.seed() + + assert mock_dbt_instance.invoke.call_count == 1 + assert result is False + + @mock.patch( + "elementary.clients.dbt.api_dbt_runner.with_chdir", + return_value=mock.MagicMock( + __enter__=mock.MagicMock(), __exit__=mock.MagicMock() + ), + ) + @mock.patch("elementary.clients.dbt.api_dbt_runner.dbtRunner") + def test_transient_exception_exhausts_retries( + self, mock_dbt_runner_cls, _mock_chdir + ): + """After exhausting retries, the last failed result is returned.""" + fail_result = mock.MagicMock() + fail_result.success = False + fail_result.exception = ConnectionError( + "('Connection aborted.', " + "RemoteDisconnected('Remote end closed connection without response'))" + ) + + mock_dbt_instance = mock.MagicMock() + mock_dbt_instance.invoke.return_value = fail_result + mock_dbt_runner_cls.return_value = mock_dbt_instance + + runner = _make_api_runner(raise_on_failure=False) + result = runner.seed() + + assert mock_dbt_instance.invoke.call_count == _TRANSIENT_MAX_RETRIES + assert result is False