|
2 | 2 |
|
3 | 3 | import json |
4 | 4 | import os |
| 5 | +import sys |
5 | 6 | import typing as t |
6 | 7 | from ast import literal_eval |
| 8 | +from pathlib import Path |
7 | 9 |
|
8 | 10 | import agate |
9 | 11 | import jinja2 |
|
13 | 15 |
|
14 | 16 | from sqlmesh.core.engine_adapter import EngineAdapter |
15 | 17 | from sqlmesh.dbt.adapter import ParsetimeAdapter, RuntimeAdapter |
| 18 | +from sqlmesh.dbt.common import DbtContext |
| 19 | +from sqlmesh.dbt.package import PackageLoader |
16 | 20 | from sqlmesh.utils import AttributeDict, yaml |
17 | 21 | from sqlmesh.utils.errors import ConfigError, MacroEvalError |
18 | 22 | from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroReturnVal |
@@ -246,6 +250,26 @@ def _try_literal_eval(value: str) -> t.Any: |
246 | 250 | return value |
247 | 251 |
|
248 | 252 |
|
| 253 | +def _dbt_macros_registry() -> JinjaMacroRegistry: |
| 254 | + registry = JinjaMacroRegistry() |
| 255 | + |
| 256 | + try: |
| 257 | + site_packages = next( |
| 258 | + p for p in sys.path if "site-packages" in p and Path(p, "dbt").exists() |
| 259 | + ) |
| 260 | + except: |
| 261 | + return registry |
| 262 | + |
| 263 | + for project_file in Path(site_packages).glob("dbt/include/*/dbt_project.yml"): |
| 264 | + if project_file.parent.stem == "starter_project": |
| 265 | + continue |
| 266 | + context = DbtContext(project_root=project_file.parent, jinja_macros=JinjaMacroRegistry()) |
| 267 | + package = PackageLoader(context).load() |
| 268 | + registry.add_macros(package.macros, package="dbt") |
| 269 | + |
| 270 | + return registry |
| 271 | + |
| 272 | + |
249 | 273 | BUILTIN_GLOBALS = { |
250 | 274 | "api": Api(), |
251 | 275 | "env_var": env_var, |
@@ -341,7 +365,13 @@ def create_builtin_globals( |
341 | 365 | } |
342 | 366 | ) |
343 | 367 |
|
344 | | - return {**builtin_globals, **jinja_globals} |
| 368 | + builtin_globals.update(jinja_globals) |
| 369 | + if "dbt" not in builtin_globals: |
| 370 | + builtin_globals["dbt"] = ( |
| 371 | + _dbt_macros_registry().build_environment(**builtin_globals).globals.get("dbt", {}) |
| 372 | + ) |
| 373 | + |
| 374 | + return builtin_globals |
345 | 375 |
|
346 | 376 |
|
347 | 377 | def create_builtin_filters() -> t.Dict[str, t.Callable]: |
|
0 commit comments