Skip to content

Commit 6b95e3e

Browse files
authored
Fix: Register source and ref dependencies from macros (#1588)
* Fix: Register source and ref dependencies from macros * Cannot reuses call_name to get arg names * Resolve PR feedback * cleanup * restore openapi.json * stringval is a literal
1 parent e8f41bf commit 6b95e3e

4 files changed

Lines changed: 79 additions & 20 deletions

File tree

sqlmesh/dbt/manifest.py

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from sqlmesh.dbt.test import TestConfig
2525
from sqlmesh.dbt.util import DBT_VERSION
2626
from sqlmesh.utils.errors import ConfigError
27-
from sqlmesh.utils.jinja import MacroInfo, MacroReference, extract_macro_references
27+
from sqlmesh.utils.jinja import MacroInfo, MacroReference, extract_call_names, nodes
2828

2929
if t.TYPE_CHECKING:
3030
from dbt.contracts.graph.manifest import Macro, Manifest
@@ -96,10 +96,10 @@ def all_macros(self) -> t.Dict[str, t.Dict[str, MacroInfo]]:
9696
def _load_all(self) -> None:
9797
if self._is_loaded:
9898
return
99+
self._load_macros()
100+
self._load_sources()
99101
self._load_tests()
100102
self._load_models_and_seeds()
101-
self._load_sources()
102-
self._load_macros()
103103
self._is_loaded = True
104104

105105
def _load_sources(self) -> None:
@@ -121,15 +121,16 @@ def _load_macros(self) -> None:
121121
if macro.name.startswith("test_"):
122122
macro.macro_sql = _convert_jinja_test_to_macro(macro.macro_sql)
123123

124-
macro_references = _macro_references(self._manifest, macro)
124+
dependencies = Dependencies(macros=_macro_references(self._manifest, macro))
125125
if not macro.name.startswith("materialization_") and not macro.name.startswith("test_"):
126-
macro_references |= _extra_macro_references(macro.macro_sql)
126+
dependencies = dependencies.union(_extra_dependencies(macro.macro_sql))
127127

128128
self._macros_per_package[macro.package_name][macro.name] = MacroConfig(
129129
info=MacroInfo(
130130
definition=macro.macro_sql,
131-
depends_on=list(macro_references),
131+
depends_on=dependencies.macros,
132132
),
133+
dependencies=dependencies,
133134
path=Path(macro.original_file_path),
134135
)
135136

@@ -167,10 +168,16 @@ def _load_tests(self) -> None:
167168
dependencies.macros.append(MacroReference(package="dbt", name="get_where_subquery"))
168169
dependencies.macros.append(MacroReference(package="dbt", name="should_store_failures"))
169170

171+
sql = node.raw_code if DBT_VERSION >= (1, 3) else node.raw_sql # type: ignore
172+
dependencies = dependencies.union(_extra_dependencies(sql))
173+
dependencies = dependencies.union(
174+
self._macro_source_ref_dependencies(dependencies.macros, package_name)
175+
)
176+
170177
test_model = _test_model(node)
171178

172179
test = TestConfig(
173-
sql=node.raw_code if DBT_VERSION >= (1, 3) else node.raw_sql, # type: ignore
180+
sql=sql,
174181
model_name=test_model,
175182
test_kwargs=node.test_metadata.kwargs if hasattr(node, "test_metadata") else {},
176183
dependencies=dependencies,
@@ -193,14 +200,17 @@ def _load_models_and_seeds(self) -> None:
193200

194201
if node.resource_type == "model":
195202
sql = node.raw_code if DBT_VERSION >= (1, 3) else node.raw_sql # type: ignore
196-
macro_references |= _extra_macro_references(sql)
203+
dependencies = Dependencies(
204+
macros=macro_references, refs=_refs(node), sources=_sources(node)
205+
)
206+
dependencies = dependencies.union(_extra_dependencies(sql))
207+
dependencies = dependencies.union(
208+
self._macro_source_ref_dependencies(dependencies.macros, node.package_name)
209+
)
210+
197211
self._models_per_package[node.package_name][node.name] = ModelConfig(
198212
sql=sql,
199-
dependencies=Dependencies(
200-
macros=macro_references,
201-
refs=_refs(node),
202-
sources=_sources(node),
203-
),
213+
dependencies=dependencies,
204214
tests=tests,
205215
**_node_base_config(node),
206216
)
@@ -260,6 +270,19 @@ def _load_profile(self) -> Profile:
260270
target_override=self.target.name,
261271
)
262272

273+
def _macro_source_ref_dependencies(
274+
self, macros: t.List[MacroReference], default_package: str
275+
) -> Dependencies:
276+
dependencies = Dependencies()
277+
for macro in macros:
278+
macro_config = self._macros_per_package[macro.package or default_package].get(
279+
macro.name
280+
)
281+
if macro_config:
282+
dependencies = dependencies.union(macro_config.dependencies)
283+
dependencies.macros = []
284+
return dependencies
285+
263286

264287
def _config(node: t.Union[ManifestNode, SourceDefinition]) -> t.Dict[str, t.Any]:
265288
return node.config.to_dict()
@@ -330,8 +353,29 @@ def _convert_jinja_test_to_macro(test_jinja: str) -> str:
330353
return re.sub(ENDTEST_REGEX, "{% endmacro %}", macro)
331354

332355

333-
def _extra_macro_references(target: str) -> t.Set[MacroReference]:
356+
def _extra_dependencies(target: str) -> Dependencies:
334357
# We sometimes observe that the manifest doesn't capture certain macros referenced in the model.
335358
# This behavior has been observed with macros like dbt.current_timestamp() and dbt_utils.slugify().
336359
# Here we apply our custom extractor in addition to referenced extracted from the manifest to mitigate this.
337-
return {r for r in extract_macro_references(target) if r.package in ("dbt", "dbt_utils")}
360+
dependencies = Dependencies()
361+
for call_name, node in extract_call_names(target):
362+
if len(call_name) == 2 and call_name[0] in ("dbt", "dbt_utils"):
363+
dependencies.macros.append(MacroReference(package=call_name[0], name=call_name[1]))
364+
elif call_name[0] == "source":
365+
source = ".".join(_jinja_call_arg_name(arg) for arg in node.args)
366+
if source:
367+
dependencies.sources.append(source)
368+
elif call_name[0] == "ref":
369+
ref = ".".join(_jinja_call_arg_name(arg) for arg in node.args)
370+
if ref:
371+
dependencies.refs.append(ref)
372+
373+
return dependencies
374+
375+
376+
def _jinja_call_arg_name(node: nodes.Node) -> str:
377+
if isinstance(node, nodes.Name):
378+
return node.name
379+
if isinstance(node, nodes.Const):
380+
return node.value
381+
return ""

sqlmesh/dbt/package.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import typing as t
55
from pathlib import Path
66

7-
from sqlmesh.dbt.common import PROJECT_FILENAME, load_yaml
7+
from sqlmesh.dbt.common import PROJECT_FILENAME, Dependencies, load_yaml
88
from sqlmesh.dbt.model import ModelConfig
99
from sqlmesh.dbt.seed import SeedConfig
1010
from sqlmesh.dbt.source import SourceConfig
@@ -24,6 +24,7 @@ class MacroConfig(PydanticModel):
2424
"""Class to contain macro configuration"""
2525

2626
info: MacroInfo
27+
dependencies: Dependencies
2728
path: Path
2829

2930

sqlmesh/utils/jinja.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,9 @@ def render_jinja(query: str, methods: t.Optional[t.Dict[str, t.Any]] = None) ->
115115
return ENVIRONMENT.from_string(query).render(methods or {})
116116

117117

118-
def find_call_names(node: nodes.Node, vars_in_scope: t.Set[str]) -> t.Iterator[t.Tuple[str, ...]]:
118+
def find_call_names(
119+
node: nodes.Node, vars_in_scope: t.Set[str]
120+
) -> t.Iterator[t.Tuple[t.Tuple[str, ...], nodes.Call]]:
119121
vars_in_scope = vars_in_scope.copy()
120122
for child_node in node.iter_child_nodes():
121123
if "target" in child_node.fields:
@@ -132,17 +134,17 @@ def find_call_names(node: nodes.Node, vars_in_scope: t.Set[str]) -> t.Iterator[t
132134
elif isinstance(child_node, nodes.Call):
133135
name = call_name(child_node)
134136
if name[0][0] != "'" and name[0] not in vars_in_scope:
135-
yield name
137+
yield (name, child_node)
136138
yield from find_call_names(child_node, vars_in_scope)
137139

138140

139-
def extract_call_names(jinja_str: str) -> t.List[t.Tuple[str, ...]]:
141+
def extract_call_names(jinja_str: str) -> t.List[t.Tuple[t.Tuple[str, ...], nodes.Call]]:
140142
return list(find_call_names(ENVIRONMENT.parse(jinja_str), set()))
141143

142144

143145
def extract_macro_references(jinja_str: str) -> t.Set[MacroReference]:
144146
result = set()
145-
for call_name in extract_call_names(jinja_str):
147+
for call_name, _ in extract_call_names(jinja_str):
146148
if len(call_name) == 1:
147149
result.add(MacroReference(name=call_name[0]))
148150
elif len(call_name) == 2:

tests/utils/test_jinja.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22

33
from sqlmesh.utils import AttributeDict
44
from sqlmesh.utils.jinja import (
5+
ENVIRONMENT,
56
JinjaMacroRegistry,
67
MacroExtractor,
78
MacroReference,
89
MacroReturnVal,
10+
call_name,
11+
nodes,
912
)
1013

1114

@@ -268,3 +271,12 @@ def test_macro_registry_top_level_packages():
268271
"macro_a_a",
269272
"macro_a_a",
270273
]
274+
275+
276+
def test_find_call_names():
277+
jinja_str = "{{ local_macro() }}{{ package.package_macro() }}{{ 'stringval'.function() }}"
278+
[call_name(node) for node in ENVIRONMENT.parse(jinja_str).find_all(nodes.Call)] == [
279+
("local_macro",),
280+
("package", "package_macro"),
281+
("'stringval'", "function"),
282+
]

0 commit comments

Comments
 (0)