Skip to content

Commit 86f2278

Browse files
authored
Feat: dbt namespace macros (#690)
* Implement dbt namespace macros * PR review fixes
1 parent 2a2be40 commit 86f2278

7 files changed

Lines changed: 67 additions & 12 deletions

File tree

sqlmesh/dbt/basemodel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from sqlmesh.core.config.base import UpdateStrategy
1818
from sqlmesh.core.model import Model
1919
from sqlmesh.dbt.adapter import ParsetimeAdapter
20-
from sqlmesh.dbt.builtin import create_builtin_globals
2120
from sqlmesh.dbt.column import (
2221
ColumnConfig,
2322
column_descriptions_to_sqlmesh,
@@ -276,6 +275,8 @@ def _context_for_dependencies(
276275

277276
class ModelSqlRenderer(t.Generic[BMC]):
278277
def __init__(self, context: DbtContext, config: BMC):
278+
from sqlmesh.dbt.builtin import create_builtin_globals
279+
279280
self.context = context
280281
self.config = config
281282

sqlmesh/dbt/builtin.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
import json
44
import os
5+
import sys
56
import typing as t
67
from ast import literal_eval
8+
from pathlib import Path
79

810
import agate
911
import jinja2
@@ -13,6 +15,8 @@
1315

1416
from sqlmesh.core.engine_adapter import EngineAdapter
1517
from sqlmesh.dbt.adapter import ParsetimeAdapter, RuntimeAdapter
18+
from sqlmesh.dbt.common import DbtContext
19+
from sqlmesh.dbt.package import PackageLoader
1620
from sqlmesh.utils import AttributeDict, yaml
1721
from sqlmesh.utils.errors import ConfigError, MacroEvalError
1822
from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroReturnVal
@@ -246,6 +250,26 @@ def _try_literal_eval(value: str) -> t.Any:
246250
return value
247251

248252

253+
def _dbt_macros_registry() -> JinjaMacroRegistry:
254+
registry = JinjaMacroRegistry()
255+
256+
try:
257+
site_packages = next(
258+
p for p in sys.path if "site-packages" in p and Path(p, "dbt").exists()
259+
)
260+
except:
261+
return registry
262+
263+
for project_file in Path(site_packages).glob("dbt/include/*/dbt_project.yml"):
264+
if project_file.parent.stem == "starter_project":
265+
continue
266+
context = DbtContext(project_root=project_file.parent, jinja_macros=JinjaMacroRegistry())
267+
package = PackageLoader(context).load()
268+
registry.add_macros(package.macros, package="dbt")
269+
270+
return registry
271+
272+
249273
BUILTIN_GLOBALS = {
250274
"api": Api(),
251275
"env_var": env_var,
@@ -341,7 +365,13 @@ def create_builtin_globals(
341365
}
342366
)
343367

344-
return {**builtin_globals, **jinja_globals}
368+
builtin_globals.update(jinja_globals)
369+
if "dbt" not in builtin_globals:
370+
builtin_globals["dbt"] = (
371+
_dbt_macros_registry().build_environment(**builtin_globals).globals.get("dbt", {})
372+
)
373+
374+
return builtin_globals
345375

346376

347377
def create_builtin_filters() -> t.Dict[str, t.Callable]:

sqlmesh/dbt/common.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class DbtContext:
6868
_sources: t.Dict[str, SourceConfig] = field(default_factory=dict)
6969
_refs: t.Dict[str, str] = field(default_factory=dict)
7070

71-
_target: t.Optional[TargetConfig] = None
71+
_target: TargetConfig = field(default_factory=TargetConfig)
7272

7373
_jinja_environment: t.Optional[Environment] = None
7474

@@ -140,8 +140,6 @@ def refs(self) -> t.Dict[str, str]:
140140

141141
@property
142142
def target(self) -> TargetConfig:
143-
if not self._target:
144-
raise ConfigError(f"Target not set for {self.project_name}")
145143
return self._target
146144

147145
@target.setter

sqlmesh/dbt/package.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ class Package(PydanticModel):
4848
class PackageLoader:
4949
"""Loader for DBT packages"""
5050

51-
def __init__(self, context: DbtContext, overrides: ProjectConfig):
51+
def __init__(self, context: DbtContext, overrides: t.Optional[ProjectConfig] = None):
5252
self._context = context.copy()
53-
self._overrides = overrides
53+
self._overrides = overrides or ProjectConfig()
5454
self._config_paths: t.Set[Path] = set()
5555
self.project_config = ProjectConfig()
5656

sqlmesh/dbt/target.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ class TargetConfig(abc.ABC, PydanticModel):
4040
"""
4141

4242
# dbt
43-
type: str
43+
type: str = ""
4444
name: str = ""
45-
schema_: str = Field(alias="schema")
45+
schema_: str = Field(default="", alias="schema")
4646
threads: int = 1
4747
profile_name: t.Optional[str] = None
4848

sqlmesh/utils/jinja.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,14 +211,15 @@ def build_macro(self, reference: MacroReference, **kwargs: t.Any) -> t.Optional[
211211
Returns:
212212
The macro as a Python callable or None if not found.
213213
"""
214+
global_vars = self._create_builtin_globals(kwargs)
214215
if reference.package is not None and reference.name not in self.packages.get(
215216
reference.package, {}
216217
):
217-
return None
218+
global_package = global_vars.get(reference.package, {})
219+
return global_package.get(reference.name)
218220
if reference.package is None and reference.name not in self.root_macros:
219-
return None
221+
return global_vars.get(reference.name)
220222

221-
global_vars = self._create_builtin_globals(kwargs)
222223
return self._make_callable(reference.name, reference.package, {}, global_vars)
223224

224225
def build_environment(self, **kwargs: t.Any) -> Environment:

tests/dbt/test_transformation.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
from dbt.exceptions import CompilationError
66
from sqlglot import exp, parse_one
77

8+
from sqlmesh.core.config.connection import DuckDBConnectionConfig
89
from sqlmesh.core.context import Context, ExecutionContext
910
from sqlmesh.core.model import (
1011
IncrementalByTimeRangeKind,
1112
IncrementalByUniqueKeyKind,
1213
ModelKind,
1314
ModelKindName,
1415
)
16+
from sqlmesh.dbt.builtin import create_builtin_globals
1517
from sqlmesh.dbt.column import (
1618
ColumnConfig,
1719
column_descriptions_to_sqlmesh,
@@ -426,3 +428,26 @@ def test_zip(sushi_test_project: Project):
426428
assert context.render("{{ zip_strict([1, 2], ['a', 'b']) }}") == "[(1, 'a'), (2, 'b')]"
427429
with pytest.raises(TypeError):
428430
context.render("{{ zip_strict(12, ['a', 'b']) }}")
431+
432+
433+
def test_dbt_namespace():
434+
context = DbtContext()
435+
jinja_globals = create_builtin_globals(
436+
jinja_macros=context.jinja_macros,
437+
jinja_globals={},
438+
engine_adapter=DuckDBConnectionConfig().create_engine_adapter(),
439+
)
440+
jinja_env = context.jinja_macros.build_environment(**jinja_globals)
441+
442+
assert (
443+
jinja_env.from_string("{{ dbt.replace('original sentence', 'original', 'updated') }}")
444+
.render()
445+
.strip()
446+
== """replace(
447+
original sentence,
448+
original,
449+
updated
450+
)"""
451+
)
452+
453+
assert jinja_env.from_string("{{ dbt.hash('col')}}").render() == "md5(cast(col as TEXT))"

0 commit comments

Comments
 (0)