-
Notifications
You must be signed in to change notification settings - Fork 137
Expand file tree
/
Copy pathadapter_query_runner.py
More file actions
221 lines (182 loc) · 8.31 KB
/
adapter_query_runner.py
File metadata and controls
221 lines (182 loc) · 8.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
"""Direct database query execution via dbt adapter connection.
Bypasses ``run_operation`` log-parsing entirely so that query results are
never lost due to intermittent log-capture issues in the CLI / fusion
runners.
"""
import json
import multiprocessing
import os
import re
from datetime import date, datetime, time
from decimal import Decimal
from pathlib import Path
from typing import Any, Dict, List, Optional
from dbt.adapters.base import BaseAdapter
from logger import get_logger
logger = get_logger(__name__)
class UnsupportedJinjaError(Exception):
"""Raised when a query contains Jinja expressions beyond ref()/source()."""
def __init__(self, query: str) -> None:
self.query = query
super().__init__(
"Query contains Jinja expressions beyond {{ ref() }} / {{ source() }} "
"which cannot be executed via the direct adapter path. "
"Use the run_operation fallback instead."
)
# Pattern that matches {{ ref('name') }} or {{ ref("name") }} with optional whitespace
_REF_PATTERN = re.compile(r"\{\{\s*ref\(\s*['\"]([^'\"]+)['\"]\s*\)\s*\}\}")
# Pattern that matches {{ source('source_name', 'table_name') }}
_SOURCE_PATTERN = re.compile(
r"\{\{\s*source\(\s*['\"]([^'\"]+)['\"]\s*,\s*['\"]([^'\"]+)['\"]\s*\)\s*\}\}"
)
# Pattern that matches any Jinja expression {{ ... }}
_JINJA_EXPR_PATTERN = re.compile(r"\{\{.*?\}\}")
def _serialize_value(val: Any) -> Any:
"""Mimic elementary's ``agate_to_dicts`` serialisation.
* ``Decimal`` → ``int`` (no fractional part) or ``float``
* ``datetime`` / ``date`` / ``time`` → ISO-format string
* Everything else is returned unchanged.
"""
if isinstance(val, Decimal):
# Match the Jinja macro: normalize, then int or float
normalized = val.normalize()
if normalized.as_tuple().exponent >= 0:
return int(normalized)
return float(normalized)
if isinstance(val, (datetime, date, time)):
return val.isoformat()
return val
class AdapterQueryRunner:
"""Execute SQL directly through a dbt adapter connection.
Parameters
----------
project_dir : str
Path to the dbt project directory.
target : str
Name of the dbt target / profile output to use.
"""
def __init__(self, project_dir: str, target: str) -> None:
self._project_dir = project_dir
self._target = target
self._adapter: BaseAdapter = self._create_adapter(project_dir, target)
self._ref_map: Optional[Dict[str, str]] = None
self._source_map: Optional[Dict[tuple, str]] = None
# ------------------------------------------------------------------
# Adapter bootstrap
# ------------------------------------------------------------------
@staticmethod
def _create_adapter(project_dir: str, target: str) -> BaseAdapter:
from argparse import Namespace
from dbt.adapters.factory import get_adapter, register_adapter, reset_adapters
from dbt.config.runtime import RuntimeConfig
from dbt.flags import set_from_args
profiles_dir = os.environ.get("DBT_PROFILES_DIR", os.path.expanduser("~/.dbt"))
args = Namespace(
project_dir=project_dir,
profiles_dir=profiles_dir,
target=target,
threads=1,
vars={},
profile=None,
PROFILES_DIR=profiles_dir,
PROJECT_DIR=project_dir,
)
set_from_args(args, None)
config = RuntimeConfig.from_args(args)
reset_adapters()
mp_context = multiprocessing.get_context("spawn")
register_adapter(config, mp_context)
return get_adapter(config)
# ------------------------------------------------------------------
# Ref resolution
# ------------------------------------------------------------------
def _load_manifest_maps(self) -> None:
"""Load ref and source maps from the dbt manifest."""
manifest_path = Path(self._project_dir) / "target" / "manifest.json"
if not manifest_path.exists():
raise FileNotFoundError(
f"Manifest not found at {manifest_path}. "
"Run `dbt run` or `dbt compile` first."
)
with open(manifest_path) as fh:
manifest = json.load(fh)
ref_map: Dict[str, str] = {}
for node in manifest.get("nodes", {}).values():
relation_name = node.get("relation_name")
name = node.get("name")
if relation_name and name:
ref_map[name] = relation_name
source_map: Dict[tuple, str] = {}
for source in manifest.get("sources", {}).values():
relation_name = source.get("relation_name")
name = source.get("name")
source_name = source.get("source_name")
if relation_name and source_name and name:
source_map[(source_name, name)] = relation_name
# Also register source tables by name for simple ref() lookups
ref_map.setdefault(name, relation_name)
self._ref_map = ref_map
self._source_map = source_map
def _ensure_maps_loaded(self) -> None:
"""Lazily load manifest maps on first use."""
if self._ref_map is None:
self._load_manifest_maps()
def resolve_refs(self, query: str) -> str:
"""Replace ``{{ ref('name') }}`` and ``{{ source('x','y') }}`` with relation names."""
self._ensure_maps_loaded()
assert self._ref_map is not None
assert self._source_map is not None
def _replace_ref(match: re.Match) -> str: # type: ignore[type-arg]
name = match.group(1)
if name not in self._ref_map:
# Manifest may have changed (temp models/seeds); reload once.
self._load_manifest_maps()
assert self._ref_map is not None
if name not in self._ref_map:
raise ValueError(
f"Cannot resolve ref('{name}'): not found in dbt manifest."
)
return self._ref_map[name]
def _replace_source(match: re.Match) -> str: # type: ignore[type-arg]
source_name, table_name = match.group(1), match.group(2)
key = (source_name, table_name)
if self._source_map is None or key not in self._source_map:
self._load_manifest_maps()
assert self._source_map is not None
if key not in self._source_map:
raise ValueError(
f"Cannot resolve source('{source_name}', '{table_name}'): "
"not found in dbt manifest."
)
return self._source_map[key]
query = _REF_PATTERN.sub(_replace_ref, query)
query = _SOURCE_PATTERN.sub(_replace_source, query)
return query
# ------------------------------------------------------------------
# Query execution
# ------------------------------------------------------------------
@staticmethod
def has_non_ref_jinja(query: str) -> bool:
"""Return True if *query* contains Jinja beyond ``{{ ref() }}`` / ``{{ source() }}``."""
stripped = _REF_PATTERN.sub("", query)
stripped = _SOURCE_PATTERN.sub("", stripped)
return bool(_JINJA_EXPR_PATTERN.search(stripped))
def run_query(self, prerendered_query: str) -> List[Dict[str, Any]]:
"""Render Jinja refs/sources and execute a query, returning rows as dicts.
Column names are lower-cased and values are serialised to match the
behaviour of ``elementary.agate_to_dicts``.
Only ``{{ ref() }}`` and ``{{ source() }}`` Jinja expressions are
supported. Raises ``UnsupportedJinjaError`` if the query contains
other Jinja expressions.
"""
if self.has_non_ref_jinja(prerendered_query):
raise UnsupportedJinjaError(prerendered_query)
sql = self.resolve_refs(prerendered_query)
with self._adapter.connection_named("run_query"):
_response, table = self._adapter.execute(sql, fetch=True)
# Convert agate Table → list[dict] matching agate_to_dicts behaviour
columns = [c.lower() for c in table.column_names]
return [
{col: _serialize_value(val) for col, val in zip(columns, row)}
for row in table
]