diff --git a/elementary/monitor/api/report/report.py b/elementary/monitor/api/report/report.py index 77850333e..802679fd5 100644 --- a/elementary/monitor/api/report/report.py +++ b/elementary/monitor/api/report/report.py @@ -1,7 +1,9 @@ from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor from typing import Dict, Iterable, List, Optional, Tuple, Union from elementary.clients.api.api_client import APIClient +from elementary.clients.dbt.subprocess_dbt_runner import SubprocessDbtRunner from elementary.monitor.api.filters.filters import FiltersAPI from elementary.monitor.api.groups.groups import GroupsAPI from elementary.monitor.api.groups.schema import GroupsSchema @@ -38,8 +40,11 @@ from elementary.monitor.api.totals_schema import TotalsSchema from elementary.monitor.data_monitoring.schema import SelectorFilterSchema from elementary.monitor.fetchers.tests.schema import NormalizedTestSchema +from elementary.utils.log import get_logger from elementary.utils.time import get_now_utc_iso_format +logger = get_logger(__name__) + class ReportAPI(APIClient): def _get_groups( @@ -68,6 +73,27 @@ def _get_exposures( ) -> Dict[str, NormalizedExposureSchema]: return models_api.get_exposures(upstream_node_ids=upstream_node_ids) + def _create_subprocess_runner(self) -> SubprocessDbtRunner: + """Create a SubprocessDbtRunner for thread-safe parallel execution. + + dbt's Python API (APIDbtRunner) is not thread-safe due to global + mutable state (GLOBAL_FLAGS, adapter FACTORY, etc.). + SubprocessDbtRunner spawns an independent dbt process per call, + making it safe to use from multiple threads. + """ + runner = self.dbt_runner + return SubprocessDbtRunner( + project_dir=runner.project_dir, + profiles_dir=runner.profiles_dir, + target=runner.target, + raise_on_failure=runner.raise_on_failure, + env_vars=getattr(runner, "env_vars", None), + vars=runner.vars, + secret_vars=runner.secret_vars, + allow_macros_without_package_prefix=runner.allow_macros_without_package_prefix, + run_deps_if_needed=False, + ) + def get_report_data( self, days_back: int = 7, @@ -79,6 +105,44 @@ def get_report_data( filter: SelectorFilterSchema = SelectorFilterSchema(), env: Optional[str] = None, warehouse_type: Optional[str] = None, + threads: int = 1, + ) -> Tuple[ReportDataSchema, Optional[Exception]]: + if threads > 1: + return self._get_report_data_parallel( + days_back=days_back, + test_runs_amount=test_runs_amount, + disable_passed_test_metrics=disable_passed_test_metrics, + exclude_elementary_models=exclude_elementary_models, + project_name=project_name, + disable_samples=disable_samples, + filter=filter, + env=env, + warehouse_type=warehouse_type, + threads=threads, + ) + return self._get_report_data_sequential( + days_back=days_back, + test_runs_amount=test_runs_amount, + disable_passed_test_metrics=disable_passed_test_metrics, + exclude_elementary_models=exclude_elementary_models, + project_name=project_name, + disable_samples=disable_samples, + filter=filter, + env=env, + warehouse_type=warehouse_type, + ) + + def _get_report_data_sequential( + self, + days_back: int = 7, + test_runs_amount: int = 720, + disable_passed_test_metrics: bool = False, + exclude_elementary_models: bool = False, + project_name: Optional[str] = None, + disable_samples: bool = False, + filter: SelectorFilterSchema = SelectorFilterSchema(), + env: Optional[str] = None, + warehouse_type: Optional[str] = None, ) -> Tuple[ReportDataSchema, Optional[Exception]]: try: tests_api = TestsAPI( @@ -112,15 +176,6 @@ def get_report_data( lineage_node_ids.extend(exposures.keys()) singular_tests = tests_api.get_singular_tests() - groups = self._get_groups( - models.values(), - sources.values(), - exposures.values(), - seeds.values(), - snapshots.values(), - singular_tests, - ) - models_runs = models_api.get_models_runs( days_back=days_back, exclude_elementary_models=exclude_elementary_models ) @@ -136,98 +191,278 @@ def get_report_data( source_freshness_results = ( source_freshnesses_api.get_source_freshness_results() ) - - union_test_results = { - x: test_results.get(x, []) + source_freshness_results.get(x, []) - for x in set(test_results).union(source_freshness_results) - } - - test_results_totals = get_total_test_results(union_test_results) - test_runs = tests_api.get_test_runs() source_freshness_runs = source_freshnesses_api.get_source_freshness_runs() - union_test_runs = dict() - for key in set(test_runs).union(source_freshness_runs): - test_run = test_runs.get(key, []) - source_freshness_run = ( - source_freshness_runs.get(key, []) if key is not None else [] - ) - union_test_runs[key] = test_run + source_freshness_run - - test_runs_totals = get_total_test_runs(union_test_runs) - lineage = lineage_api.get_lineage( lineage_node_ids, exclude_elementary_models ) - filters = filters_api.get_filters( - test_results_totals, - test_runs_totals, - models, - sources, - models_runs.runs, - seeds, - snapshots, + + return self._assemble_report_data( + days_back=days_back, + project_name=project_name, + env=env, + warehouse_type=warehouse_type, + seeds=seeds, + snapshots=snapshots, + models=models, + sources=sources, + exposures=exposures, + singular_tests=singular_tests, + models_runs=models_runs, + coverages=coverages, + tests=tests, + test_invocation=test_invocation, + test_results=test_results, + source_freshness_results=source_freshness_results, + test_runs=test_runs, + source_freshness_runs=source_freshness_runs, + lineage=lineage, + filters_api=filters_api, + models_latest_invocation=invocations_api.get_models_latest_invocation(), + invocations_data=invocations_api.get_models_latest_invocations_data(), ) + except Exception as error: + return ReportDataSchema(), error - serializable_groups = groups.dict() - serializable_models = self._serialize_models( - models, sources, exposures, seeds, snapshots + def _get_report_data_parallel( + self, + days_back: int = 7, + test_runs_amount: int = 720, + disable_passed_test_metrics: bool = False, + exclude_elementary_models: bool = False, + project_name: Optional[str] = None, + disable_samples: bool = False, + filter: SelectorFilterSchema = SelectorFilterSchema(), + env: Optional[str] = None, + warehouse_type: Optional[str] = None, + threads: int = 4, + ) -> Tuple[ReportDataSchema, Optional[Exception]]: + try: + parallel_runner = self._create_subprocess_runner() + logger.info( + "Fetching report data in parallel with %d threads", threads ) - serializable_model_runs = self._serialize_models_runs(models_runs.runs) - serializable_model_runs_totals = models_runs.dict(include={"totals"})[ - "totals" - ] - serializable_models_coverages = self._serialize_coverages(coverages) - serializable_tests = self._serialize_tests(tests) - serializable_test_results = self._serialize_test_results(union_test_results) - serializable_test_results_totals = self._serialize_totals( - test_results_totals + + def _new_models_api() -> ModelsAPI: + return ModelsAPI(dbt_runner=parallel_runner) + + def _new_tests_api() -> TestsAPI: + return TestsAPI( + dbt_runner=parallel_runner, + days_back=days_back, + invocations_per_test=test_runs_amount, + disable_passed_test_metrics=disable_passed_test_metrics, + ) + + def _new_freshness_api() -> SourceFreshnessesAPI: + return SourceFreshnessesAPI( + dbt_runner=parallel_runner, + days_back=days_back, + invocations_per_test=test_runs_amount, + ) + + def _new_invocations_api() -> InvocationsAPI: + return InvocationsAPI(dbt_runner=parallel_runner) + + # Phase 1: fetch all independent data in parallel + with ThreadPoolExecutor(max_workers=threads) as pool: + f_seeds = pool.submit(_new_models_api().get_seeds) + f_snapshots = pool.submit(_new_models_api().get_snapshots) + f_models = pool.submit( + _new_models_api().get_models, exclude_elementary_models + ) + f_sources = pool.submit(_new_models_api().get_sources) + f_singular_tests = pool.submit(_new_tests_api().get_singular_tests) + f_models_runs = pool.submit( + _new_models_api().get_models_runs, + days_back, + exclude_elementary_models, + ) + f_coverages = pool.submit(_new_models_api().get_test_coverages) + f_tests = pool.submit(_new_tests_api().get_tests) + f_test_invocation = pool.submit( + _new_invocations_api().get_test_invocation_from_filter, filter + ) + f_freshness_results = pool.submit( + _new_freshness_api().get_source_freshness_results + ) + f_test_runs = pool.submit(_new_tests_api().get_test_runs) + f_freshness_runs = pool.submit( + _new_freshness_api().get_source_freshness_runs + ) + f_latest_invocation = pool.submit( + _new_invocations_api().get_models_latest_invocation + ) + f_invocations_data = pool.submit( + _new_invocations_api().get_models_latest_invocations_data + ) + + seeds = f_seeds.result() + snapshots = f_snapshots.result() + models = f_models.result() + sources = f_sources.result() + singular_tests = f_singular_tests.result() + models_runs = f_models_runs.result() + coverages = f_coverages.result() + tests = f_tests.result() + test_invocation = f_test_invocation.result() + source_freshness_results = f_freshness_results.result() + test_runs = f_test_runs.result() + source_freshness_runs = f_freshness_runs.result() + models_latest_invocation = f_latest_invocation.result() + invocations_data = f_invocations_data.result() + + # Phase 2: fetch data that depends on Phase 1 results + lineage_node_ids: List[str] = ( + list(seeds.keys()) + + list(snapshots.keys()) + + list(models.keys()) + + list(sources.keys()) ) - serializable_test_runs = self._serialize_test_runs(union_test_runs) - serializable_test_runs_totals = self._serialize_totals(test_runs_totals) - serializable_invocation = test_invocation.dict() - serializable_filters = filters.dict() - serializable_lineage = lineage.dict() - - models_latest_invocation = invocations_api.get_models_latest_invocation() - invocations = invocations_api.get_models_latest_invocations_data() - - invocations_job_identification = defaultdict(list) - for invocation in invocations: - invocation_key = invocation.job_name or invocation.job_id - if invocation_key is not None: - invocations_job_identification[invocation_key].append( - invocation.invocation_id - ) - - report_data = ReportDataSchema( - creation_time=get_now_utc_iso_format(), + + with ThreadPoolExecutor(max_workers=threads) as pool: + f_exposures = pool.submit( + _new_models_api().get_exposures, + upstream_node_ids=lineage_node_ids, + ) + f_test_results = pool.submit( + _new_tests_api().get_test_results, + test_invocation.invocation_id, + disable_samples, + ) + + exposures = f_exposures.result() + test_results = f_test_results.result() + lineage_node_ids.extend(exposures.keys()) + + # Phase 3: lineage depends on all node IDs + lineage = LineageAPI(dbt_runner=parallel_runner).get_lineage( + lineage_node_ids, exclude_elementary_models + ) + + # Phase 4: pure computation (no dbt calls) + return self._assemble_report_data( days_back=days_back, - models=serializable_models, - groups=serializable_groups, - tests=serializable_tests, - invocation=serializable_invocation, - test_results=serializable_test_results, - test_results_totals=serializable_test_results_totals, - test_runs=serializable_test_runs, - test_runs_totals=serializable_test_runs_totals, - coverages=serializable_models_coverages, - model_runs=serializable_model_runs, - model_runs_totals=serializable_model_runs_totals, - filters=serializable_filters, - lineage=serializable_lineage, - invocations=invocations, - resources_latest_invocation=models_latest_invocation, - invocations_job_identification=invocations_job_identification, - env=ReportDataEnvSchema( - project_name=project_name, env=env, warehouse_type=warehouse_type - ), + project_name=project_name, + env=env, + warehouse_type=warehouse_type, + seeds=seeds, + snapshots=snapshots, + models=models, + sources=sources, + exposures=exposures, + singular_tests=singular_tests, + models_runs=models_runs, + coverages=coverages, + tests=tests, + test_invocation=test_invocation, + test_results=test_results, + source_freshness_results=source_freshness_results, + test_runs=test_runs, + source_freshness_runs=source_freshness_runs, + lineage=lineage, + filters_api=FiltersAPI(dbt_runner=parallel_runner), + models_latest_invocation=models_latest_invocation, + invocations_data=invocations_data, ) - return report_data, None except Exception as error: return ReportDataSchema(), error + def _assemble_report_data( + self, + days_back, + project_name, + env, + warehouse_type, + seeds, + snapshots, + models, + sources, + exposures, + singular_tests, + models_runs, + coverages, + tests, + test_invocation, + test_results, + source_freshness_results, + test_runs, + source_freshness_runs, + lineage, + filters_api, + models_latest_invocation, + invocations_data, + ) -> Tuple[ReportDataSchema, Optional[Exception]]: + groups = self._get_groups( + models.values(), + sources.values(), + exposures.values(), + seeds.values(), + snapshots.values(), + singular_tests, + ) + + union_test_results = { + x: test_results.get(x, []) + source_freshness_results.get(x, []) + for x in set(test_results).union(source_freshness_results) + } + test_results_totals = get_total_test_results(union_test_results) + + union_test_runs = dict() + for key in set(test_runs).union(source_freshness_runs): + test_run = test_runs.get(key, []) + source_freshness_run = ( + source_freshness_runs.get(key, []) if key is not None else [] + ) + union_test_runs[key] = test_run + source_freshness_run + test_runs_totals = get_total_test_runs(union_test_runs) + + filters = filters_api.get_filters( + test_results_totals, + test_runs_totals, + models, + sources, + models_runs.runs, + seeds, + snapshots, + ) + + invocations_job_identification = defaultdict(list) + for invocation in invocations_data: + invocation_key = invocation.job_name or invocation.job_id + if invocation_key is not None: + invocations_job_identification[invocation_key].append( + invocation.invocation_id + ) + + report_data = ReportDataSchema( + creation_time=get_now_utc_iso_format(), + days_back=days_back, + models=self._serialize_models( + models, sources, exposures, seeds, snapshots + ), + groups=groups.dict(), + tests=self._serialize_tests(tests), + invocation=test_invocation.dict(), + test_results=self._serialize_test_results(union_test_results), + test_results_totals=self._serialize_totals(test_results_totals), + test_runs=self._serialize_test_runs(union_test_runs), + test_runs_totals=self._serialize_totals(test_runs_totals), + coverages=self._serialize_coverages(coverages), + model_runs=self._serialize_models_runs(models_runs.runs), + model_runs_totals=models_runs.dict(include={"totals"})["totals"], + filters=filters.dict(), + lineage=lineage.dict(), + invocations=invocations_data, + resources_latest_invocation=models_latest_invocation, + invocations_job_identification=invocations_job_identification, + env=ReportDataEnvSchema( + project_name=project_name, env=env, warehouse_type=warehouse_type + ), + ) + return report_data, None + def _serialize_models( self, models: Dict[str, NormalizedModelSchema], diff --git a/elementary/monitor/cli.py b/elementary/monitor/cli.py index c86b4b992..dbbdd4416 100644 --- a/elementary/monitor/cli.py +++ b/elementary/monitor/cli.py @@ -449,6 +449,13 @@ def monitor( default=True, help="Whether to open the report in the browser.", ) +@click.option( + "--threads", + type=click.IntRange(min=1), + default=1, + help="Number of threads for fetching report data in parallel. " + "When set to >1, independent dbt operations run concurrently using subprocess-based runners.", +) @click.pass_context def report( ctx, @@ -464,6 +471,7 @@ def report( file_path, disable_passed_test_metrics, open_browser, + threads, exclude_elementary_models, disable_samples, project_name, @@ -511,6 +519,7 @@ def report( exclude_elementary_models=exclude_elementary_models, should_open_browser=open_browser, project_name=project_name, + threads=threads, ) anonymous_tracking.track_cli_end( Command.REPORT, data_monitoring.properties(), ctx.command.name @@ -660,6 +669,13 @@ def report( default=None, help="Include additional information at the test results summary message.\nCurrently only --include descriptions is supported.", ) +@click.option( + "--threads", + type=click.IntRange(min=1), + default=1, + help="Number of threads for fetching report data in parallel. " + "When set to >1, independent dbt operations run concurrently using subprocess-based runners.", +) @click.pass_context def send_report( ctx, @@ -701,6 +717,7 @@ def send_report( select, disable, include, + threads, target_path, quiet_logs, ssl_ca_bundle, @@ -784,6 +801,7 @@ def send_report( remote_file_path=bucket_file_path, disable_html_attachment=(disable == "html_attachment"), include_description=(include == "description"), + threads=threads, ) anonymous_tracking.track_cli_end( diff --git a/elementary/monitor/data_monitoring/report/data_monitoring_report.py b/elementary/monitor/data_monitoring/report/data_monitoring_report.py index 7493b96e7..005b41e92 100644 --- a/elementary/monitor/data_monitoring/report/data_monitoring_report.py +++ b/elementary/monitor/data_monitoring/report/data_monitoring_report.py @@ -61,6 +61,7 @@ def generate_report( should_open_browser: bool = True, exclude_elementary_models: bool = False, project_name: Optional[str] = None, + threads: int = 1, ) -> Tuple[bool, str]: html_path = self._get_report_file_path(file_path) output_data = self.get_report_data( @@ -69,6 +70,7 @@ def generate_report( disable_passed_test_metrics=disable_passed_test_metrics, exclude_elementary_models=exclude_elementary_models, project_name=project_name, + threads=threads, ) template_html_path = os.path.join(os.path.dirname(__file__), "index.html") @@ -110,6 +112,7 @@ def get_report_data( disable_passed_test_metrics: bool = False, exclude_elementary_models: bool = False, project_name: Optional[str] = None, + threads: int = 1, ): report_api = ReportAPI(self.internal_dbt_runner) report_data, error = report_api.get_report_data( @@ -122,6 +125,7 @@ def get_report_data( filter=self.selector_filter.to_selector_filter_schema(), env=self.config.env, warehouse_type=self.warehouse_info.type if self.warehouse_info else None, + threads=threads, ) self._add_report_tracking(report_data, error) if error: @@ -182,6 +186,7 @@ def send_report( remote_file_path: Optional[str] = None, disable_html_attachment: bool = False, include_description: bool = False, + threads: int = 1, ): # Generate the report generated_report_successfully, local_html_path = self.generate_report( @@ -192,6 +197,7 @@ def send_report( should_open_browser=should_open_browser, exclude_elementary_models=exclude_elementary_models, project_name=project_name, + threads=threads, ) if not generated_report_successfully: diff --git a/tests/unit/monitor/api/report/__init__.py b/tests/unit/monitor/api/report/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/monitor/api/report/test_report_parallel.py b/tests/unit/monitor/api/report/test_report_parallel.py new file mode 100644 index 000000000..9f5ecec8b --- /dev/null +++ b/tests/unit/monitor/api/report/test_report_parallel.py @@ -0,0 +1,125 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from elementary.monitor.api.report.report import ReportAPI + + +@pytest.fixture +def mock_dbt_runner(): + runner = MagicMock() + runner.project_dir = "/tmp/project" + runner.profiles_dir = "/tmp/profiles" + runner.target = "dev" + runner.raise_on_failure = True + runner.env_vars = {"KEY": "value"} + runner.vars = {} + runner.secret_vars = {} + runner.allow_macros_without_package_prefix = False + return runner + + +class TestCreateSubprocessRunner: + def test_creates_runner_with_correct_config(self, mock_dbt_runner): + api = ReportAPI(mock_dbt_runner) + with patch( + "elementary.monitor.api.report.report.SubprocessDbtRunner" + ) as mock_cls: + api._create_subprocess_runner() + mock_cls.assert_called_once_with( + project_dir="/tmp/project", + profiles_dir="/tmp/profiles", + target="dev", + raise_on_failure=True, + env_vars={"KEY": "value"}, + vars={}, + secret_vars={}, + allow_macros_without_package_prefix=False, + run_deps_if_needed=False, + ) + + def test_deps_not_run(self, mock_dbt_runner): + api = ReportAPI(mock_dbt_runner) + with patch( + "elementary.monitor.api.report.report.SubprocessDbtRunner" + ) as mock_cls: + api._create_subprocess_runner() + call_kwargs = mock_cls.call_args[1] + assert call_kwargs["run_deps_if_needed"] is False + + +class TestGetReportDataRouting: + def test_threads_1_uses_sequential(self, mock_dbt_runner): + api = ReportAPI(mock_dbt_runner) + with patch.object(api, "_get_report_data_sequential") as mock_seq: + mock_seq.return_value = (MagicMock(), None) + api.get_report_data(threads=1) + mock_seq.assert_called_once() + + def test_threads_gt1_uses_parallel(self, mock_dbt_runner): + api = ReportAPI(mock_dbt_runner) + with patch.object(api, "_get_report_data_parallel") as mock_par: + mock_par.return_value = (MagicMock(), None) + api.get_report_data(threads=4) + mock_par.assert_called_once() + + def test_threads_passed_to_parallel(self, mock_dbt_runner): + api = ReportAPI(mock_dbt_runner) + with patch.object(api, "_get_report_data_parallel") as mock_par: + mock_par.return_value = (MagicMock(), None) + api.get_report_data(threads=8) + call_kwargs = mock_par.call_args[1] + assert call_kwargs["threads"] == 8 + + +class TestGetReportDataParallel: + def test_uses_thread_pool_executor(self, mock_dbt_runner): + api = ReportAPI(mock_dbt_runner) + with ( + patch.object(api, "_create_subprocess_runner") as mock_create, + patch( + "elementary.monitor.api.report.report.ThreadPoolExecutor" + ) as mock_pool_cls, + patch( + "elementary.monitor.api.report.report.ModelsAPI" + ), + patch( + "elementary.monitor.api.report.report.TestsAPI" + ), + patch( + "elementary.monitor.api.report.report.SourceFreshnessesAPI" + ), + patch( + "elementary.monitor.api.report.report.InvocationsAPI" + ), + patch( + "elementary.monitor.api.report.report.LineageAPI" + ), + patch( + "elementary.monitor.api.report.report.FiltersAPI" + ), + patch.object(api, "_assemble_report_data") as mock_assemble, + ): + mock_create.return_value = MagicMock() + mock_pool = MagicMock() + mock_pool_cls.return_value.__enter__ = MagicMock(return_value=mock_pool) + mock_pool_cls.return_value.__exit__ = MagicMock(return_value=False) + mock_pool.submit.return_value.result.return_value = MagicMock( + invocation_id="inv-1" + ) + mock_assemble.return_value = (MagicMock(), None) + + _, err = api._get_report_data_parallel(threads=4) + + mock_pool_cls.assert_called_with(max_workers=4) + assert err is None + mock_assemble.assert_called_once() + + def test_error_propagation(self, mock_dbt_runner): + api = ReportAPI(mock_dbt_runner) + error = RuntimeError("test error") + with patch.object( + api, "_create_subprocess_runner", side_effect=error + ): + _, err = api._get_report_data_parallel(threads=4) + assert err is error