Skip to content

Commit a750efb

Browse files
authored
Fix!: Use target dbt adapter relation and column classes. (#1056)
1 parent fb4dff6 commit a750efb

File tree

4 files changed

+97
-18
lines changed

4 files changed

+97
-18
lines changed

setup.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ ignore_missing_imports = True
7575
[mypy-pytest_lazyfixture.*]
7676
ignore_missing_imports = True
7777

78+
[mypy-dbt.adapters.*]
79+
ignore_missing_imports = True
80+
7881
[autoflake]
7982
in-place = True
8083
expand-star-imports = True

sqlmesh/dbt/builtin.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
import agate
99
import jinja2
1010
from dbt import version
11-
from dbt.adapters.base import BaseRelation
11+
from dbt.adapters.base import BaseRelation, Column
1212
from dbt.contracts.relation import Policy
1313
from ruamel.yaml import YAMLError
1414

1515
from sqlmesh.core.engine_adapter import EngineAdapter
1616
from sqlmesh.dbt.adapter import BaseAdapter, ParsetimeAdapter, RuntimeAdapter
17+
from sqlmesh.dbt.target import TargetConfig
1718
from sqlmesh.dbt.util import DBT_VERSION
1819
from sqlmesh.utils import AttributeDict, yaml
1920
from sqlmesh.utils.errors import ConfigError, MacroEvalError
@@ -37,12 +38,14 @@ def warn(self, msg: str) -> str:
3738

3839

3940
class Api:
40-
def __init__(self) -> None:
41-
from dbt.adapters.base.column import Column
42-
from dbt.adapters.base.relation import BaseRelation
43-
44-
self.Relation = BaseRelation
45-
self.Column = Column
41+
def __init__(self, target: t.Optional[AttributeDict] = None) -> None:
42+
if target:
43+
config = TargetConfig.load(target)
44+
self.Relation = config.relation_class
45+
self.Column = config.column_class
46+
else:
47+
self.Relation = BaseRelation
48+
self.Column = Column
4649

4750

4851
class Flags:
@@ -156,25 +159,25 @@ def var(name: str, default: t.Optional[t.Any] = None) -> t.Any:
156159
return var
157160

158161

159-
def generate_ref(refs: t.Dict[str, t.Any]) -> t.Callable:
162+
def generate_ref(refs: t.Dict[str, t.Any], api: Api) -> t.Callable:
160163
def ref(package: str, name: t.Optional[str] = None) -> t.Optional[BaseRelation]:
161164
ref_name = f"{package}.{name}" if name else package
162165
relation_info = refs.get(ref_name)
163166
if relation_info is None:
164167
return None
165168

166-
return BaseRelation.create(**relation_info)
169+
return api.Relation.create(**relation_info)
167170

168171
return ref
169172

170173

171-
def generate_source(sources: t.Dict[str, t.Any]) -> t.Callable:
174+
def generate_source(sources: t.Dict[str, t.Any], api: Api) -> t.Callable:
172175
def source(package: str, name: str) -> t.Optional[BaseRelation]:
173176
relation_info = sources.get(f"{package}.{name}")
174177
if relation_info is None:
175178
return None
176179

177-
return BaseRelation.create(**relation_info)
180+
return api.Relation.create(**relation_info)
178181

179182
return source
180183

@@ -251,7 +254,6 @@ def _try_literal_eval(value: str) -> t.Any:
251254

252255

253256
BUILTIN_GLOBALS = {
254-
"api": Api(),
255257
"dbt_version": version.__version__,
256258
"env_var": env_var,
257259
"exceptions": Exceptions(),
@@ -288,20 +290,25 @@ def create_builtin_globals(
288290
builtin_globals = BUILTIN_GLOBALS.copy()
289291
jinja_globals = jinja_globals.copy()
290292

293+
target: t.Optional[AttributeDict] = jinja_globals.get("target", None)
294+
api = Api(target)
295+
296+
builtin_globals["api"] = api
297+
291298
this = jinja_globals.pop("this", None)
292299
if this is not None:
293-
if not isinstance(this, BaseRelation):
294-
builtin_globals["this"] = BaseRelation.create(**this)
300+
if not isinstance(this, api.Relation):
301+
builtin_globals["this"] = api.Relation.create(**this)
295302
else:
296303
builtin_globals["this"] = this
297304

298305
sources = jinja_globals.pop("sources", None)
299306
if sources is not None:
300-
builtin_globals["source"] = generate_source(sources)
307+
builtin_globals["source"] = generate_source(sources, api)
301308

302309
refs = jinja_globals.pop("refs", None)
303310
if refs is not None:
304-
builtin_globals["ref"] = generate_ref(refs)
311+
builtin_globals["ref"] = generate_ref(refs, api)
305312

306313
variables = jinja_globals.pop("vars", None)
307314
if variables is not None:
@@ -319,11 +326,10 @@ def create_builtin_globals(
319326
)
320327
builtin_globals.update({"log": log, "print": log})
321328
else:
322-
target = jinja_globals.get("target")
323329
adapter = ParsetimeAdapter(
324330
jinja_macros,
325331
jinja_globals={**builtin_globals, **jinja_globals},
326-
dialect=target.type if target else None,
332+
dialect=target.type if target else None, # type: ignore
327333
)
328334
builtin_globals.update({"log": no_log, "print": no_log})
329335

sqlmesh/dbt/target.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import sys
55
import typing as t
66

7+
from dbt.adapters.base import BaseRelation, Column
78
from pydantic import Field, root_validator, validator
89

910
from sqlmesh.core.config.connection import (
@@ -98,6 +99,14 @@ def quoting(self) -> QuotingConfig:
9899
def extra(self) -> t.Set[str]:
99100
return self.extra_fields(set(self.dict()))
100101

102+
@property
103+
def relation_class(self) -> t.Type[BaseRelation]:
104+
return BaseRelation
105+
106+
@property
107+
def column_class(self) -> t.Type[Column]:
108+
return Column
109+
101110

102111
class DuckDbConfig(TargetConfig):
103112
"""
@@ -115,6 +124,12 @@ class DuckDbConfig(TargetConfig):
115124
def default_incremental_strategy(self, kind: IncrementalKind) -> str:
116125
return "delete+insert"
117126

127+
@property
128+
def relation_class(self) -> t.Type[BaseRelation]:
129+
from dbt.adapters.duckdb.relation import DuckDBRelation
130+
131+
return DuckDBRelation
132+
118133
def to_sqlmesh(self) -> ConnectionConfig:
119134
return DuckDBConnectionConfig(database=self.path, concurrent_tasks=self.threads)
120135

@@ -154,6 +169,18 @@ class SnowflakeConfig(TargetConfig):
154169
def default_incremental_strategy(self, kind: IncrementalKind) -> str:
155170
return "merge"
156171

172+
@property
173+
def relation_class(self) -> t.Type[BaseRelation]:
174+
from dbt.adapters.snowflake import SnowflakeRelation
175+
176+
return SnowflakeRelation
177+
178+
@property
179+
def column_class(self) -> t.Type[Column]:
180+
from dbt.adapters.snowflake import SnowflakeColumn
181+
182+
return SnowflakeColumn
183+
157184
def to_sqlmesh(self) -> ConnectionConfig:
158185
return SnowflakeConnectionConfig(
159186
user=self.user,
@@ -274,6 +301,18 @@ def validate_database(
274301
def default_incremental_strategy(self, kind: IncrementalKind) -> str:
275302
return "append"
276303

304+
@property
305+
def relation_class(self) -> t.Type[BaseRelation]:
306+
from dbt.adapters.redshift import RedshiftRelation
307+
308+
return RedshiftRelation
309+
310+
@property
311+
def column_class(self) -> t.Type[Column]:
312+
from dbt.adapters.redshift import RedshiftColumn
313+
314+
return RedshiftColumn
315+
277316
def to_sqlmesh(self) -> ConnectionConfig:
278317
return RedshiftConnectionConfig(
279318
user=self.user,
@@ -309,6 +348,18 @@ class DatabricksConfig(TargetConfig):
309348
def default_incremental_strategy(self, kind: IncrementalKind) -> str:
310349
return "merge"
311350

351+
@property
352+
def relation_class(self) -> t.Type[BaseRelation]:
353+
from dbt.adapters.databricks.relation import DatabricksRelation
354+
355+
return DatabricksRelation
356+
357+
@property
358+
def column_class(self) -> t.Type[Column]:
359+
from dbt.adapters.databricks.column import DatabricksColumn
360+
361+
return DatabricksColumn
362+
312363
def to_sqlmesh(self) -> ConnectionConfig:
313364
return DatabricksConnectionConfig(
314365
server_hostname=self.host,
@@ -376,6 +427,18 @@ def validate_fields(
376427
def default_incremental_strategy(self, kind: IncrementalKind) -> str:
377428
return "merge"
378429

430+
@property
431+
def relation_class(self) -> t.Type[BaseRelation]:
432+
from dbt.adapters.bigquery.relation import BigQueryRelation
433+
434+
return BigQueryRelation
435+
436+
@property
437+
def column_class(self) -> t.Type[Column]:
438+
from dbt.adapters.bigquery import BigQueryColumn
439+
440+
return BigQueryColumn
441+
379442
def to_sqlmesh(self) -> ConnectionConfig:
380443
return BigQueryConnectionConfig(
381444
method=self.method,

tests/dbt/test_transformation.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,11 @@ def test_flags(sushi_test_project: Project):
315315
def test_relation(sushi_test_project: Project):
316316
context = sushi_test_project.context
317317

318+
assert (
319+
context.render("{{ api.Relation }}")
320+
== "<class 'dbt.adapters.duckdb.relation.DuckDBRelation'>"
321+
)
322+
318323
jinja = (
319324
"{% set relation = api.Relation.create(schema='sushi', identifier='waiters') %}"
320325
"{{ relation.schema }} {{ relation.identifier}}"
@@ -326,6 +331,8 @@ def test_relation(sushi_test_project: Project):
326331
def test_column(sushi_test_project: Project):
327332
context = sushi_test_project.context
328333

334+
assert context.render("{{ api.Column }}") == "<class 'dbt.adapters.base.column.Column'>"
335+
329336
jinja = (
330337
"{% set col = api.Column('foo', 'integer') %}" "{{ col.is_integer() }} {{ col.is_string()}}"
331338
)

0 commit comments

Comments
 (0)