Skip to content

Commit f437a16

Browse files
authored
Fix: [dbt] materialized can be jinja and allow, but warn, on extra profile fields (#721)
* [dbt] materialized field can be jinja * [dbt] Warn on extra profile fields
1 parent ff7ba43 commit f437a16

6 files changed

Lines changed: 55 additions & 38 deletions

File tree

sqlmesh/dbt/basemodel.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -259,15 +259,27 @@ def _context_for_dependencies(
259259
) -> DbtContext:
260260
model_context = context.copy()
261261

262-
model_context.sources = {
263-
name: value for name, value in context.sources.items() if name in dependencies.sources
264-
}
265-
model_context.seeds = {
266-
name: value for name, value in context.seeds.items() if name in dependencies.refs
267-
}
268-
model_context.models = {
269-
name: value for name, value in context.models.items() if name in dependencies.refs
270-
}
262+
models = {}
263+
seeds = {}
264+
sources = {}
265+
266+
for ref in self._dependencies.refs:
267+
if ref in context.seeds:
268+
seeds[ref] = context.seeds[ref]
269+
elif ref in context.models:
270+
models[ref] = context.models[ref]
271+
else:
272+
raise ConfigError(f"Model '{ref}' was not found for model '{self.table_name}'.")
273+
274+
for source in self._dependencies.sources:
275+
if source in context.sources:
276+
sources[source] = context.sources[source]
277+
else:
278+
raise ConfigError(f"Source '{source}' was not found for model '{self.table_name}'.")
279+
280+
model_context.sources = sources
281+
model_context.seeds = seeds
282+
model_context.models = models
271283
model_context.variables = {
272284
name: value
273285
for name, value in context.variables.items()
@@ -365,16 +377,8 @@ def _extract_value(node: t.Any) -> t.Any:
365377
return config
366378

367379
def _ref(self, package_name: str, model_name: t.Optional[str] = None) -> BaseRelation:
368-
if package_name in self.context.models:
369-
relation = BaseRelation.create(**self.context.models[package_name].relation_info)
370-
elif package_name in self.context.seeds:
371-
relation = BaseRelation.create(**self.context.seeds[package_name].relation_info)
372-
else:
373-
raise ConfigError(
374-
f"Model '{package_name}' was not found for model '{self.config.table_name}'."
375-
)
376380
self._captured_dependencies.refs.add(package_name)
377-
return relation
381+
return BaseRelation.create()
378382

379383
def _var(self, name: str, default: t.Optional[str] = None) -> t.Any:
380384
if default is None and name not in self.context.variables:
@@ -386,12 +390,8 @@ def _var(self, name: str, default: t.Optional[str] = None) -> t.Any:
386390

387391
def _source(self, source_name: str, table_name: str) -> BaseRelation:
388392
full_name = ".".join([source_name, table_name])
389-
if full_name not in self.context.sources:
390-
raise ConfigError(
391-
f"Source '{full_name}' was not found for model '{self.config.table_name}'."
392-
)
393393
self._captured_dependencies.sources.add(full_name)
394-
return BaseRelation.create(**self.context.sources[full_name].relation_info)
394+
return BaseRelation.create()
395395

396396
class TrackingAdapter(ParsetimeAdapter):
397397
def __init__(self, outer_self: ModelSqlRenderer, *args: t.Any, **kwargs: t.Any):

sqlmesh/dbt/loader.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,6 @@ def _load_models(
5656
context = project.context.copy()
5757

5858
for package_name, package in project.packages.items():
59-
context.add_sources(package.sources)
60-
context.add_seeds(package.seeds)
61-
context.add_models(package.models)
6259
context.jinja_macros.add_macros(
6360
package.macros,
6461
package=package_name if package_name != context.project_name else None,

sqlmesh/dbt/model.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class ModelConfig(BaseModelConfig):
7070
start: t.Optional[str] = None
7171
cluster_by: t.Optional[t.List[str]] = None
7272
incremental_strategy: t.Optional[str] = None
73-
materialized: Materialization = Materialization.VIEW
73+
materialized: str = Materialization.VIEW.value
7474
sql_header: t.Optional[str] = None
7575
unique_key: t.Optional[t.List[str]] = None
7676

@@ -94,10 +94,6 @@ def _validate_list(cls, v: t.Union[str, t.List[str]]) -> t.List[str]:
9494
def _validate_sql(cls, v: t.Union[str, SqlStr]) -> SqlStr:
9595
return SqlStr(v)
9696

97-
@validator("materialized", pre=True)
98-
def _validate_materialization(cls, v: str) -> Materialization:
99-
return Materialization(v.lower())
100-
10197
_FIELD_UPDATE_STRATEGY: t.ClassVar[t.Dict[str, UpdateStrategy]] = {
10298
**BaseModelConfig._FIELD_UPDATE_STRATEGY,
10399
**{
@@ -112,15 +108,15 @@ def model_dialect(self) -> t.Optional[str]:
112108

113109
@property
114110
def model_materialization(self) -> Materialization:
115-
return self.materialized
111+
return Materialization(self.materialized.lower())
116112

117113
def model_kind(self, target: TargetConfig) -> ModelKind:
118114
"""
119115
Get the sqlmesh ModelKind
120116
Returns:
121117
The sqlmesh ModelKind
122118
"""
123-
materialization = self.materialized
119+
materialization = self.model_materialization
124120
if materialization == Materialization.TABLE:
125121
return ModelKind(name=ModelKindName.FULL)
126122
if materialization == Materialization.VIEW:

sqlmesh/dbt/project.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import logging
34
import typing as t
45
from pathlib import Path
56

@@ -9,6 +10,8 @@
910
from sqlmesh.dbt.profile import Profile
1011
from sqlmesh.utils.errors import ConfigError
1112

13+
logger = logging.getLogger(__name__)
14+
1215

1316
class Project:
1417
"""Configuration for a DBT project"""
@@ -64,6 +67,13 @@ def load(cls, context: DbtContext) -> Project:
6467
profile = Profile.load(context, context.target_name)
6568
context.target = profile.target
6669

70+
extra_fields = profile.target.extra
71+
if extra_fields:
72+
extra_str = ",".join(f"'{field}'" for field in extra_fields)
73+
logger.warning(
74+
f"Warning: {profile.target.type} adapter does not currently support {extra_str}."
75+
)
76+
6777
packages = {}
6878
root_loader = PackageLoader(context, ProjectConfig())
6979

sqlmesh/dbt/target.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@
1616
SnowflakeConnectionConfig,
1717
)
1818
from sqlmesh.core.model import IncrementalByTimeRangeKind, IncrementalByUniqueKeyKind
19-
from sqlmesh.dbt.common import QuotingConfig
19+
from sqlmesh.dbt.common import DbtConfig, QuotingConfig
2020
from sqlmesh.utils import AttributeDict
2121
from sqlmesh.utils.errors import ConfigError
22-
from sqlmesh.utils.pydantic import PydanticModel
2322

2423
if sys.version_info >= (3, 9):
2524
from typing import Literal
@@ -29,7 +28,7 @@
2928
IncrementalKind = t.Union[t.Type[IncrementalByUniqueKeyKind], t.Type[IncrementalByTimeRangeKind]]
3029

3130

32-
class TargetConfig(abc.ABC, PydanticModel):
31+
class TargetConfig(abc.ABC, DbtConfig):
3332
"""
3433
Configuration for DBT profile target
3534
@@ -92,6 +91,10 @@ def target_jinja(self, profile_name: str) -> AttributeDict:
9291
def quoting(self) -> QuotingConfig:
9392
return QuotingConfig()
9493

94+
@property
95+
def extra(self) -> t.Set[str]:
96+
return self.extra_fields(set(self.dict()))
97+
9598

9699
class DuckDbConfig(TargetConfig):
97100
"""

tests/dbt/test_transformation.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,16 +218,17 @@ def test_config_containing_missing_dependency():
218218
context = DbtContext()
219219
model = ModelConfig(sql="{{ config(pre_hook=\"{{ print(ref('bar')) }}\") }} SELECT 1 FROM a")
220220
with pytest.raises(ConfigError, match="'bar' was not found"):
221-
model.render_config(context)
221+
model.render_config(context).to_sqlmesh(context)
222222

223223
model = ModelConfig(sql='{{ config(pre_hook="{{ get_table_name() }}") }} SELECT 1 FROM a')
224224
with pytest.raises(ConfigError, match="get_table_name"):
225-
model.render_config(context)
225+
model.render_config(context).to_sqlmesh(context)
226226

227227
model = ModelConfig(sql="{{ config(alias='{{ get_table_name() }}') }} SELECT 1 FROM a")
228228
rendered = model.render_config(context)
229229
assert rendered.alias == "{{ get_table_name() }}"
230230
assert "get_table_name" not in [macro.name for macro in rendered._dependencies.macros]
231+
rendered.to_sqlmesh(context)
231232

232233

233234
def test_config_containing_method():
@@ -273,6 +274,16 @@ def test_schema_jinja(sushi_test_project: Project):
273274
model_config.to_sqlmesh(context).render_query().sql() == "SELECT 1 AS one FROM sushi AS sushi"
274275

275276

277+
def test_materialized_jinja(sushi_test_project: Project):
278+
model_config = ModelConfig(
279+
materialized="{{ 'table' if target.type in ['duckdb'] else 'view' }}",
280+
sql="SELECT 1 AS one FROM {{ schema }}",
281+
)
282+
context = sushi_test_project.context
283+
rendered = model_config.render_config(context)
284+
assert rendered.materialized == "table"
285+
286+
276287
def test_this(assert_exp_eq, sushi_test_project: Project):
277288
model_config = ModelConfig(alias="test", sql="SELECT 1 AS one FROM {{ this.identifier }}")
278289
context = sushi_test_project.context

0 commit comments

Comments
 (0)