Skip to content

Commit 1629b72

Browse files
authored
Merge branch 'master' into devin/1772277728-optimize-athena-ci
2 parents d398702 + 02174cb commit 1629b72

7 files changed

Lines changed: 586 additions & 56 deletions

File tree

elementary/clients/dbt/api_dbt_runner.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ class APIDbtRunner(CommandLineDbtRunner):
2727
def _inner_run_command(
2828
self,
2929
dbt_command_args: List[str],
30-
capture_output: bool,
3130
quiet: bool,
3231
log_output: bool,
3332
log_format: str,

elementary/clients/dbt/command_line_dbt_runner.py

Lines changed: 211 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,50 @@
55
from typing import Any, Dict, List, Optional
66

77
import yaml
8+
from tenacity import (
9+
RetryCallState,
10+
retry,
11+
retry_if_exception,
12+
stop_after_attempt,
13+
wait_exponential,
14+
)
815

916
from elementary.clients.dbt.base_dbt_runner import BaseDbtRunner
1017
from elementary.clients.dbt.dbt_log import parse_dbt_output
18+
from elementary.clients.dbt.transient_errors import is_transient_error
1119
from elementary.exceptions.exceptions import DbtCommandError, DbtLsCommandError
1220
from elementary.monitor.dbt_project_utils import is_dbt_package_up_to_date
1321
from elementary.utils.env_vars import is_debug
1422
from elementary.utils.log import get_logger
1523

1624
logger = get_logger(__name__)
1725

26+
# Retry configuration for transient errors.
27+
_TRANSIENT_MAX_RETRIES = 3
28+
_TRANSIENT_WAIT_MULTIPLIER = 10 # seconds
29+
_TRANSIENT_WAIT_MAX = 60 # seconds
30+
31+
32+
class DbtTransientError(Exception):
33+
"""Raised internally to signal a transient dbt failure that should be retried."""
34+
35+
def __init__(self, result: "DbtCommandResult", message: str) -> None:
36+
super().__init__(message)
37+
self.result = result
38+
39+
40+
def _before_retry_log(retry_state: RetryCallState) -> None:
41+
"""Log before each retry. Reads log_command_args from the retried call."""
42+
log_command_args = retry_state.kwargs.get("log_command_args", [])
43+
attempt = retry_state.attempt_number
44+
logger.warning(
45+
"Transient error detected for dbt command '%s' (attempt %d/%d). Retrying...",
46+
" ".join(log_command_args),
47+
attempt,
48+
_TRANSIENT_MAX_RETRIES,
49+
)
50+
51+
1852
MACRO_RESULT_PATTERN = re.compile(
1953
"Elementary: --ELEMENTARY-MACRO-OUTPUT-START--(.*)--ELEMENTARY-MACRO-OUTPUT-END--"
2054
)
@@ -50,17 +84,78 @@ def __init__(
5084
secret_vars,
5185
allow_macros_without_package_prefix,
5286
)
87+
self.adapter_type = self._get_adapter_type()
5388
self.raise_on_failure = raise_on_failure
5489
self.env_vars = env_vars
5590
if force_dbt_deps:
5691
self.deps()
5792
elif run_deps_if_needed:
5893
self._run_deps_if_needed()
5994

95+
def _get_adapter_type(self) -> Optional[str]:
96+
"""Resolve the adapter type from ``profiles.yml``.
97+
98+
Reads the profile name from ``dbt_project.yml``, then looks up the
99+
selected target in ``profiles.yml`` to extract its ``type`` field
100+
(e.g. ``"bigquery"``, ``"snowflake"``).
101+
102+
Returns ``None`` when profiles.yml or the expected keys are missing.
103+
"""
104+
profiles_dir = (
105+
self.profiles_dir
106+
if self.profiles_dir
107+
else os.path.join(os.path.expanduser("~"), ".dbt")
108+
)
109+
profiles_path = os.path.join(profiles_dir, "profiles.yml")
110+
if not os.path.exists(profiles_path):
111+
logger.debug("profiles.yml not found at %s", profiles_path)
112+
return None
113+
114+
with open(profiles_path) as f:
115+
profiles = yaml.safe_load(f)
116+
117+
# Read dbt_project.yml to get the profile name.
118+
dbt_project_path = os.path.join(self.project_dir, "dbt_project.yml")
119+
if not os.path.exists(dbt_project_path):
120+
logger.debug("dbt_project.yml not found at %s", dbt_project_path)
121+
return None
122+
123+
with open(dbt_project_path) as f:
124+
dbt_project = yaml.safe_load(f)
125+
126+
profile_name = dbt_project.get("profile")
127+
if not profile_name:
128+
logger.debug("No profile name found in dbt_project.yml")
129+
return None
130+
131+
profile = profiles.get(profile_name) if profiles else None
132+
if not profile:
133+
logger.debug("Profile '%s' not found in profiles.yml", profile_name)
134+
return None
135+
136+
# Determine which target to use.
137+
target_name = self.target or profile.get("target")
138+
if not target_name:
139+
logger.debug("No target specified and no default target in profile")
140+
return None
141+
142+
target_config = profile.get("outputs", {}).get(target_name)
143+
if not target_config:
144+
logger.debug("Target '%s' not found in profile outputs", target_name)
145+
return None
146+
147+
adapter_type = target_config.get("type")
148+
if adapter_type:
149+
logger.debug(
150+
"Resolved adapter type '%s' for target '%s'",
151+
adapter_type,
152+
target_name,
153+
)
154+
return adapter_type
155+
60156
def _inner_run_command(
61157
self,
62158
dbt_command_args: List[str],
63-
capture_output: bool,
64159
quiet: bool,
65160
log_output: bool,
66161
log_format: str,
@@ -75,15 +170,13 @@ def _parse_ls_command_result(
75170
def _run_command(
76171
self,
77172
command_args: List[str],
78-
capture_output: bool = False,
79173
log_format: str = "json",
80174
vars: Optional[dict] = None,
81175
quiet: bool = False,
82176
log_output: bool = True,
83177
) -> DbtCommandResult:
84178
dbt_command_args = []
85-
if capture_output:
86-
dbt_command_args.extend(["--log-format", log_format])
179+
dbt_command_args.extend(["--log-format", log_format])
87180
dbt_command_args.extend(command_args)
88181
dbt_command_args.extend(["--project-dir", os.path.abspath(self.project_dir)])
89182
if self.profiles_dir:
@@ -112,28 +205,108 @@ def _run_command(
112205
else:
113206
logger.debug(log_msg)
114207

115-
result = self._inner_run_command(
116-
dbt_command_args,
117-
capture_output=capture_output,
118-
quiet=quiet,
119-
log_output=log_output,
120-
log_format=log_format,
121-
)
122-
123-
if capture_output and result.output:
208+
try:
209+
return self._inner_run_command_with_retries(
210+
dbt_command_args=dbt_command_args,
211+
log_command_args=log_command_args,
212+
quiet=quiet,
213+
log_output=log_output,
214+
log_format=log_format,
215+
)
216+
except DbtTransientError as exc:
217+
logger.exception(
218+
"dbt command '%s' failed after %d attempts due to transient errors.",
219+
" ".join(log_command_args),
220+
_TRANSIENT_MAX_RETRIES,
221+
)
222+
if isinstance(exc.__cause__, DbtCommandError):
223+
raise exc.__cause__ from exc
224+
return exc.result
225+
226+
@retry(
227+
retry=retry_if_exception(lambda exc: isinstance(exc, DbtTransientError)),
228+
stop=stop_after_attempt(_TRANSIENT_MAX_RETRIES),
229+
wait=wait_exponential(
230+
multiplier=_TRANSIENT_WAIT_MULTIPLIER,
231+
max=_TRANSIENT_WAIT_MAX,
232+
),
233+
before_sleep=_before_retry_log,
234+
reraise=True,
235+
)
236+
def _inner_run_command_with_retries(
237+
self,
238+
dbt_command_args: List[str],
239+
log_command_args: List[str],
240+
quiet: bool,
241+
log_output: bool,
242+
log_format: str,
243+
) -> DbtCommandResult:
244+
"""Run one dbt command attempt. Raises DbtTransientError for transient failures so tenacity can retry."""
245+
try:
246+
result = self._inner_run_command(
247+
dbt_command_args,
248+
quiet=quiet,
249+
log_output=log_output,
250+
log_format=log_format,
251+
)
252+
except DbtCommandError as exc:
253+
output_text = str(exc)
254+
stderr_text: Optional[str] = None
255+
if exc.proc_err is not None:
256+
if exc.proc_err.output:
257+
output_text = (
258+
exc.proc_err.output.decode()
259+
if isinstance(exc.proc_err.output, bytes)
260+
else str(exc.proc_err.output)
261+
)
262+
if exc.proc_err.stderr:
263+
stderr_text = (
264+
exc.proc_err.stderr.decode()
265+
if isinstance(exc.proc_err.stderr, bytes)
266+
else str(exc.proc_err.stderr)
267+
)
268+
if is_transient_error(
269+
self.adapter_type, output=output_text, stderr=stderr_text
270+
):
271+
raise DbtTransientError(
272+
result=DbtCommandResult(
273+
success=False,
274+
output=output_text,
275+
stderr=stderr_text,
276+
),
277+
message=f"Transient error during dbt command: {exc}",
278+
) from exc
279+
raise
280+
281+
if result.output:
124282
logger.debug(
125-
f"Result bytes size for command '{log_command_args}' is {len(result.output)}"
283+
"Result bytes size for command '%s' is %d",
284+
" ".join(log_command_args),
285+
len(result.output),
126286
)
127287
if log_output or is_debug():
128288
for log in parse_dbt_output(result.output, log_format):
129289
logger.info(log.msg)
130290

291+
if not result.success and is_transient_error(
292+
self.adapter_type, output=result.output, stderr=result.stderr
293+
):
294+
raise DbtTransientError(
295+
result=result,
296+
message=(
297+
f"Transient error during dbt command: "
298+
f"{' '.join(log_command_args)}"
299+
),
300+
)
301+
131302
return result
132303

133-
def deps(self, quiet: bool = False, capture_output: bool = True) -> bool:
134-
result = self._run_command(
135-
command_args=["deps"], quiet=quiet, capture_output=capture_output
136-
)
304+
def deps(
305+
self,
306+
quiet: bool = False,
307+
capture_output: bool = True, # Deprecated: no-op, kept for backward compatibility.
308+
) -> bool:
309+
result = self._run_command(command_args=["deps"], quiet=quiet)
137310
return result.success
138311

139312
def seed(self, select: Optional[str] = None, full_refresh: bool = False) -> bool:
@@ -152,7 +325,7 @@ def snapshot(self) -> bool:
152325
def run_operation(
153326
self,
154327
macro_name: str,
155-
capture_output: bool = True,
328+
capture_output: bool = True, # Deprecated: no-op, kept for backward compatibility.
156329
macro_args: Optional[dict] = None,
157330
log_errors: bool = True,
158331
vars: Optional[dict] = None,
@@ -177,7 +350,6 @@ def run_operation(
177350
command_args.extend(["--args", json_args])
178351
result = self._run_command(
179352
command_args=command_args,
180-
capture_output=capture_output,
181353
vars=vars,
182354
quiet=quiet,
183355
log_output=log_output,
@@ -191,23 +363,22 @@ def run_operation(
191363
log_pattern = (
192364
RAW_EDR_LOGS_PATTERN if return_raw_edr_logs else MACRO_RESULT_PATTERN
193365
)
194-
if capture_output:
195-
if result.output is not None:
196-
for log in parse_dbt_output(result.output):
197-
if log_errors and log.level == "error":
198-
logger.error(log.msg)
199-
continue
200-
201-
if log.msg:
202-
match = log_pattern.match(log.msg)
203-
if match:
204-
run_operation_results.append(match.group(1))
205-
206-
if result.stderr is not None and log_errors:
207-
for log in parse_dbt_output(result.stderr):
208-
if log.level == "error":
209-
logger.error(log.msg)
210-
continue
366+
if result.output is not None:
367+
for log in parse_dbt_output(result.output):
368+
if log_errors and log.level == "error":
369+
logger.error(log.msg)
370+
continue
371+
372+
if log.msg:
373+
match = log_pattern.match(log.msg)
374+
if match:
375+
run_operation_results.append(match.group(1))
376+
377+
if result.stderr is not None and log_errors:
378+
for log in parse_dbt_output(result.stderr):
379+
if log.level == "error":
380+
logger.error(log.msg)
381+
continue
211382

212383
return run_operation_results
213384

@@ -218,7 +389,7 @@ def run(
218389
full_refresh: bool = False,
219390
vars: Optional[dict] = None,
220391
quiet: bool = False,
221-
capture_output: bool = False,
392+
capture_output: bool = False, # Deprecated: no-op, kept for backward compatibility.
222393
) -> bool:
223394
command_args = ["run"]
224395
if full_refresh:
@@ -231,7 +402,6 @@ def run(
231402
command_args=command_args,
232403
vars=vars,
233404
quiet=quiet,
234-
capture_output=capture_output,
235405
)
236406
return result.success
237407

@@ -240,7 +410,7 @@ def test(
240410
select: Optional[str] = None,
241411
vars: Optional[dict] = None,
242412
quiet: bool = False,
243-
capture_output: bool = False,
413+
capture_output: bool = False, # Deprecated: no-op, kept for backward compatibility.
244414
) -> bool:
245415
command_args = ["test"]
246416
if select:
@@ -249,7 +419,6 @@ def test(
249419
command_args=command_args,
250420
vars=vars,
251421
quiet=quiet,
252-
capture_output=capture_output,
253422
)
254423
return result.success
255424

@@ -266,9 +435,7 @@ def ls(self, select: Optional[str] = None) -> list:
266435
if select:
267436
command_args.extend(["-s", select])
268437
try:
269-
result = self._run_command(
270-
command_args=command_args, capture_output=True, log_format="text"
271-
)
438+
result = self._run_command(command_args=command_args, log_format="text")
272439
return self._parse_ls_command_result(select, result)
273440
except DbtCommandError:
274441
raise DbtLsCommandError(select)

0 commit comments

Comments
 (0)