Skip to content

Commit 4786f58

Browse files
authored
Fix: Take the current package into account when trimming jinja macros (#1010)
1 parent 3d13952 commit 4786f58

6 files changed

Lines changed: 28 additions & 13 deletions

File tree

sqlmesh/dbt/basemodel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,9 @@ def model_function(self) -> AttributeDict[str, t.Any]:
210210
def sqlmesh_model_kwargs(self, context: DbtContext) -> t.Dict[str, t.Any]:
211211
"""Get common sqlmesh model parameters"""
212212
model_context = context.context_for_dependencies(self.dependencies)
213-
jinja_macros = model_context.jinja_macros.trim(self.dependencies.macros)
213+
jinja_macros = model_context.jinja_macros.trim(
214+
self.dependencies.macros, package=self.package_name
215+
)
214216
jinja_macros.global_objs.update(
215217
{
216218
"this": self.relation_info,

sqlmesh/dbt/loader.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,7 @@ def _load_project(self) -> Project:
132132
)
133133

134134
for package_name, macro_infos in context.manifest.all_macros.items():
135-
context.jinja_macros.add_macros(
136-
macro_infos,
137-
package=package_name if package_name != context.project_name else None,
138-
)
135+
context.jinja_macros.add_macros(macro_infos, package=package_name)
139136

140137
self._macros_max_mtime = max(macros_mtimes) if macros_mtimes else None
141138

sqlmesh/dbt/test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class TestConfig(GeneralConfig):
3535
column_name: The name of the column under test.
3636
dependencies: The macros, refs, and sources the test depends upon.
3737
dialect: The sql dialect of the test.
38+
package_name: Name of the package that defines the test.
3839
alias: The alias for the materialized table where failures are stored (Not supported).
3940
schema: The schema for the materialized table where the failures are stored (Not supported).
4041
database: The database for the materilized table where the failures are stored (Not supported).
@@ -58,6 +59,7 @@ class TestConfig(GeneralConfig):
5859
dialect: str = ""
5960

6061
# dbt fields
62+
package_name: str = ""
6163
alias: t.Optional[str] = None
6264
schema_: t.Optional[str] = Field("", alias="schema")
6365
database: t.Optional[str] = None
@@ -89,7 +91,9 @@ def to_sqlmesh(self, context: DbtContext) -> Audit:
8991
"""
9092
test_context = context.context_for_dependencies(self.dependencies)
9193

92-
jinja_macros = test_context.jinja_macros.trim(self.dependencies.macros)
94+
jinja_macros = test_context.jinja_macros.trim(
95+
self.dependencies.macros, package=self.package_name
96+
)
9397
jinja_macros.global_objs.update(
9498
{
9599
"config": self.attribute_dict,

sqlmesh/utils/jinja.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,24 +251,31 @@ def build_environment(self, **kwargs: t.Any) -> Environment:
251251
env.filters.update(self._environment.filters)
252252
return env
253253

254-
def trim(self, dependencies: t.Iterable[MacroReference]) -> JinjaMacroRegistry:
254+
def trim(
255+
self, dependencies: t.Iterable[MacroReference], package: t.Optional[str] = None
256+
) -> JinjaMacroRegistry:
255257
"""Trims the registry by keeping only macros with given references and their transitive dependencies.
256258
257259
Args:
258260
dependencies: References to macros that should be kept.
261+
package: The name of the package in the context of which the trimming should be performed.
259262
260263
Returns:
261264
A new trimmed registry.
262265
"""
263266
dependencies_by_package: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set)
264267
for dep in dependencies:
265-
dependencies_by_package[dep.package].add(dep.name)
268+
dependencies_by_package[dep.package or package].add(dep.name)
269+
270+
top_level_packages = self.top_level_packages.copy()
271+
if package is not None:
272+
top_level_packages.append(package)
266273

267274
result = JinjaMacroRegistry(
268275
global_objs=self.global_objs.copy(),
269276
create_builtins_module=self.create_builtins_module,
270277
root_package_name=self.root_package_name,
271-
top_level_packages=self.top_level_packages.copy(),
278+
top_level_packages=top_level_packages,
272279
)
273280
for package, names in dependencies_by_package.items():
274281
result = result.merge(self._trim_macros(names, package))

tests/dbt/conftest.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,7 @@ def sushi_test_project() -> Project:
1616
delete_cache(project_root)
1717
project = Project.load(DbtContext(project_root=Path(project_root)))
1818
for package_name, package in project.packages.items():
19-
project.context.jinja_macros.add_macros(
20-
package.macro_infos,
21-
package=package_name if package_name != project.context.project_name else None,
22-
)
19+
project.context.jinja_macros.add_macros(package.macro_infos, package=package_name)
2320
return project
2421

2522

tests/utils/test_jinja.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,14 @@ def test_macro_registry_trim():
159159
)
160160
assert rendered == "macro_a_a macro_a_b"
161161

162+
trimmed_registry_for_package_b = registry.trim(
163+
[MacroReference(name="macro_b_b")], package="package_b"
164+
)
165+
assert set(trimmed_registry_for_package_b.packages) == {"package_a", "package_b"}
166+
assert set(trimmed_registry_for_package_b.packages["package_a"]) == {"macro_a_a"}
167+
assert set(trimmed_registry_for_package_b.packages["package_b"]) == {"macro_b_a", "macro_b_b"}
168+
assert not trimmed_registry_for_package_b.root_macros
169+
162170

163171
def test_macro_return():
164172
macros = "{% macro test_return() %}{{ macro_return([1, 2, 3]) }}{% endmacro %}"

0 commit comments

Comments
 (0)