Skip to content

Commit 42ed0a3

Browse files
authored
Fix: Get dbt classes without all config fields present (#1605)
1 parent 6b95e3e commit 42ed0a3

3 files changed

Lines changed: 59 additions & 34 deletions

File tree

sqlmesh/dbt/builtin.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from sqlmesh.core.engine_adapter import EngineAdapter
1818
from sqlmesh.dbt.adapter import BaseAdapter, ParsetimeAdapter, RuntimeAdapter
19-
from sqlmesh.dbt.target import TargetConfig
19+
from sqlmesh.dbt.target import TARGET_TYPE_TO_CONFIG_CLASS
2020
from sqlmesh.dbt.util import DBT_VERSION
2121
from sqlmesh.utils import AttributeDict, yaml
2222
from sqlmesh.utils.errors import ConfigError, MacroEvalError
@@ -44,10 +44,10 @@ def warn(self, msg: str) -> str:
4444
class Api:
4545
def __init__(self, target: t.Optional[AttributeDict] = None) -> None:
4646
if target:
47-
config = TargetConfig.load(target)
48-
self.Relation = config.relation_class
49-
self.Column = config.column_class
50-
self.quote_policy = config.quote_policy
47+
config_class = TARGET_TYPE_TO_CONFIG_CLASS[target["type"]]
48+
self.Relation = config_class.relation_class
49+
self.Column = config_class.column_class
50+
self.quote_policy = config_class.quote_policy
5151
else:
5252
self.Relation = BaseRelation
5353
self.Column = Column

sqlmesh/dbt/target.py

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
)
2828
from sqlmesh.dbt.common import DbtConfig
2929
from sqlmesh.dbt.util import DBT_VERSION
30-
from sqlmesh.utils import AttributeDict
30+
from sqlmesh.utils import AttributeDict, classproperty
3131
from sqlmesh.utils.errors import ConfigError
3232
from sqlmesh.utils.pydantic import field_validator, model_validator
3333

@@ -111,20 +111,20 @@ def attribute_dict(self) -> AttributeDict:
111111
fields["target_name"] = self.name
112112
return AttributeDict(fields)
113113

114-
@property
115-
def quote_policy(self) -> Policy:
114+
@classproperty
115+
def quote_policy(cls) -> Policy:
116116
return Policy()
117117

118118
@property
119119
def extra(self) -> t.Set[str]:
120120
return self.extra_fields(set(self.dict()))
121121

122-
@property
123-
def relation_class(self) -> t.Type[BaseRelation]:
122+
@classproperty
123+
def relation_class(cls) -> t.Type[BaseRelation]:
124124
return BaseRelation
125125

126-
@property
127-
def column_class(self) -> t.Type[Column]:
126+
@classproperty
127+
def column_class(cls) -> t.Type[Column]:
128128
return Column
129129

130130

@@ -161,8 +161,8 @@ def validate_authentication(
161161
def default_incremental_strategy(self, kind: IncrementalKind) -> str:
162162
return "delete+insert"
163163

164-
@property
165-
def relation_class(self) -> t.Type[BaseRelation]:
164+
@classproperty
165+
def relation_class(cls) -> t.Type[BaseRelation]:
166166
from dbt.adapters.duckdb.relation import DuckDBRelation
167167

168168
return DuckDBRelation
@@ -223,14 +223,14 @@ def validate_authentication(
223223
def default_incremental_strategy(self, kind: IncrementalKind) -> str:
224224
return "merge"
225225

226-
@property
227-
def relation_class(self) -> t.Type[BaseRelation]:
226+
@classproperty
227+
def relation_class(cls) -> t.Type[BaseRelation]:
228228
from dbt.adapters.snowflake import SnowflakeRelation
229229

230230
return SnowflakeRelation
231231

232-
@property
233-
def column_class(self) -> t.Type[Column]:
232+
@classproperty
233+
def column_class(cls) -> t.Type[Column]:
234234
from dbt.adapters.snowflake import SnowflakeColumn
235235

236236
return SnowflakeColumn
@@ -247,8 +247,8 @@ def to_sqlmesh(self) -> ConnectionConfig:
247247
concurrent_tasks=self.threads,
248248
)
249249

250-
@property
251-
def quote_policy(self) -> Policy:
250+
@classproperty
251+
def quote_policy(cls) -> Policy:
252252
return Policy(database=False, schema=False, identifier=False)
253253

254254

@@ -359,20 +359,20 @@ def validate_database(
359359
def default_incremental_strategy(self, kind: IncrementalKind) -> str:
360360
return "append"
361361

362-
@property
363-
def relation_class(self) -> t.Type[BaseRelation]:
362+
@classproperty
363+
def relation_class(cls) -> t.Type[BaseRelation]:
364364
from dbt.adapters.redshift import RedshiftRelation
365365

366366
return RedshiftRelation
367367

368-
@property
369-
def column_class(self) -> t.Type[Column]:
368+
@classproperty
369+
def column_class(cls) -> t.Type[Column]:
370370
if DBT_VERSION < (1, 6):
371371
from dbt.adapters.redshift import RedshiftColumn # type: ignore
372372

373373
return RedshiftColumn
374374
else:
375-
return super().column_class
375+
return super(RedshiftConfig, cls).column_class()
376376

377377
def to_sqlmesh(self) -> ConnectionConfig:
378378
return RedshiftConnectionConfig(
@@ -409,14 +409,14 @@ class DatabricksConfig(TargetConfig):
409409
def default_incremental_strategy(self, kind: IncrementalKind) -> str:
410410
return "merge"
411411

412-
@property
413-
def relation_class(self) -> t.Type[BaseRelation]:
412+
@classproperty
413+
def relation_class(cls) -> t.Type[BaseRelation]:
414414
from dbt.adapters.databricks.relation import DatabricksRelation
415415

416416
return DatabricksRelation
417417

418-
@property
419-
def column_class(self) -> t.Type[Column]:
418+
@classproperty
419+
def column_class(cls) -> t.Type[Column]:
420420
from dbt.adapters.databricks.column import DatabricksColumn
421421

422422
return DatabricksColumn
@@ -491,14 +491,14 @@ def validate_fields(
491491
def default_incremental_strategy(self, kind: IncrementalKind) -> str:
492492
return "merge"
493493

494-
@property
495-
def relation_class(self) -> t.Type[BaseRelation]:
494+
@classproperty
495+
def relation_class(cls) -> t.Type[BaseRelation]:
496496
from dbt.adapters.bigquery.relation import BigQueryRelation
497497

498498
return BigQueryRelation
499499

500-
@property
501-
def column_class(self) -> t.Type[Column]:
500+
@classproperty
501+
def column_class(cls) -> t.Type[Column]:
502502
from dbt.adapters.bigquery import BigQueryColumn
503503

504504
return BigQueryColumn
@@ -529,3 +529,13 @@ def to_sqlmesh(self) -> ConnectionConfig:
529529
priority=self.priority,
530530
maximum_bytes_billed=self.maximum_bytes_billed,
531531
)
532+
533+
534+
TARGET_TYPE_TO_CONFIG_CLASS: t.Dict[str, t.Type[TargetConfig]] = {
535+
"databricks": DatabricksConfig,
536+
"duckdb": DuckDbConfig,
537+
"postgres": PostgresConfig,
538+
"redshift": RedshiftConfig,
539+
"snowflake": SnowflakeConfig,
540+
"bigquery": BigQueryConfig,
541+
}

tests/dbt/test_config.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
from pathlib import Path
33

44
import pytest
5-
from dbt.adapters.base import BaseRelation
5+
from dbt.adapters.base import BaseRelation, Column
6+
from dbt.adapters.duckdb.relation import DuckDBRelation
7+
from dbt.contracts.relation import Policy
68

79
from sqlmesh.core.dialect import jinja_query
810
from sqlmesh.core.model import SqlModel
@@ -12,6 +14,7 @@
1214
from sqlmesh.dbt.project import Project
1315
from sqlmesh.dbt.source import SourceConfig
1416
from sqlmesh.dbt.target import (
17+
TARGET_TYPE_TO_CONFIG_CLASS,
1518
BigQueryConfig,
1619
DatabricksConfig,
1720
DuckDbConfig,
@@ -512,3 +515,15 @@ def test_quoting_config():
512515
assert QuotingConfig.parse_obj(
513516
{"database": True, "identifier": True, "schema": True}
514517
) == QuotingConfig(database=True, identifier=True, schema=True)
518+
519+
520+
def test_db_type_to_relation_class():
521+
assert (TARGET_TYPE_TO_CONFIG_CLASS["duckdb"].relation_class) == DuckDBRelation
522+
523+
524+
def test_db_type_to_column_class():
525+
assert (TARGET_TYPE_TO_CONFIG_CLASS["duckdb"].column_class) == Column
526+
527+
528+
def test_db_type_to_quote_policy():
529+
assert isinstance(TARGET_TYPE_TO_CONFIG_CLASS["duckdb"].quote_policy, Policy)

0 commit comments

Comments
 (0)