diff --git a/elementary/clients/dbt/api_dbt_runner.py b/elementary/clients/dbt/api_dbt_runner.py index 7dcf729ae..1235f557c 100644 --- a/elementary/clients/dbt/api_dbt_runner.py +++ b/elementary/clients/dbt/api_dbt_runner.py @@ -24,6 +24,10 @@ class APIDbtCommandResult(DbtCommandResult): class APIDbtRunner(CommandLineDbtRunner): + def __init__(self, *args, **kwargs): + self._manifest = None + super().__init__(*args, **kwargs) + def _inner_run_command( self, dbt_command_args: List[str], @@ -45,9 +49,11 @@ def collect_dbt_command_logs(event): dbt_logs.append(event_dump) with env_vars_context(self.env_vars): - dbt = dbtRunner(callbacks=[collect_dbt_command_logs]) + dbt = dbtRunner(manifest=self._manifest, callbacks=[collect_dbt_command_logs]) with with_chdir(self.project_dir): res: dbtRunnerResult = dbt.invoke(dbt_command_args) + if self._manifest is None and res.success: + self._manifest = dbt.manifest output = "\n".join(dbt_logs) or None # Surface the exception text so that transient-error detection in # _inner_run_command_with_retries can match against it. The dbt diff --git a/tests/unit/clients/dbt_runner/test_api_dbt_runner.py b/tests/unit/clients/dbt_runner/test_api_dbt_runner.py new file mode 100644 index 000000000..5a825ef76 --- /dev/null +++ b/tests/unit/clients/dbt_runner/test_api_dbt_runner.py @@ -0,0 +1,89 @@ +from contextlib import contextmanager +from unittest import mock + +from dbt.cli.main import dbtRunnerResult + +from elementary.clients.dbt.api_dbt_runner import APIDbtRunner + + +def _make_result(success=True, exception=None): + return dbtRunnerResult( + success=success, + result=None, + exception=exception, + ) + + +def _make_runner(): + runner = APIDbtRunner.__new__(APIDbtRunner) + runner._manifest = None + runner.project_dir = "/tmp/fake" + runner.env_vars = None + runner.raise_on_failure = False + return runner + + +@contextmanager +def _noop_context(*args, **kwargs): + yield + + +_PATCH_CHDIR = mock.patch("elementary.clients.dbt.api_dbt_runner.with_chdir", _noop_context) +_PATCH_ENV = mock.patch("elementary.clients.dbt.api_dbt_runner.env_vars_context", _noop_context) + + +@_PATCH_ENV +@_PATCH_CHDIR +@mock.patch("elementary.clients.dbt.api_dbt_runner.dbtRunner") +def test_manifest_cached_after_first_success(mock_dbt_runner_cls): + fake_manifest = object() + mock_instance = mock.MagicMock() + mock_instance.invoke.return_value = _make_result(success=True) + mock_instance.manifest = fake_manifest + mock_dbt_runner_cls.return_value = mock_instance + + runner = _make_runner() + runner._inner_run_command(["run-operation", "foo"], quiet=True, log_output=False, log_format="json") + + assert runner._manifest is fake_manifest + mock_dbt_runner_cls.assert_called_once_with(manifest=None, callbacks=mock.ANY) + + +@_PATCH_ENV +@_PATCH_CHDIR +@mock.patch("elementary.clients.dbt.api_dbt_runner.dbtRunner") +def test_manifest_not_cached_on_failure(mock_dbt_runner_cls): + mock_instance = mock.MagicMock() + mock_instance.invoke.return_value = _make_result(success=False) + mock_instance.manifest = object() + mock_dbt_runner_cls.return_value = mock_instance + + runner = _make_runner() + runner._inner_run_command(["run-operation", "foo"], quiet=True, log_output=False, log_format="json") + + assert runner._manifest is None + + +@_PATCH_ENV +@_PATCH_CHDIR +@mock.patch("elementary.clients.dbt.api_dbt_runner.dbtRunner") +def test_cached_manifest_reused_on_subsequent_calls(mock_dbt_runner_cls): + fake_manifest = object() + mock_instance = mock.MagicMock() + mock_instance.invoke.return_value = _make_result(success=True) + mock_instance.manifest = fake_manifest + mock_dbt_runner_cls.return_value = mock_instance + + runner = _make_runner() + + runner._inner_run_command(["run-operation", "foo"], quiet=True, log_output=False, log_format="json") + assert runner._manifest is fake_manifest + + new_manifest = object() + mock_instance.manifest = new_manifest + mock_dbt_runner_cls.reset_mock() + + runner._inner_run_command(["run-operation", "bar"], quiet=True, log_output=False, log_format="json") + + mock_dbt_runner_cls.assert_called_once_with(manifest=fake_manifest, callbacks=mock.ANY) + assert runner._manifest is fake_manifest