55from typing import Any , Dict , List , Optional
66
77import yaml
8+ from tenacity import (
9+ RetryCallState ,
10+ retry ,
11+ retry_if_exception ,
12+ stop_after_attempt ,
13+ wait_exponential ,
14+ )
815
916from elementary .clients .dbt .base_dbt_runner import BaseDbtRunner
1017from elementary .clients .dbt .dbt_log import parse_dbt_output
18+ from elementary .clients .dbt .transient_errors import is_transient_error
1119from elementary .exceptions .exceptions import DbtCommandError , DbtLsCommandError
1220from elementary .monitor .dbt_project_utils import is_dbt_package_up_to_date
1321from elementary .utils .env_vars import is_debug
1422from elementary .utils .log import get_logger
1523
1624logger = get_logger (__name__ )
1725
26+ # Retry configuration for transient errors.
27+ _TRANSIENT_MAX_RETRIES = 3
28+ _TRANSIENT_WAIT_MULTIPLIER = 10 # seconds
29+ _TRANSIENT_WAIT_MAX = 60 # seconds
30+
31+
32+ class DbtTransientError (Exception ):
33+ """Raised internally to signal a transient dbt failure that should be retried."""
34+
35+ def __init__ (self , result : "DbtCommandResult" , message : str ) -> None :
36+ super ().__init__ (message )
37+ self .result = result
38+
39+
40+ def _before_retry_log (retry_state : RetryCallState ) -> None :
41+ """Log before each retry. Reads log_command_args from the retried call."""
42+ log_command_args = retry_state .kwargs .get ("log_command_args" , [])
43+ attempt = retry_state .attempt_number
44+ logger .warning (
45+ "Transient error detected for dbt command '%s' (attempt %d/%d). Retrying..." ,
46+ " " .join (log_command_args ),
47+ attempt ,
48+ _TRANSIENT_MAX_RETRIES ,
49+ )
50+
51+
1852MACRO_RESULT_PATTERN = re .compile (
1953 "Elementary: --ELEMENTARY-MACRO-OUTPUT-START--(.*)--ELEMENTARY-MACRO-OUTPUT-END--"
2054)
@@ -50,17 +84,78 @@ def __init__(
5084 secret_vars ,
5185 allow_macros_without_package_prefix ,
5286 )
87+ self .adapter_type = self ._get_adapter_type ()
5388 self .raise_on_failure = raise_on_failure
5489 self .env_vars = env_vars
5590 if force_dbt_deps :
5691 self .deps ()
5792 elif run_deps_if_needed :
5893 self ._run_deps_if_needed ()
5994
95+ def _get_adapter_type (self ) -> Optional [str ]:
96+ """Resolve the adapter type from ``profiles.yml``.
97+
98+ Reads the profile name from ``dbt_project.yml``, then looks up the
99+ selected target in ``profiles.yml`` to extract its ``type`` field
100+ (e.g. ``"bigquery"``, ``"snowflake"``).
101+
102+ Returns ``None`` when profiles.yml or the expected keys are missing.
103+ """
104+ profiles_dir = (
105+ self .profiles_dir
106+ if self .profiles_dir
107+ else os .path .join (os .path .expanduser ("~" ), ".dbt" )
108+ )
109+ profiles_path = os .path .join (profiles_dir , "profiles.yml" )
110+ if not os .path .exists (profiles_path ):
111+ logger .debug ("profiles.yml not found at %s" , profiles_path )
112+ return None
113+
114+ with open (profiles_path ) as f :
115+ profiles = yaml .safe_load (f )
116+
117+ # Read dbt_project.yml to get the profile name.
118+ dbt_project_path = os .path .join (self .project_dir , "dbt_project.yml" )
119+ if not os .path .exists (dbt_project_path ):
120+ logger .debug ("dbt_project.yml not found at %s" , dbt_project_path )
121+ return None
122+
123+ with open (dbt_project_path ) as f :
124+ dbt_project = yaml .safe_load (f )
125+
126+ profile_name = dbt_project .get ("profile" )
127+ if not profile_name :
128+ logger .debug ("No profile name found in dbt_project.yml" )
129+ return None
130+
131+ profile = profiles .get (profile_name ) if profiles else None
132+ if not profile :
133+ logger .debug ("Profile '%s' not found in profiles.yml" , profile_name )
134+ return None
135+
136+ # Determine which target to use.
137+ target_name = self .target or profile .get ("target" )
138+ if not target_name :
139+ logger .debug ("No target specified and no default target in profile" )
140+ return None
141+
142+ target_config = profile .get ("outputs" , {}).get (target_name )
143+ if not target_config :
144+ logger .debug ("Target '%s' not found in profile outputs" , target_name )
145+ return None
146+
147+ adapter_type = target_config .get ("type" )
148+ if adapter_type :
149+ logger .debug (
150+ "Resolved adapter type '%s' for target '%s'" ,
151+ adapter_type ,
152+ target_name ,
153+ )
154+ return adapter_type
155+
60156 def _inner_run_command (
61157 self ,
62158 dbt_command_args : List [str ],
63- capture_output : bool ,
64159 quiet : bool ,
65160 log_output : bool ,
66161 log_format : str ,
@@ -75,15 +170,13 @@ def _parse_ls_command_result(
75170 def _run_command (
76171 self ,
77172 command_args : List [str ],
78- capture_output : bool = False ,
79173 log_format : str = "json" ,
80174 vars : Optional [dict ] = None ,
81175 quiet : bool = False ,
82176 log_output : bool = True ,
83177 ) -> DbtCommandResult :
84178 dbt_command_args = []
85- if capture_output :
86- dbt_command_args .extend (["--log-format" , log_format ])
179+ dbt_command_args .extend (["--log-format" , log_format ])
87180 dbt_command_args .extend (command_args )
88181 dbt_command_args .extend (["--project-dir" , os .path .abspath (self .project_dir )])
89182 if self .profiles_dir :
@@ -112,28 +205,108 @@ def _run_command(
112205 else :
113206 logger .debug (log_msg )
114207
115- result = self ._inner_run_command (
116- dbt_command_args ,
117- capture_output = capture_output ,
118- quiet = quiet ,
119- log_output = log_output ,
120- log_format = log_format ,
121- )
122-
123- if capture_output and result .output :
208+ try :
209+ return self ._inner_run_command_with_retries (
210+ dbt_command_args = dbt_command_args ,
211+ log_command_args = log_command_args ,
212+ quiet = quiet ,
213+ log_output = log_output ,
214+ log_format = log_format ,
215+ )
216+ except DbtTransientError as exc :
217+ logger .exception (
218+ "dbt command '%s' failed after %d attempts due to transient errors." ,
219+ " " .join (log_command_args ),
220+ _TRANSIENT_MAX_RETRIES ,
221+ )
222+ if isinstance (exc .__cause__ , DbtCommandError ):
223+ raise exc .__cause__ from exc
224+ return exc .result
225+
226+ @retry (
227+ retry = retry_if_exception (lambda exc : isinstance (exc , DbtTransientError )),
228+ stop = stop_after_attempt (_TRANSIENT_MAX_RETRIES ),
229+ wait = wait_exponential (
230+ multiplier = _TRANSIENT_WAIT_MULTIPLIER ,
231+ max = _TRANSIENT_WAIT_MAX ,
232+ ),
233+ before_sleep = _before_retry_log ,
234+ reraise = True ,
235+ )
236+ def _inner_run_command_with_retries (
237+ self ,
238+ dbt_command_args : List [str ],
239+ log_command_args : List [str ],
240+ quiet : bool ,
241+ log_output : bool ,
242+ log_format : str ,
243+ ) -> DbtCommandResult :
244+ """Run one dbt command attempt. Raises DbtTransientError for transient failures so tenacity can retry."""
245+ try :
246+ result = self ._inner_run_command (
247+ dbt_command_args ,
248+ quiet = quiet ,
249+ log_output = log_output ,
250+ log_format = log_format ,
251+ )
252+ except DbtCommandError as exc :
253+ output_text = str (exc )
254+ stderr_text : Optional [str ] = None
255+ if exc .proc_err is not None :
256+ if exc .proc_err .output :
257+ output_text = (
258+ exc .proc_err .output .decode ()
259+ if isinstance (exc .proc_err .output , bytes )
260+ else str (exc .proc_err .output )
261+ )
262+ if exc .proc_err .stderr :
263+ stderr_text = (
264+ exc .proc_err .stderr .decode ()
265+ if isinstance (exc .proc_err .stderr , bytes )
266+ else str (exc .proc_err .stderr )
267+ )
268+ if is_transient_error (
269+ self .adapter_type , output = output_text , stderr = stderr_text
270+ ):
271+ raise DbtTransientError (
272+ result = DbtCommandResult (
273+ success = False ,
274+ output = output_text ,
275+ stderr = stderr_text ,
276+ ),
277+ message = f"Transient error during dbt command: { exc } " ,
278+ ) from exc
279+ raise
280+
281+ if result .output :
124282 logger .debug (
125- f"Result bytes size for command '{ log_command_args } ' is { len (result .output )} "
283+ "Result bytes size for command '%s' is %d" ,
284+ " " .join (log_command_args ),
285+ len (result .output ),
126286 )
127287 if log_output or is_debug ():
128288 for log in parse_dbt_output (result .output , log_format ):
129289 logger .info (log .msg )
130290
291+ if not result .success and is_transient_error (
292+ self .adapter_type , output = result .output , stderr = result .stderr
293+ ):
294+ raise DbtTransientError (
295+ result = result ,
296+ message = (
297+ f"Transient error during dbt command: "
298+ f"{ ' ' .join (log_command_args )} "
299+ ),
300+ )
301+
131302 return result
132303
133- def deps (self , quiet : bool = False , capture_output : bool = True ) -> bool :
134- result = self ._run_command (
135- command_args = ["deps" ], quiet = quiet , capture_output = capture_output
136- )
304+ def deps (
305+ self ,
306+ quiet : bool = False ,
307+ capture_output : bool = True , # Deprecated: no-op, kept for backward compatibility.
308+ ) -> bool :
309+ result = self ._run_command (command_args = ["deps" ], quiet = quiet )
137310 return result .success
138311
139312 def seed (self , select : Optional [str ] = None , full_refresh : bool = False ) -> bool :
@@ -152,7 +325,7 @@ def snapshot(self) -> bool:
152325 def run_operation (
153326 self ,
154327 macro_name : str ,
155- capture_output : bool = True ,
328+ capture_output : bool = True , # Deprecated: no-op, kept for backward compatibility.
156329 macro_args : Optional [dict ] = None ,
157330 log_errors : bool = True ,
158331 vars : Optional [dict ] = None ,
@@ -177,7 +350,6 @@ def run_operation(
177350 command_args .extend (["--args" , json_args ])
178351 result = self ._run_command (
179352 command_args = command_args ,
180- capture_output = capture_output ,
181353 vars = vars ,
182354 quiet = quiet ,
183355 log_output = log_output ,
@@ -191,23 +363,22 @@ def run_operation(
191363 log_pattern = (
192364 RAW_EDR_LOGS_PATTERN if return_raw_edr_logs else MACRO_RESULT_PATTERN
193365 )
194- if capture_output :
195- if result .output is not None :
196- for log in parse_dbt_output (result .output ):
197- if log_errors and log .level == "error" :
198- logger .error (log .msg )
199- continue
200-
201- if log .msg :
202- match = log_pattern .match (log .msg )
203- if match :
204- run_operation_results .append (match .group (1 ))
205-
206- if result .stderr is not None and log_errors :
207- for log in parse_dbt_output (result .stderr ):
208- if log .level == "error" :
209- logger .error (log .msg )
210- continue
366+ if result .output is not None :
367+ for log in parse_dbt_output (result .output ):
368+ if log_errors and log .level == "error" :
369+ logger .error (log .msg )
370+ continue
371+
372+ if log .msg :
373+ match = log_pattern .match (log .msg )
374+ if match :
375+ run_operation_results .append (match .group (1 ))
376+
377+ if result .stderr is not None and log_errors :
378+ for log in parse_dbt_output (result .stderr ):
379+ if log .level == "error" :
380+ logger .error (log .msg )
381+ continue
211382
212383 return run_operation_results
213384
@@ -218,7 +389,7 @@ def run(
218389 full_refresh : bool = False ,
219390 vars : Optional [dict ] = None ,
220391 quiet : bool = False ,
221- capture_output : bool = False ,
392+ capture_output : bool = False , # Deprecated: no-op, kept for backward compatibility.
222393 ) -> bool :
223394 command_args = ["run" ]
224395 if full_refresh :
@@ -231,7 +402,6 @@ def run(
231402 command_args = command_args ,
232403 vars = vars ,
233404 quiet = quiet ,
234- capture_output = capture_output ,
235405 )
236406 return result .success
237407
@@ -240,7 +410,7 @@ def test(
240410 select : Optional [str ] = None ,
241411 vars : Optional [dict ] = None ,
242412 quiet : bool = False ,
243- capture_output : bool = False ,
413+ capture_output : bool = False , # Deprecated: no-op, kept for backward compatibility.
244414 ) -> bool :
245415 command_args = ["test" ]
246416 if select :
@@ -249,7 +419,6 @@ def test(
249419 command_args = command_args ,
250420 vars = vars ,
251421 quiet = quiet ,
252- capture_output = capture_output ,
253422 )
254423 return result .success
255424
@@ -266,9 +435,7 @@ def ls(self, select: Optional[str] = None) -> list:
266435 if select :
267436 command_args .extend (["-s" , select ])
268437 try :
269- result = self ._run_command (
270- command_args = command_args , capture_output = True , log_format = "text"
271- )
438+ result = self ._run_command (command_args = command_args , log_format = "text" )
272439 return self ._parse_ls_command_result (select , result )
273440 except DbtCommandError :
274441 raise DbtLsCommandError (select )
0 commit comments