Skip to content

Commit 4209672

Browse files
authored
Fix!: Improve tracking of var dependencies in dbt models (#5204)
1 parent d3eae73 commit 4209672

File tree

12 files changed

+136
-59
lines changed

12 files changed

+136
-59
lines changed

sqlmesh/dbt/basemodel.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -313,33 +313,21 @@ def sqlmesh_model_kwargs(
313313
"""Get common sqlmesh model parameters"""
314314
self.remove_tests_with_invalid_refs(context)
315315
self.check_for_circular_test_refs(context)
316+
317+
dependencies = self.dependencies.copy()
318+
if dependencies.has_dynamic_var_names:
319+
# Include ALL variables as dependencies since we couldn't determine
320+
# precisely which variables are referenced in the model
321+
dependencies.variables |= set(context.variables)
322+
316323
model_dialect = self.dialect(context)
317324
model_context = context.context_for_dependencies(
318-
self.dependencies.union(self.tests_ref_source_dependencies)
325+
dependencies.union(self.tests_ref_source_dependencies)
319326
)
320327
jinja_macros = model_context.jinja_macros.trim(
321-
self.dependencies.macros, package=self.package_name
322-
)
323-
324-
model_node: AttributeDict[str, t.Any] = AttributeDict(
325-
{
326-
k: v
327-
for k, v in context._manifest._manifest.nodes[self.node_name].to_dict().items()
328-
if k in self.dependencies.model_attrs
329-
}
330-
if context._manifest and self.node_name in context._manifest._manifest.nodes
331-
else {}
332-
)
333-
334-
jinja_macros.add_globals(
335-
{
336-
"this": self.relation_info,
337-
"model": model_node,
338-
"schema": self.table_schema,
339-
"config": self.config_attribute_dict,
340-
**model_context.jinja_globals, # type: ignore
341-
}
328+
dependencies.macros, package=self.package_name
342329
)
330+
jinja_macros.add_globals(self._model_jinja_context(model_context, dependencies))
343331
return {
344332
"audits": [(test.name, {}) for test in self.tests],
345333
"columns": column_types_to_sqlmesh(
@@ -369,3 +357,23 @@ def to_sqlmesh(
369357
virtual_environment_mode: VirtualEnvironmentMode = VirtualEnvironmentMode.default,
370358
) -> Model:
371359
"""Convert DBT model into sqlmesh Model"""
360+
361+
def _model_jinja_context(
362+
self, context: DbtContext, dependencies: Dependencies
363+
) -> t.Dict[str, t.Any]:
364+
model_node: AttributeDict[str, t.Any] = AttributeDict(
365+
{
366+
k: v
367+
for k, v in context._manifest._manifest.nodes[self.node_name].to_dict().items()
368+
if k in dependencies.model_attrs
369+
}
370+
if context._manifest and self.node_name in context._manifest._manifest.nodes
371+
else {}
372+
)
373+
return {
374+
"this": self.relation_info,
375+
"model": model_node,
376+
"schema": self.table_schema,
377+
"config": self.config_attribute_dict,
378+
**context.jinja_globals,
379+
}

sqlmesh/dbt/common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,13 +184,16 @@ class Dependencies(PydanticModel):
184184
variables: t.Set[str] = set()
185185
model_attrs: t.Set[str] = set()
186186

187+
has_dynamic_var_names: bool = False
188+
187189
def union(self, other: Dependencies) -> Dependencies:
188190
return Dependencies(
189191
macros=list(set(self.macros) | set(other.macros)),
190192
sources=self.sources | other.sources,
191193
refs=self.refs | other.refs,
192194
variables=self.variables | other.variables,
193195
model_attrs=self.model_attrs | other.model_attrs,
196+
has_dynamic_var_names=self.has_dynamic_var_names or other.has_dynamic_var_names,
194197
)
195198

196199
@field_validator("macros", mode="after")

sqlmesh/dbt/context.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from sqlmesh.core.config import Config as SQLMeshConfig
1010
from sqlmesh.dbt.builtin import _relation_info_to_relation
11+
from sqlmesh.dbt.common import Dependencies
1112
from sqlmesh.dbt.manifest import ManifestHelper
1213
from sqlmesh.dbt.target import TargetConfig
1314
from sqlmesh.utils import AttributeDict
@@ -22,7 +23,6 @@
2223
if t.TYPE_CHECKING:
2324
from jinja2 import Environment
2425

25-
from sqlmesh.dbt.basemodel import Dependencies
2626
from sqlmesh.dbt.model import ModelConfig
2727
from sqlmesh.dbt.relation import Policy
2828
from sqlmesh.dbt.seed import SeedConfig
@@ -101,8 +101,6 @@ def add_variables(self, variables: t.Dict[str, t.Any]) -> None:
101101
self._jinja_environment = None
102102

103103
def set_and_render_variables(self, variables: t.Dict[str, t.Any], package: str) -> None:
104-
self.variables = variables
105-
106104
jinja_environment = self.jinja_macros.build_environment(**self.jinja_globals)
107105

108106
def _render_var(value: t.Any) -> t.Any:

sqlmesh/dbt/loader.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,6 @@ def _to_sqlmesh(config: BMC, context: DbtContext) -> Model:
124124
)
125125

126126
for project in self._load_projects():
127-
context = project.context.copy()
128-
129127
macros_max_mtime = self._macros_max_mtime
130128
yaml_max_mtimes = self._compute_yaml_max_mtime_per_subfolder(
131129
project.context.project_root
@@ -135,12 +133,13 @@ def _to_sqlmesh(config: BMC, context: DbtContext) -> Model:
135133
logger.debug("Converting models to sqlmesh")
136134
# Now that config is rendered, create the sqlmesh models
137135
for package in project.packages.values():
138-
context.set_and_render_variables(package.variables, package.name)
136+
package_context = project.context.copy()
137+
package_context.set_and_render_variables(package.variables, package.name)
139138
package_models: t.Dict[str, BaseModelConfig] = {**package.models, **package.seeds}
140139

141140
for model in package_models.values():
142141
sqlmesh_model = cache.get_or_load_models(
143-
model.path, loader=lambda: [_to_sqlmesh(model, context)]
142+
model.path, loader=lambda: [_to_sqlmesh(model, package_context)]
144143
)[0]
145144

146145
models[sqlmesh_model.fqn] = sqlmesh_model
@@ -155,15 +154,14 @@ def _load_audits(
155154
audits: UniqueKeyDict = UniqueKeyDict("audits")
156155

157156
for project in self._load_projects():
158-
context = project.context
159-
160157
logger.debug("Converting audits to sqlmesh")
161158
for package in project.packages.values():
162-
context.set_and_render_variables(package.variables, package.name)
159+
package_context = project.context.copy()
160+
package_context.set_and_render_variables(package.variables, package.name)
163161
for test in package.tests.values():
164162
logger.debug("Converting '%s' to sqlmesh format", test.name)
165163
try:
166-
audits[test.name] = test.to_sqlmesh(context)
164+
audits[test.name] = test.to_sqlmesh(package_context)
167165
except MissingModelError as e:
168166
logger.warning(
169167
"Skipping audit '%s' because model '%s' is not a valid ref",
@@ -244,9 +242,9 @@ def _load_environment_statements(self, macros: MacroRegistry) -> t.List[Environm
244242
project_names: t.Set[str] = set()
245243
dialect = self.config.dialect
246244
for project in self._load_projects():
247-
context = project.context
248245
for package_name, package in project.packages.items():
249-
context.set_and_render_variables(package.variables, package_name)
246+
package_context = project.context.copy()
247+
package_context.set_and_render_variables(package.variables, package_name)
250248
on_run_start: t.List[str] = [
251249
on_run_hook.sql
252250
for on_run_hook in sorted(package.on_run_start.values(), key=lambda h: h.index)
@@ -261,7 +259,7 @@ def _load_environment_statements(self, macros: MacroRegistry) -> t.List[Environm
261259
for hook in [*package.on_run_start.values(), *package.on_run_end.values()]:
262260
dependencies = dependencies.union(hook.dependencies)
263261

264-
statements_context = context.context_for_dependencies(dependencies)
262+
statements_context = package_context.context_for_dependencies(dependencies)
265263
jinja_registry = make_jinja_registry(
266264
statements_context.jinja_macros, package_name, set(dependencies.macros)
267265
)

sqlmesh/dbt/manifest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,9 @@ def _extra_dependencies(self, target: str, package: str) -> Dependencies:
554554
args = [jinja_call_arg_name(arg) for arg in node.args]
555555
if args and args[0]:
556556
dependencies.variables.add(args[0])
557+
else:
558+
# We couldn't determine the var name statically
559+
dependencies.has_dynamic_var_names = True
557560
dependencies.macros.append(MacroReference(name="var"))
558561
elif len(call_name) == 1:
559562
macro_name = call_name[0]

sqlmesh/dbt/project.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,6 @@ def load(cls, context: DbtContext, variables: t.Optional[t.Dict[str, t.Any]] = N
5555
raise ConfigError(f"Could not find {PROJECT_FILENAME} in {context.project_root}")
5656
project_yaml = load_yaml(project_file_path)
5757

58-
variable_overrides = variables
59-
variables = {**project_yaml.get("vars", {}), **(variables or {})}
60-
6158
project_name = context.render(project_yaml.get("name", ""))
6259
context.project_name = project_name
6360
if not context.project_name:
@@ -69,6 +66,7 @@ def load(cls, context: DbtContext, variables: t.Optional[t.Dict[str, t.Any]] = N
6966
profile = Profile.load(context, context.target_name)
7067
context.target = profile.target
7168

69+
variable_overrides = variables or {}
7270
context.manifest = ManifestHelper(
7371
project_file_path.parent,
7472
profile.path.parent,
@@ -101,13 +99,17 @@ def load(cls, context: DbtContext, variables: t.Optional[t.Dict[str, t.Any]] = N
10199
package = package_loader.load(path.parent)
102100
packages[package.name] = package
103101

102+
all_project_variables = {**project_yaml.get("vars", {}), **(variable_overrides or {})}
104103
for name, package in packages.items():
105-
package_vars = variables.get(name)
104+
package_vars = all_project_variables.get(name)
106105

107106
if isinstance(package_vars, dict):
108107
package.variables.update(package_vars)
109108

110-
package.variables.update(variables)
109+
if name == context.project_name:
110+
package.variables.update(all_project_variables)
111+
else:
112+
package.variables.update(variable_overrides)
111113

112114
return Project(context, profile, packages)
113115

tests/dbt/test_config.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ def test_variables(assert_exp_eq, sushi_test_project):
362362
"nested_vars": {
363363
"some_nested_var": 2,
364364
},
365+
"dynamic_test_var": 3,
365366
"list_var": [
366367
{"name": "item1", "value": 1},
367368
{"name": "item2", "value": 2},
@@ -375,25 +376,10 @@ def test_variables(assert_exp_eq, sushi_test_project):
375376
expected_customer_variables = {
376377
"some_var": ["foo", "bar"],
377378
"some_other_var": 5,
378-
"yet_another_var": 1,
379+
"yet_another_var": 5,
379380
"customers:bla": False,
380381
"customers:customer_id": "customer_id",
381382
"start": "Jan 1 2022",
382-
"top_waiters:limit": 10,
383-
"top_waiters:revenue": "revenue",
384-
"customers:boo": ["a", "b"],
385-
"nested_vars": {
386-
"some_nested_var": 2,
387-
},
388-
"list_var": [
389-
{"name": "item1", "value": 1},
390-
{"name": "item2", "value": 2},
391-
],
392-
"customers": {
393-
"customers:bla": False,
394-
"customers:customer_id": "customer_id",
395-
"some_var": ["foo", "bar"],
396-
},
397383
}
398384

399385
assert sushi_test_project.packages["sushi"].variables == expected_sushi_variables
@@ -406,7 +392,9 @@ def test_nested_variables(sushi_test_project):
406392
sql="SELECT {{ var('nested_vars')['some_nested_var'] }}",
407393
dependencies=Dependencies(variables=["nested_vars"]),
408394
)
409-
sqlmesh_model = model_config.to_sqlmesh(sushi_test_project.context)
395+
context = sushi_test_project.context.copy()
396+
context.set_and_render_variables(sushi_test_project.packages["sushi"].variables, "sushi")
397+
sqlmesh_model = model_config.to_sqlmesh(context)
410398
assert sqlmesh_model.jinja_macros.global_objs["vars"]["nested_vars"] == {"some_nested_var": 2}
411399

412400

tests/dbt/test_manifest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def test_manifest_helper(caplog):
7979
waiter_revenue_by_day_config = models["waiter_revenue_by_day_v2"]
8080
assert waiter_revenue_by_day_config.dependencies == Dependencies(
8181
macros={
82+
MacroReference(name="dynamic_var_name_dependency"),
8283
MacroReference(name="log_value"),
8384
MacroReference(name="test_dependencies"),
8485
MacroReference(package="customers", name="duckdb__current_engine"),
@@ -87,6 +88,7 @@ def test_manifest_helper(caplog):
8788
},
8889
sources={"streaming.items", "streaming.orders", "streaming.order_items"},
8990
variables={"yet_another_var", "nested_vars"},
91+
has_dynamic_var_names=True,
9092
)
9193
assert waiter_revenue_by_day_config.materialized == "incremental"
9294
assert waiter_revenue_by_day_config.incremental_strategy == "delete+insert"

tests/dbt/test_transformation.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
)
3838
from sqlmesh.core.state_sync.db.snapshot import _snapshot_to_json
3939
from sqlmesh.dbt.builtin import _relation_info_to_relation
40+
from sqlmesh.dbt.common import Dependencies
4041
from sqlmesh.dbt.column import (
4142
ColumnConfig,
4243
column_descriptions_to_sqlmesh,
@@ -50,6 +51,7 @@
5051
from sqlmesh.dbt.target import BigQueryConfig, DuckDbConfig, SnowflakeConfig, ClickhouseConfig
5152
from sqlmesh.dbt.test import TestConfig
5253
from sqlmesh.utils.errors import ConfigError, MacroEvalError, SQLMeshError
54+
from sqlmesh.utils.jinja import MacroReference
5355

5456
pytestmark = [pytest.mark.dbt, pytest.mark.slow]
5557

@@ -1530,6 +1532,9 @@ def test_dbt_package_macros(sushi_test_project: Project):
15301532
@pytest.mark.xdist_group("dbt_manifest")
15311533
def test_dbt_vars(sushi_test_project: Project):
15321534
context = sushi_test_project.context
1535+
context.set_and_render_variables(
1536+
sushi_test_project.packages["customers"].variables, "customers"
1537+
)
15331538

15341539
assert context.render("{{ var('some_other_var') }}") == "5"
15351540
assert context.render("{{ var('some_other_var', 0) }}") == "5"
@@ -1854,3 +1859,65 @@ def test_on_run_start_end():
18541859
"CREATE OR REPLACE TABLE schema_table_sushi__dev_nested_package AS SELECT 'sushi__dev' AS schema",
18551860
]
18561861
)
1862+
1863+
1864+
@pytest.mark.xdist_group("dbt_manifest")
1865+
def test_dynamic_var_names(sushi_test_project: Project, sushi_test_dbt_context: Context):
1866+
context = sushi_test_project.context
1867+
context.set_and_render_variables(sushi_test_project.packages["sushi"].variables, "sushi")
1868+
context.target = BigQueryConfig(name="production", database="main", schema="sushi")
1869+
model_config = ModelConfig(
1870+
name="model",
1871+
alias="model",
1872+
schema="test",
1873+
package_name="package",
1874+
materialized="table",
1875+
unique_key="ds",
1876+
partition_by={"field": "ds", "granularity": "month"},
1877+
sql="""
1878+
{% set var_name = "yet_" + "another_" + "var" %}
1879+
{% set results = run_query('select 1 as one') %}
1880+
{% if results %}
1881+
SELECT {{ results.columns[0].values()[0] }} AS one {{ var(var_name) }} AS var FROM {{ this.identifier }}
1882+
{% else %}
1883+
SELECT NULL AS one {{ var(var_name) }} AS var FROM {{ this.identifier }}
1884+
{% endif %}
1885+
""",
1886+
dependencies=Dependencies(has_dynamic_var_names=True),
1887+
)
1888+
converted_model = model_config.to_sqlmesh(context)
1889+
assert "yet_another_var" in converted_model.jinja_macros.global_objs["vars"] # type: ignore
1890+
1891+
# Test the existing model in the sushi project
1892+
assert (
1893+
"dynamic_test_var" # type: ignore
1894+
in sushi_test_dbt_context.get_model(
1895+
"sushi.waiter_revenue_by_day_v2"
1896+
).jinja_macros.global_objs["vars"]
1897+
)
1898+
1899+
1900+
@pytest.mark.xdist_group("dbt_manifest")
1901+
def test_dynamic_var_names_in_macro(sushi_test_project: Project):
1902+
context = sushi_test_project.context
1903+
context.set_and_render_variables(sushi_test_project.packages["sushi"].variables, "sushi")
1904+
context.target = BigQueryConfig(name="production", database="main", schema="sushi")
1905+
model_config = ModelConfig(
1906+
name="model",
1907+
alias="model",
1908+
schema="test",
1909+
package_name="package",
1910+
materialized="table",
1911+
unique_key="ds",
1912+
partition_by={"field": "ds", "granularity": "month"},
1913+
sql="""
1914+
{% set var_name = "dynamic_" + "test_" + "var" %}
1915+
SELECT {{ sushi.dynamic_var_name_dependency(var_name) }} AS var
1916+
""",
1917+
dependencies=Dependencies(
1918+
macros=[MacroReference(package="sushi", name="dynamic_var_name_dependency")],
1919+
has_dynamic_var_names=True,
1920+
),
1921+
)
1922+
converted_model = model_config.to_sqlmesh(context)
1923+
assert "dynamic_test_var" in converted_model.jinja_macros.global_objs["vars"] # type: ignore

tests/fixtures/dbt/sushi_test/dbt_project.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ vars:
4747
customers:boo: ["a", "b"]
4848

4949
yet_another_var: 1
50+
dynamic_test_var: 3
5051

5152
customers:
5253
some_var: ["foo", "bar"]

0 commit comments

Comments
 (0)