Skip to content

Commit 5ade723

Browse files
authored
Fix!: depend on all attributes of dbt model when passed to a macro (#5269)
1 parent a3e7bda commit 5ade723

File tree

15 files changed

+210
-120
lines changed

15 files changed

+210
-120
lines changed

sqlmesh/core/model/definition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ def _statement_renderer(self, expression: exp.Expression) -> ExpressionRenderer:
594594
python_env=self.python_env,
595595
only_execution_time=False,
596596
default_catalog=self.default_catalog,
597-
model_fqn=self.fqn,
597+
model=self,
598598
)
599599
return self._statement_renderer_cache[expression_key]
600600

@@ -1573,14 +1573,14 @@ def _query_renderer(self) -> QueryRenderer:
15731573
self.dialect,
15741574
self.macro_definitions,
15751575
schema=self.mapping_schema,
1576-
model_fqn=self.fqn,
15771576
path=self._path,
15781577
jinja_macro_registry=self.jinja_macros,
15791578
python_env=self.python_env,
15801579
only_execution_time=self.kind.only_execution_time,
15811580
default_catalog=self.default_catalog,
15821581
quote_identifiers=not no_quote_identifiers,
15831582
optimize_query=self.optimize_query,
1583+
model=self,
15841584
)
15851585

15861586
@property

sqlmesh/core/renderer.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from sqlglot.dialects.dialect import DialectType
3232

3333
from sqlmesh.core.linter.rule import Rule
34+
from sqlmesh.core.model.definition import _Model
3435
from sqlmesh.core.snapshot import DeployabilityIndex, Snapshot
3536

3637

@@ -50,9 +51,9 @@ def __init__(
5051
schema: t.Optional[t.Dict[str, t.Any]] = None,
5152
default_catalog: t.Optional[str] = None,
5253
quote_identifiers: bool = True,
53-
model_fqn: t.Optional[str] = None,
5454
normalize_identifiers: bool = True,
5555
optimize_query: t.Optional[bool] = True,
56+
model: t.Optional[_Model] = None,
5657
):
5758
self._expression = expression
5859
self._dialect = dialect
@@ -66,8 +67,9 @@ def __init__(
6667
self._quote_identifiers = quote_identifiers
6768
self.update_schema({} if schema is None else schema)
6869
self._cache: t.List[t.Optional[exp.Expression]] = []
69-
self._model_fqn = model_fqn
70+
self._model_fqn = model.fqn if model else None
7071
self._optimize_query_flag = optimize_query is not False
72+
self._model = model
7173

7274
def update_schema(self, schema: t.Dict[str, t.Any]) -> None:
7375
self.schema = d.normalize_mapping_schema(schema, dialect=self._dialect)
@@ -188,30 +190,32 @@ def _resolve_table(table: str | exp.Table) -> str:
188190
}
189191

190192
variables = kwargs.pop("variables", {})
191-
jinja_env_kwargs = {
192-
**{
193-
**render_kwargs,
194-
**_prepare_python_env_for_jinja(macro_evaluator, self._python_env),
195-
**variables,
196-
},
197-
"snapshots": snapshots or {},
198-
"table_mapping": table_mapping,
199-
"deployability_index": deployability_index,
200-
"default_catalog": self._default_catalog,
201-
"runtime_stage": runtime_stage.value,
202-
"resolve_table": _resolve_table,
203-
}
204-
if this_model:
205-
render_kwargs["this_model"] = this_model
206-
jinja_env_kwargs["this_model"] = this_model.sql(
207-
dialect=self._dialect, identify=True, comments=False
208-
)
209-
210-
jinja_env = self._jinja_macro_registry.build_environment(**jinja_env_kwargs)
211193

212194
expressions = [self._expression]
213195
if isinstance(self._expression, d.Jinja):
214196
try:
197+
jinja_env_kwargs = {
198+
**{
199+
**render_kwargs,
200+
**_prepare_python_env_for_jinja(macro_evaluator, self._python_env),
201+
**variables,
202+
},
203+
"snapshots": snapshots or {},
204+
"table_mapping": table_mapping,
205+
"deployability_index": deployability_index,
206+
"default_catalog": self._default_catalog,
207+
"runtime_stage": runtime_stage.value,
208+
"resolve_table": _resolve_table,
209+
"model_instance": self._model,
210+
}
211+
212+
if this_model:
213+
jinja_env_kwargs["this_model"] = this_model.sql(
214+
dialect=self._dialect, identify=True, comments=False
215+
)
216+
217+
jinja_env = self._jinja_macro_registry.build_environment(**jinja_env_kwargs)
218+
215219
expressions = []
216220
rendered_expression = jinja_env.from_string(self._expression.name).render()
217221
logger.debug(
@@ -229,6 +233,9 @@ def _resolve_table(table: str | exp.Table) -> str:
229233
f"Could not render or parse jinja at '{self._path}'.\n{ex}"
230234
) from ex
231235

236+
if this_model:
237+
render_kwargs["this_model"] = this_model
238+
232239
macro_evaluator.locals.update(render_kwargs)
233240

234241
if variables:

sqlmesh/dbt/basemodel.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
DbtConfig,
2323
Dependencies,
2424
GeneralConfig,
25+
RAW_CODE_KEY,
2526
SqlStr,
2627
sql_str_validator,
2728
)
@@ -167,14 +168,6 @@ def _validate_grants(cls, v: t.Dict[str, str]) -> t.Dict[str, t.List[str]]:
167168
},
168169
}
169170

170-
@property
171-
def sql_no_config(self) -> SqlStr:
172-
return SqlStr("")
173-
174-
@property
175-
def sql_embedded_config(self) -> SqlStr:
176-
return SqlStr("")
177-
178171
@property
179172
def table_schema(self) -> str:
180173
"""
@@ -375,15 +368,21 @@ def to_sqlmesh(
375368
def _model_jinja_context(
376369
self, context: DbtContext, dependencies: Dependencies
377370
) -> t.Dict[str, t.Any]:
378-
model_node: AttributeDict[str, t.Any] = AttributeDict(
379-
{
380-
k: v
381-
for k, v in context._manifest._manifest.nodes[self.node_name].to_dict().items()
382-
if k in dependencies.model_attrs
383-
}
384-
if context._manifest and self.node_name in context._manifest._manifest.nodes
385-
else {}
386-
)
371+
if context._manifest and self.node_name in context._manifest._manifest.nodes:
372+
attributes = context._manifest._manifest.nodes[self.node_name].to_dict()
373+
if dependencies.model_attrs.all_attrs:
374+
model_node: AttributeDict[str, t.Any] = AttributeDict(attributes)
375+
else:
376+
model_node = AttributeDict(
377+
filter(lambda kv: kv[0] in dependencies.model_attrs.attrs, attributes.items())
378+
)
379+
380+
# We exclude the raw SQL code to reduce the payload size. It's still accessible through
381+
# the JinjaQuery instance stored in the resulting SQLMesh model's `query` field.
382+
model_node.pop(RAW_CODE_KEY, None)
383+
else:
384+
model_node = AttributeDict({})
385+
387386
return {
388387
"this": self.relation_info,
389388
"model": model_node,

sqlmesh/dbt/builtin.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
from sqlmesh.core.console import get_console
1818
from sqlmesh.core.engine_adapter import EngineAdapter
19+
from sqlmesh.core.model.definition import SqlModel
1920
from sqlmesh.core.snapshot.definition import DeployabilityIndex
2021
from sqlmesh.dbt.adapter import BaseAdapter, ParsetimeAdapter, RuntimeAdapter
22+
from sqlmesh.dbt.common import RAW_CODE_KEY
2123
from sqlmesh.dbt.relation import Policy
2224
from sqlmesh.dbt.target import TARGET_TYPE_TO_CONFIG_CLASS
2325
from sqlmesh.dbt.util import DBT_VERSION
@@ -469,12 +471,21 @@ def create_builtin_globals(
469471
is_incremental &= snapshot_table_exists
470472
else:
471473
is_incremental = False
474+
472475
builtin_globals["is_incremental"] = lambda: is_incremental
473476

474477
builtin_globals["builtins"] = AttributeDict(
475478
{k: builtin_globals.get(k) for k in ("ref", "source", "config", "var")}
476479
)
477480

481+
if (model := jinja_globals.pop("model", None)) is not None:
482+
if isinstance(model_instance := jinja_globals.pop("model_instance", None), SqlModel):
483+
builtin_globals["model"] = AttributeDict(
484+
{**model, RAW_CODE_KEY: model_instance.query.name}
485+
)
486+
else:
487+
builtin_globals["model"] = AttributeDict(model.copy())
488+
478489
if engine_adapter is not None:
479490
builtin_globals["flags"] = Flags(which="run")
480491
adapter: BaseAdapter = RuntimeAdapter(

sqlmesh/dbt/common.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,26 @@
22

33
import re
44
import typing as t
5+
from dataclasses import dataclass
56
from pathlib import Path
67

78
from ruamel.yaml.constructor import DuplicateKeyError
89
from sqlglot.helper import ensure_list
910

11+
from sqlmesh.dbt.util import DBT_VERSION
1012
from sqlmesh.core.config.base import BaseConfig, UpdateStrategy
13+
from sqlmesh.core.config.common import DBT_PROJECT_FILENAME
1114
from sqlmesh.utils import AttributeDict
1215
from sqlmesh.utils.conversions import ensure_bool, try_str_to_bool
1316
from sqlmesh.utils.errors import ConfigError
1417
from sqlmesh.utils.jinja import MacroReference
1518
from sqlmesh.utils.pydantic import PydanticModel, field_validator
1619
from sqlmesh.utils.yaml import load
17-
from sqlmesh.core.config.common import DBT_PROJECT_FILENAME
1820

1921
T = t.TypeVar("T", bound="GeneralConfig")
2022

2123
PROJECT_FILENAME = DBT_PROJECT_FILENAME
24+
RAW_CODE_KEY = "raw_code" if DBT_VERSION >= (1, 3, 0) else "raw_sql" # type: ignore
2225

2326
JINJA_ONLY = {
2427
"adapter",
@@ -172,6 +175,12 @@ def sqlmesh_config_fields(self) -> t.Set[str]:
172175
return set()
173176

174177

178+
@dataclass
179+
class ModelAttrs:
180+
attrs: t.Set[str]
181+
all_attrs: bool = False
182+
183+
175184
class Dependencies(PydanticModel):
176185
"""
177186
DBT dependencies for a model, macro, etc.
@@ -186,7 +195,7 @@ class Dependencies(PydanticModel):
186195
sources: t.Set[str] = set()
187196
refs: t.Set[str] = set()
188197
variables: t.Set[str] = set()
189-
model_attrs: t.Set[str] = set()
198+
model_attrs: ModelAttrs = ModelAttrs(attrs=set())
190199

191200
has_dynamic_var_names: bool = False
192201

@@ -196,7 +205,10 @@ def union(self, other: Dependencies) -> Dependencies:
196205
sources=self.sources | other.sources,
197206
refs=self.refs | other.refs,
198207
variables=self.variables | other.variables,
199-
model_attrs=self.model_attrs | other.model_attrs,
208+
model_attrs=ModelAttrs(
209+
attrs=self.model_attrs.attrs | other.model_attrs.attrs,
210+
all_attrs=self.model_attrs.all_attrs or other.model_attrs.all_attrs,
211+
),
200212
has_dynamic_var_names=self.has_dynamic_var_names or other.has_dynamic_var_names,
201213
)
202214

sqlmesh/dbt/loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def _to_sqlmesh(config: BMC, context: DbtContext) -> Model:
138138
package_models: t.Dict[str, BaseModelConfig] = {**package.models, **package.seeds}
139139

140140
for model in package_models.values():
141-
if isinstance(model, ModelConfig) and not model.sql_no_config:
141+
if isinstance(model, ModelConfig) and not model.sql.strip():
142142
logger.info(f"Skipping empty model '{model.name}' at path '{model.path}'.")
143143
continue
144144

sqlmesh/dbt/manifest.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@
4444
from sqlmesh.core import constants as c
4545
from sqlmesh.utils.errors import SQLMeshError
4646
from sqlmesh.core.config import ModelDefaultsConfig
47-
from sqlmesh.dbt.basemodel import Dependencies
4847
from sqlmesh.dbt.builtin import BUILTIN_FILTERS, BUILTIN_GLOBALS, OVERRIDDEN_MACROS
48+
from sqlmesh.dbt.common import Dependencies
4949
from sqlmesh.dbt.model import ModelConfig
5050
from sqlmesh.dbt.package import HookConfig, MacroConfig
5151
from sqlmesh.dbt.seed import SeedConfig
@@ -354,7 +354,9 @@ def _load_models_and_seeds(self) -> None:
354354
dependencies = Dependencies(
355355
macros=macro_references, refs=_refs(node), sources=_sources(node)
356356
)
357-
dependencies = dependencies.union(self._extra_dependencies(sql, node.package_name))
357+
dependencies = dependencies.union(
358+
self._extra_dependencies(sql, node.package_name, track_all_model_attrs=True)
359+
)
358360
dependencies = dependencies.union(
359361
self._flatten_dependencies_from_macros(dependencies.macros, node.package_name)
360362
)
@@ -552,17 +554,37 @@ def _flatten_dependencies_from_macros(
552554
dependencies = dependencies.union(macro_dependencies)
553555
return dependencies
554556

555-
def _extra_dependencies(self, target: str, package: str) -> Dependencies:
556-
# We sometimes observe that the manifest doesn't capture all macros, refs, and sources within a macro.
557-
# This behavior has been observed with macros like dbt.current_timestamp(), dbt_utils.slugify(), and source().
558-
# Here we apply our custom extractor to make a best effort to supplement references captured in the manifest.
557+
def _extra_dependencies(
558+
self,
559+
target: str,
560+
package: str,
561+
track_all_model_attrs: bool = False,
562+
) -> Dependencies:
563+
"""
564+
We sometimes observe that the manifest doesn't capture all macros, refs, and sources within a macro.
565+
This behavior has been observed with macros like dbt.current_timestamp(), dbt_utils.slugify(), and source().
566+
Here we apply our custom extractor to make a best effort to supplement references captured in the manifest.
567+
"""
559568
dependencies = Dependencies()
569+
570+
# Whether all `model` attributes (e.g., `model.config`) should be included in the dependencies
571+
all_model_attrs = False
572+
560573
for call_name, node in extract_call_names(target, cache=self._calls):
561574
if call_name[0] == "config":
562575
continue
563-
elif isinstance(node, jinja2.nodes.Getattr):
576+
577+
if (
578+
track_all_model_attrs
579+
and not all_model_attrs
580+
and isinstance(node, jinja2.nodes.Call)
581+
and any(isinstance(a, jinja2.nodes.Name) and a.name == "model" for a in node.args)
582+
):
583+
all_model_attrs = True
584+
585+
if isinstance(node, jinja2.nodes.Getattr):
564586
if call_name[0] == "model":
565-
dependencies.model_attrs.add(call_name[1])
587+
dependencies.model_attrs.attrs.add(call_name[1])
566588
elif call_name[0] == "source":
567589
args = [jinja_call_arg_name(arg) for arg in node.args]
568590
if args and all(arg for arg in args):
@@ -606,6 +628,14 @@ def _extra_dependencies(self, target: str, package: str) -> Dependencies:
606628
call_name[0], call_name[1], dependencies.macros.append
607629
)
608630

631+
# When `model` is referenced as-is, e.g. it's passed as an argument to a macro call like
632+
# `{{ foo(model) }}`, we can't easily track the attributes that are actually used, because
633+
# it may be aliased and hence tracking actual uses of `model` requires a proper data flow
634+
# analysis. We conservatively deal with this by including all of its supported attributes
635+
# if a standalone reference is found.
636+
if all_model_attrs:
637+
dependencies.model_attrs.all_attrs = True
638+
609639
return dependencies
610640

611641

0 commit comments

Comments
 (0)