diff --git a/integration_tests/requirements.txt b/integration_tests/requirements.txt index 4bc90f57f..d43fc7f3e 100644 --- a/integration_tests/requirements.txt +++ b/integration_tests/requirements.txt @@ -3,6 +3,7 @@ pytest-xdist pytest-parametrization pytest-html filelock +tenacity # urllib3>=2.2.2 fixes CVE-2023-45803 and CVE-2024-37891 # Upper bound <3.0.0 prevents breaking changes from future major versions urllib3>=2.2.2,<3.0.0 diff --git a/integration_tests/tests/adapter_query_runner.py b/integration_tests/tests/adapter_query_runner.py new file mode 100644 index 000000000..580d600cc --- /dev/null +++ b/integration_tests/tests/adapter_query_runner.py @@ -0,0 +1,221 @@ +"""Direct database query execution via dbt adapter connection. + +Bypasses ``run_operation`` log-parsing entirely so that query results are +never lost due to intermittent log-capture issues in the CLI / fusion +runners. +""" + +import json +import multiprocessing +import os +import re +from datetime import date, datetime, time +from decimal import Decimal +from pathlib import Path +from typing import Any, Dict, List, Optional + +from dbt.adapters.base import BaseAdapter +from logger import get_logger + +logger = get_logger(__name__) + + +class UnsupportedJinjaError(Exception): + """Raised when a query contains Jinja expressions beyond ref()/source().""" + + def __init__(self, query: str) -> None: + self.query = query + super().__init__( + "Query contains Jinja expressions beyond {{ ref() }} / {{ source() }} " + "which cannot be executed via the direct adapter path. " + "Use the run_operation fallback instead." + ) + + +# Pattern that matches {{ ref('name') }} or {{ ref("name") }} with optional whitespace +_REF_PATTERN = re.compile(r"\{\{\s*ref\(\s*['\"]([^'\"]+)['\"]\s*\)\s*\}\}") + +# Pattern that matches {{ source('source_name', 'table_name') }} +_SOURCE_PATTERN = re.compile( + r"\{\{\s*source\(\s*['\"]([^'\"]+)['\"]\s*,\s*['\"]([^'\"]+)['\"]\s*\)\s*\}\}" +) + +# Pattern that matches any Jinja expression {{ ... }} +_JINJA_EXPR_PATTERN = re.compile(r"\{\{.*?\}\}") + + +def _serialize_value(val: Any) -> Any: + """Mimic elementary's ``agate_to_dicts`` serialisation. + + * ``Decimal`` → ``int`` (no fractional part) or ``float`` + * ``datetime`` / ``date`` / ``time`` → ISO-format string + * Everything else is returned unchanged. + """ + if isinstance(val, Decimal): + # Match the Jinja macro: normalize, then int or float + normalized = val.normalize() + if normalized.as_tuple().exponent >= 0: + return int(normalized) + return float(normalized) + if isinstance(val, (datetime, date, time)): + return val.isoformat() + return val + + +class AdapterQueryRunner: + """Execute SQL directly through a dbt adapter connection. + + Parameters + ---------- + project_dir : str + Path to the dbt project directory. + target : str + Name of the dbt target / profile output to use. + """ + + def __init__(self, project_dir: str, target: str) -> None: + self._project_dir = project_dir + self._target = target + self._adapter: BaseAdapter = self._create_adapter(project_dir, target) + self._ref_map: Optional[Dict[str, str]] = None + self._source_map: Optional[Dict[tuple, str]] = None + + # ------------------------------------------------------------------ + # Adapter bootstrap + # ------------------------------------------------------------------ + + @staticmethod + def _create_adapter(project_dir: str, target: str) -> BaseAdapter: + from argparse import Namespace + + from dbt.adapters.factory import get_adapter, register_adapter, reset_adapters + from dbt.config.runtime import RuntimeConfig + from dbt.flags import set_from_args + + profiles_dir = os.environ.get("DBT_PROFILES_DIR", os.path.expanduser("~/.dbt")) + args = Namespace( + project_dir=project_dir, + profiles_dir=profiles_dir, + target=target, + threads=1, + vars={}, + profile=None, + PROFILES_DIR=profiles_dir, + PROJECT_DIR=project_dir, + ) + set_from_args(args, None) + config = RuntimeConfig.from_args(args) + + reset_adapters() + mp_context = multiprocessing.get_context("spawn") + register_adapter(config, mp_context) + return get_adapter(config) + + # ------------------------------------------------------------------ + # Ref resolution + # ------------------------------------------------------------------ + + def _load_manifest_maps(self) -> None: + """Load ref and source maps from the dbt manifest.""" + manifest_path = Path(self._project_dir) / "target" / "manifest.json" + if not manifest_path.exists(): + raise FileNotFoundError( + f"Manifest not found at {manifest_path}. " + "Run `dbt run` or `dbt compile` first." + ) + with open(manifest_path) as fh: + manifest = json.load(fh) + + ref_map: Dict[str, str] = {} + for node in manifest.get("nodes", {}).values(): + relation_name = node.get("relation_name") + name = node.get("name") + if relation_name and name: + ref_map[name] = relation_name + + source_map: Dict[tuple, str] = {} + for source in manifest.get("sources", {}).values(): + relation_name = source.get("relation_name") + name = source.get("name") + source_name = source.get("source_name") + if relation_name and source_name and name: + source_map[(source_name, name)] = relation_name + # Also register source tables by name for simple ref() lookups + ref_map.setdefault(name, relation_name) + + self._ref_map = ref_map + self._source_map = source_map + + def _ensure_maps_loaded(self) -> None: + """Lazily load manifest maps on first use.""" + if self._ref_map is None: + self._load_manifest_maps() + + def resolve_refs(self, query: str) -> str: + """Replace ``{{ ref('name') }}`` and ``{{ source('x','y') }}`` with relation names.""" + self._ensure_maps_loaded() + assert self._ref_map is not None + assert self._source_map is not None + + def _replace_ref(match: re.Match) -> str: # type: ignore[type-arg] + name = match.group(1) + if name not in self._ref_map: + # Manifest may have changed (temp models/seeds); reload once. + self._load_manifest_maps() + assert self._ref_map is not None + if name not in self._ref_map: + raise ValueError( + f"Cannot resolve ref('{name}'): not found in dbt manifest." + ) + return self._ref_map[name] + + def _replace_source(match: re.Match) -> str: # type: ignore[type-arg] + source_name, table_name = match.group(1), match.group(2) + key = (source_name, table_name) + if self._source_map is None or key not in self._source_map: + self._load_manifest_maps() + assert self._source_map is not None + if key not in self._source_map: + raise ValueError( + f"Cannot resolve source('{source_name}', '{table_name}'): " + "not found in dbt manifest." + ) + return self._source_map[key] + + query = _REF_PATTERN.sub(_replace_ref, query) + query = _SOURCE_PATTERN.sub(_replace_source, query) + return query + + # ------------------------------------------------------------------ + # Query execution + # ------------------------------------------------------------------ + + @staticmethod + def has_non_ref_jinja(query: str) -> bool: + """Return True if *query* contains Jinja beyond ``{{ ref() }}`` / ``{{ source() }}``.""" + stripped = _REF_PATTERN.sub("", query) + stripped = _SOURCE_PATTERN.sub("", stripped) + return bool(_JINJA_EXPR_PATTERN.search(stripped)) + + def run_query(self, prerendered_query: str) -> List[Dict[str, Any]]: + """Render Jinja refs/sources and execute a query, returning rows as dicts. + + Column names are lower-cased and values are serialised to match the + behaviour of ``elementary.agate_to_dicts``. + + Only ``{{ ref() }}`` and ``{{ source() }}`` Jinja expressions are + supported. Raises ``UnsupportedJinjaError`` if the query contains + other Jinja expressions. + """ + if self.has_non_ref_jinja(prerendered_query): + raise UnsupportedJinjaError(prerendered_query) + sql = self.resolve_refs(prerendered_query) + with self._adapter.connection_named("run_query"): + _response, table = self._adapter.execute(sql, fetch=True) + + # Convert agate Table → list[dict] matching agate_to_dicts behaviour + columns = [c.lower() for c in table.column_names] + return [ + {col: _serialize_value(val) for col, val in zip(columns, row)} + for row in table + ] diff --git a/integration_tests/tests/dbt_project.py b/integration_tests/tests/dbt_project.py index ca45d2a09..44b9f3f41 100644 --- a/integration_tests/tests/dbt_project.py +++ b/integration_tests/tests/dbt_project.py @@ -6,16 +6,30 @@ from typing import Any, Dict, Generator, List, Literal, Optional, Union, overload from uuid import uuid4 +from adapter_query_runner import AdapterQueryRunner, UnsupportedJinjaError from data_seeder import DbtDataSeeder from dbt_utils import get_database_and_schema_properties from elementary.clients.dbt.base_dbt_runner import BaseDbtRunner from elementary.clients.dbt.factory import RunnerMethod, create_dbt_runner from logger import get_logger from ruamel.yaml import YAML +from tenacity import ( + RetryCallState, + retry, + retry_if_result, + stop_after_attempt, + wait_fixed, +) PYTEST_XDIST_WORKER = os.environ.get("PYTEST_XDIST_WORKER", None) SCHEMA_NAME_SUFFIX = f"_{PYTEST_XDIST_WORKER}" if PYTEST_XDIST_WORKER else "" +# Retry settings for the run_operation fallback path. run_operation() can +# intermittently return an empty list when the MACRO_RESULT_PATTERN log line +# is not captured from dbt's output. +_RUN_QUERY_MAX_RETRIES = 3 +_RUN_QUERY_RETRY_DELAY_SECONDS = 0.5 + _DEFAULT_VARS = { "disable_dbt_invocation_autoupload": True, "disable_dbt_artifacts_autoupload": True, @@ -59,14 +73,70 @@ def __init__( self.tmp_models_dir_path = self.models_dir_path / "tmp" self.seeds_dir_path = self.project_dir_path / "data" + self._query_runner: Optional[AdapterQueryRunner] = None + + def _get_query_runner(self) -> AdapterQueryRunner: + """Lazily initialize the direct adapter query runner.""" + if self._query_runner is None: + self._query_runner = AdapterQueryRunner( + str(self.project_dir_path), self.target + ) + return self._query_runner + def run_query(self, prerendered_query: str): - results = json.loads( - self.dbt_runner.run_operation( - "elementary.render_run_query", - macro_args={"prerendered_query": prerendered_query}, - )[0] + # Fast path: queries that only contain {{ ref() }} / {{ source() }} + # can be executed directly through the adapter, bypassing + # run_operation log parsing entirely. + try: + return self._get_query_runner().run_query(prerendered_query) + except UnsupportedJinjaError: + logger.debug("Query contains complex Jinja; falling back to run_operation") + + # Slow path: full Jinja rendering via run_operation (with retry). + return self._run_query_with_run_operation(prerendered_query) + + @staticmethod + def _log_retry(retry_state: RetryCallState) -> None: + """Tenacity before_sleep callback — logs each retry with attempt number.""" + logger.warning( + "run_operation('elementary.render_run_query') returned no output; " + "retry %d/%d in %.1fs", + retry_state.attempt_number, + _RUN_QUERY_MAX_RETRIES, + _RUN_QUERY_RETRY_DELAY_SECONDS, ) - return results + + @retry( + retry=retry_if_result(lambda r: r is None), + stop=stop_after_attempt(_RUN_QUERY_MAX_RETRIES), + wait=wait_fixed(_RUN_QUERY_RETRY_DELAY_SECONDS), + before_sleep=_log_retry.__func__, + reraise=True, + ) + def _run_operation_with_retry(self, prerendered_query: str) -> Optional[list]: + """Call run_operation and return the parsed result, or None to trigger retry.""" + run_operation_results = self.dbt_runner.run_operation( + "elementary.render_run_query", + macro_args={"prerendered_query": prerendered_query}, + ) + if run_operation_results: + return json.loads(run_operation_results[0]) + return None + + def _run_query_with_run_operation(self, prerendered_query: str): + """Execute a query via run_operation with retry on empty output. + + run_operation() can intermittently return an empty list when the + MACRO_RESULT_PATTERN log line is not captured from dbt's output. + """ + result = self._run_operation_with_retry(prerendered_query) + if result is None: + raise RuntimeError( + f"run_operation('elementary.render_run_query') returned no output " + f"after {_RUN_QUERY_MAX_RETRIES} attempts. " + f"Query: {prerendered_query!r}" + ) + return result @staticmethod def read_table_query(