Skip to content

Commit c2536f5

Browse files
fix: use direct adapter connection for run_query (bypass log parsing) (#936)
1 parent f3ca4da commit c2536f5

3 files changed

Lines changed: 298 additions & 6 deletions

File tree

integration_tests/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ pytest-xdist
33
pytest-parametrization
44
pytest-html
55
filelock
6+
tenacity
67
# urllib3>=2.2.2 fixes CVE-2023-45803 and CVE-2024-37891
78
# Upper bound <3.0.0 prevents breaking changes from future major versions
89
urllib3>=2.2.2,<3.0.0
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
"""Direct database query execution via dbt adapter connection.
2+
3+
Bypasses ``run_operation`` log-parsing entirely so that query results are
4+
never lost due to intermittent log-capture issues in the CLI / fusion
5+
runners.
6+
"""
7+
8+
import json
9+
import multiprocessing
10+
import os
11+
import re
12+
from datetime import date, datetime, time
13+
from decimal import Decimal
14+
from pathlib import Path
15+
from typing import Any, Dict, List, Optional
16+
17+
from dbt.adapters.base import BaseAdapter
18+
from logger import get_logger
19+
20+
logger = get_logger(__name__)
21+
22+
23+
class UnsupportedJinjaError(Exception):
24+
"""Raised when a query contains Jinja expressions beyond ref()/source()."""
25+
26+
def __init__(self, query: str) -> None:
27+
self.query = query
28+
super().__init__(
29+
"Query contains Jinja expressions beyond {{ ref() }} / {{ source() }} "
30+
"which cannot be executed via the direct adapter path. "
31+
"Use the run_operation fallback instead."
32+
)
33+
34+
35+
# Pattern that matches {{ ref('name') }} or {{ ref("name") }} with optional whitespace
36+
_REF_PATTERN = re.compile(r"\{\{\s*ref\(\s*['\"]([^'\"]+)['\"]\s*\)\s*\}\}")
37+
38+
# Pattern that matches {{ source('source_name', 'table_name') }}
39+
_SOURCE_PATTERN = re.compile(
40+
r"\{\{\s*source\(\s*['\"]([^'\"]+)['\"]\s*,\s*['\"]([^'\"]+)['\"]\s*\)\s*\}\}"
41+
)
42+
43+
# Pattern that matches any Jinja expression {{ ... }}
44+
_JINJA_EXPR_PATTERN = re.compile(r"\{\{.*?\}\}")
45+
46+
47+
def _serialize_value(val: Any) -> Any:
48+
"""Mimic elementary's ``agate_to_dicts`` serialisation.
49+
50+
* ``Decimal`` → ``int`` (no fractional part) or ``float``
51+
* ``datetime`` / ``date`` / ``time`` → ISO-format string
52+
* Everything else is returned unchanged.
53+
"""
54+
if isinstance(val, Decimal):
55+
# Match the Jinja macro: normalize, then int or float
56+
normalized = val.normalize()
57+
if normalized.as_tuple().exponent >= 0:
58+
return int(normalized)
59+
return float(normalized)
60+
if isinstance(val, (datetime, date, time)):
61+
return val.isoformat()
62+
return val
63+
64+
65+
class AdapterQueryRunner:
66+
"""Execute SQL directly through a dbt adapter connection.
67+
68+
Parameters
69+
----------
70+
project_dir : str
71+
Path to the dbt project directory.
72+
target : str
73+
Name of the dbt target / profile output to use.
74+
"""
75+
76+
def __init__(self, project_dir: str, target: str) -> None:
77+
self._project_dir = project_dir
78+
self._target = target
79+
self._adapter: BaseAdapter = self._create_adapter(project_dir, target)
80+
self._ref_map: Optional[Dict[str, str]] = None
81+
self._source_map: Optional[Dict[tuple, str]] = None
82+
83+
# ------------------------------------------------------------------
84+
# Adapter bootstrap
85+
# ------------------------------------------------------------------
86+
87+
@staticmethod
88+
def _create_adapter(project_dir: str, target: str) -> BaseAdapter:
89+
from argparse import Namespace
90+
91+
from dbt.adapters.factory import get_adapter, register_adapter, reset_adapters
92+
from dbt.config.runtime import RuntimeConfig
93+
from dbt.flags import set_from_args
94+
95+
profiles_dir = os.environ.get("DBT_PROFILES_DIR", os.path.expanduser("~/.dbt"))
96+
args = Namespace(
97+
project_dir=project_dir,
98+
profiles_dir=profiles_dir,
99+
target=target,
100+
threads=1,
101+
vars={},
102+
profile=None,
103+
PROFILES_DIR=profiles_dir,
104+
PROJECT_DIR=project_dir,
105+
)
106+
set_from_args(args, None)
107+
config = RuntimeConfig.from_args(args)
108+
109+
reset_adapters()
110+
mp_context = multiprocessing.get_context("spawn")
111+
register_adapter(config, mp_context)
112+
return get_adapter(config)
113+
114+
# ------------------------------------------------------------------
115+
# Ref resolution
116+
# ------------------------------------------------------------------
117+
118+
def _load_manifest_maps(self) -> None:
119+
"""Load ref and source maps from the dbt manifest."""
120+
manifest_path = Path(self._project_dir) / "target" / "manifest.json"
121+
if not manifest_path.exists():
122+
raise FileNotFoundError(
123+
f"Manifest not found at {manifest_path}. "
124+
"Run `dbt run` or `dbt compile` first."
125+
)
126+
with open(manifest_path) as fh:
127+
manifest = json.load(fh)
128+
129+
ref_map: Dict[str, str] = {}
130+
for node in manifest.get("nodes", {}).values():
131+
relation_name = node.get("relation_name")
132+
name = node.get("name")
133+
if relation_name and name:
134+
ref_map[name] = relation_name
135+
136+
source_map: Dict[tuple, str] = {}
137+
for source in manifest.get("sources", {}).values():
138+
relation_name = source.get("relation_name")
139+
name = source.get("name")
140+
source_name = source.get("source_name")
141+
if relation_name and source_name and name:
142+
source_map[(source_name, name)] = relation_name
143+
# Also register source tables by name for simple ref() lookups
144+
ref_map.setdefault(name, relation_name)
145+
146+
self._ref_map = ref_map
147+
self._source_map = source_map
148+
149+
def _ensure_maps_loaded(self) -> None:
150+
"""Lazily load manifest maps on first use."""
151+
if self._ref_map is None:
152+
self._load_manifest_maps()
153+
154+
def resolve_refs(self, query: str) -> str:
155+
"""Replace ``{{ ref('name') }}`` and ``{{ source('x','y') }}`` with relation names."""
156+
self._ensure_maps_loaded()
157+
assert self._ref_map is not None
158+
assert self._source_map is not None
159+
160+
def _replace_ref(match: re.Match) -> str: # type: ignore[type-arg]
161+
name = match.group(1)
162+
if name not in self._ref_map:
163+
# Manifest may have changed (temp models/seeds); reload once.
164+
self._load_manifest_maps()
165+
assert self._ref_map is not None
166+
if name not in self._ref_map:
167+
raise ValueError(
168+
f"Cannot resolve ref('{name}'): not found in dbt manifest."
169+
)
170+
return self._ref_map[name]
171+
172+
def _replace_source(match: re.Match) -> str: # type: ignore[type-arg]
173+
source_name, table_name = match.group(1), match.group(2)
174+
key = (source_name, table_name)
175+
if self._source_map is None or key not in self._source_map:
176+
self._load_manifest_maps()
177+
assert self._source_map is not None
178+
if key not in self._source_map:
179+
raise ValueError(
180+
f"Cannot resolve source('{source_name}', '{table_name}'): "
181+
"not found in dbt manifest."
182+
)
183+
return self._source_map[key]
184+
185+
query = _REF_PATTERN.sub(_replace_ref, query)
186+
query = _SOURCE_PATTERN.sub(_replace_source, query)
187+
return query
188+
189+
# ------------------------------------------------------------------
190+
# Query execution
191+
# ------------------------------------------------------------------
192+
193+
@staticmethod
194+
def has_non_ref_jinja(query: str) -> bool:
195+
"""Return True if *query* contains Jinja beyond ``{{ ref() }}`` / ``{{ source() }}``."""
196+
stripped = _REF_PATTERN.sub("", query)
197+
stripped = _SOURCE_PATTERN.sub("", stripped)
198+
return bool(_JINJA_EXPR_PATTERN.search(stripped))
199+
200+
def run_query(self, prerendered_query: str) -> List[Dict[str, Any]]:
201+
"""Render Jinja refs/sources and execute a query, returning rows as dicts.
202+
203+
Column names are lower-cased and values are serialised to match the
204+
behaviour of ``elementary.agate_to_dicts``.
205+
206+
Only ``{{ ref() }}`` and ``{{ source() }}`` Jinja expressions are
207+
supported. Raises ``UnsupportedJinjaError`` if the query contains
208+
other Jinja expressions.
209+
"""
210+
if self.has_non_ref_jinja(prerendered_query):
211+
raise UnsupportedJinjaError(prerendered_query)
212+
sql = self.resolve_refs(prerendered_query)
213+
with self._adapter.connection_named("run_query"):
214+
_response, table = self._adapter.execute(sql, fetch=True)
215+
216+
# Convert agate Table → list[dict] matching agate_to_dicts behaviour
217+
columns = [c.lower() for c in table.column_names]
218+
return [
219+
{col: _serialize_value(val) for col, val in zip(columns, row)}
220+
for row in table
221+
]

integration_tests/tests/dbt_project.py

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,30 @@
66
from typing import Any, Dict, Generator, List, Literal, Optional, Union, overload
77
from uuid import uuid4
88

9+
from adapter_query_runner import AdapterQueryRunner, UnsupportedJinjaError
910
from data_seeder import DbtDataSeeder
1011
from dbt_utils import get_database_and_schema_properties
1112
from elementary.clients.dbt.base_dbt_runner import BaseDbtRunner
1213
from elementary.clients.dbt.factory import RunnerMethod, create_dbt_runner
1314
from logger import get_logger
1415
from ruamel.yaml import YAML
16+
from tenacity import (
17+
RetryCallState,
18+
retry,
19+
retry_if_result,
20+
stop_after_attempt,
21+
wait_fixed,
22+
)
1523

1624
PYTEST_XDIST_WORKER = os.environ.get("PYTEST_XDIST_WORKER", None)
1725
SCHEMA_NAME_SUFFIX = f"_{PYTEST_XDIST_WORKER}" if PYTEST_XDIST_WORKER else ""
1826

27+
# Retry settings for the run_operation fallback path. run_operation() can
28+
# intermittently return an empty list when the MACRO_RESULT_PATTERN log line
29+
# is not captured from dbt's output.
30+
_RUN_QUERY_MAX_RETRIES = 3
31+
_RUN_QUERY_RETRY_DELAY_SECONDS = 0.5
32+
1933
_DEFAULT_VARS = {
2034
"disable_dbt_invocation_autoupload": True,
2135
"disable_dbt_artifacts_autoupload": True,
@@ -59,14 +73,70 @@ def __init__(
5973
self.tmp_models_dir_path = self.models_dir_path / "tmp"
6074
self.seeds_dir_path = self.project_dir_path / "data"
6175

76+
self._query_runner: Optional[AdapterQueryRunner] = None
77+
78+
def _get_query_runner(self) -> AdapterQueryRunner:
79+
"""Lazily initialize the direct adapter query runner."""
80+
if self._query_runner is None:
81+
self._query_runner = AdapterQueryRunner(
82+
str(self.project_dir_path), self.target
83+
)
84+
return self._query_runner
85+
6286
def run_query(self, prerendered_query: str):
63-
results = json.loads(
64-
self.dbt_runner.run_operation(
65-
"elementary.render_run_query",
66-
macro_args={"prerendered_query": prerendered_query},
67-
)[0]
87+
# Fast path: queries that only contain {{ ref() }} / {{ source() }}
88+
# can be executed directly through the adapter, bypassing
89+
# run_operation log parsing entirely.
90+
try:
91+
return self._get_query_runner().run_query(prerendered_query)
92+
except UnsupportedJinjaError:
93+
logger.debug("Query contains complex Jinja; falling back to run_operation")
94+
95+
# Slow path: full Jinja rendering via run_operation (with retry).
96+
return self._run_query_with_run_operation(prerendered_query)
97+
98+
@staticmethod
99+
def _log_retry(retry_state: RetryCallState) -> None:
100+
"""Tenacity before_sleep callback — logs each retry with attempt number."""
101+
logger.warning(
102+
"run_operation('elementary.render_run_query') returned no output; "
103+
"retry %d/%d in %.1fs",
104+
retry_state.attempt_number,
105+
_RUN_QUERY_MAX_RETRIES,
106+
_RUN_QUERY_RETRY_DELAY_SECONDS,
68107
)
69-
return results
108+
109+
@retry(
110+
retry=retry_if_result(lambda r: r is None),
111+
stop=stop_after_attempt(_RUN_QUERY_MAX_RETRIES),
112+
wait=wait_fixed(_RUN_QUERY_RETRY_DELAY_SECONDS),
113+
before_sleep=_log_retry.__func__,
114+
reraise=True,
115+
)
116+
def _run_operation_with_retry(self, prerendered_query: str) -> Optional[list]:
117+
"""Call run_operation and return the parsed result, or None to trigger retry."""
118+
run_operation_results = self.dbt_runner.run_operation(
119+
"elementary.render_run_query",
120+
macro_args={"prerendered_query": prerendered_query},
121+
)
122+
if run_operation_results:
123+
return json.loads(run_operation_results[0])
124+
return None
125+
126+
def _run_query_with_run_operation(self, prerendered_query: str):
127+
"""Execute a query via run_operation with retry on empty output.
128+
129+
run_operation() can intermittently return an empty list when the
130+
MACRO_RESULT_PATTERN log line is not captured from dbt's output.
131+
"""
132+
result = self._run_operation_with_retry(prerendered_query)
133+
if result is None:
134+
raise RuntimeError(
135+
f"run_operation('elementary.render_run_query') returned no output "
136+
f"after {_RUN_QUERY_MAX_RETRIES} attempts. "
137+
f"Query: {prerendered_query!r}"
138+
)
139+
return result
70140

71141
@staticmethod
72142
def read_table_query(

0 commit comments

Comments
 (0)