Skip to content

Commit 257ef63

Browse files
Feat(dbt): Add support for dbt custom materializations
1 parent 92e4a32 commit 257ef63

File tree

14 files changed

+1293
-9
lines changed

14 files changed

+1293
-9
lines changed

sqlmesh/core/model/kind.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ def is_custom(self) -> bool:
119119
def is_managed(self) -> bool:
120120
return self.model_kind_name == ModelKindName.MANAGED
121121

122+
@property
123+
def is_dbt_custom(self) -> bool:
124+
return self.model_kind_name == ModelKindName.DBT_CUSTOM
125+
122126
@property
123127
def is_symbolic(self) -> bool:
124128
"""A symbolic model is one that doesn't execute at all."""
@@ -170,6 +174,7 @@ class ModelKindName(str, ModelKindMixin, Enum):
170174
EXTERNAL = "EXTERNAL"
171175
CUSTOM = "CUSTOM"
172176
MANAGED = "MANAGED"
177+
DBT_CUSTOM = "DBT_CUSTOM"
173178

174179
@property
175180
def model_kind_name(self) -> t.Optional[ModelKindName]:
@@ -887,6 +892,52 @@ def supports_python_models(self) -> bool:
887892
return False
888893

889894

895+
class DbtCustomKind(_ModelKind):
896+
name: t.Literal[ModelKindName.DBT_CUSTOM] = ModelKindName.DBT_CUSTOM
897+
materialization: str
898+
adapter: str = "default"
899+
definition: str
900+
dialect: t.Optional[str] = Field(None, validate_default=True)
901+
902+
_dialect_validator = kind_dialect_validator
903+
904+
@field_validator("materialization", "adapter", "definition", mode="before")
905+
@classmethod
906+
def _validate_fields(cls, v: t.Any) -> str:
907+
return validate_string(v)
908+
909+
@property
910+
def data_hash_values(self) -> t.List[t.Optional[str]]:
911+
return [
912+
*super().data_hash_values,
913+
self.materialization,
914+
self.definition,
915+
self.adapter,
916+
self.dialect,
917+
]
918+
919+
@property
920+
def metadata_hash_values(self) -> t.List[t.Optional[str]]:
921+
return [
922+
*super().metadata_hash_values,
923+
]
924+
925+
def to_expression(
926+
self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
927+
) -> d.ModelKind:
928+
return super().to_expression(
929+
expressions=[
930+
*(expressions or []),
931+
*_properties(
932+
{
933+
"materialization": exp.Literal.string(self.materialization),
934+
"adapter": exp.Literal.string(self.adapter),
935+
}
936+
),
937+
],
938+
)
939+
940+
890941
class EmbeddedKind(_ModelKind):
891942
name: t.Literal[ModelKindName.EMBEDDED] = ModelKindName.EMBEDDED
892943

@@ -992,6 +1043,7 @@ def to_expression(
9921043
SCDType2ByColumnKind,
9931044
CustomKind,
9941045
ManagedKind,
1046+
DbtCustomKind,
9951047
],
9961048
Field(discriminator="name"),
9971049
]
@@ -1011,6 +1063,7 @@ def to_expression(
10111063
ModelKindName.SCD_TYPE_2_BY_COLUMN: SCDType2ByColumnKind,
10121064
ModelKindName.CUSTOM: CustomKind,
10131065
ModelKindName.MANAGED: ManagedKind,
1066+
ModelKindName.DBT_CUSTOM: DbtCustomKind,
10141067
}
10151068

10161069

sqlmesh/core/snapshot/evaluator.py

Lines changed: 155 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
ViewKind,
5151
CustomKind,
5252
)
53-
from sqlmesh.core.model.kind import _Incremental
53+
from sqlmesh.core.model.kind import _Incremental, DbtCustomKind
5454
from sqlmesh.utils import CompletionStatus, columns_to_types_all_known
5555
from sqlmesh.core.schema_diff import (
5656
has_drop_alteration,
@@ -83,6 +83,7 @@
8383
format_additive_change_msg,
8484
AdditiveChangeError,
8585
)
86+
from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroReturnVal
8687

8788
if sys.version_info >= (3, 12):
8889
from importlib import metadata
@@ -747,7 +748,8 @@ def _evaluate_snapshot(
747748
adapter.transaction(),
748749
adapter.session(snapshot.model.render_session_properties(**render_statements_kwargs)),
749750
):
750-
adapter.execute(model.render_pre_statements(**render_statements_kwargs))
751+
if not snapshot.is_dbt_custom:
752+
adapter.execute(model.render_pre_statements(**render_statements_kwargs))
751753

752754
if not target_table_exists or (model.is_seed and not snapshot.intervals):
753755
# Only create the empty table if the columns were provided explicitly by the user
@@ -817,7 +819,8 @@ def _evaluate_snapshot(
817819
batch_index=batch_index,
818820
)
819821

820-
adapter.execute(model.render_post_statements(**render_statements_kwargs))
822+
if not snapshot.is_dbt_custom:
823+
adapter.execute(model.render_post_statements(**render_statements_kwargs))
821824

822825
return wap_id
823826

@@ -1432,7 +1435,7 @@ def _execute_create(
14321435
**create_render_kwargs,
14331436
"table_mapping": {snapshot.name: table_name},
14341437
}
1435-
if run_pre_post_statements:
1438+
if run_pre_post_statements and not snapshot.is_dbt_custom:
14361439
adapter.execute(snapshot.model.render_pre_statements(**create_render_kwargs))
14371440
evaluation_strategy.create(
14381441
table_name=table_name,
@@ -1444,7 +1447,7 @@ def _execute_create(
14441447
dry_run=dry_run,
14451448
physical_properties=rendered_physical_properties,
14461449
)
1447-
if run_pre_post_statements:
1450+
if run_pre_post_statements and not snapshot.is_dbt_custom:
14481451
adapter.execute(snapshot.model.render_post_statements(**create_render_kwargs))
14491452

14501453
def _can_clone(self, snapshot: Snapshot, deployability_index: DeployabilityIndex) -> bool:
@@ -1456,6 +1459,7 @@ def _can_clone(self, snapshot: Snapshot, deployability_index: DeployabilityIndex
14561459
and adapter.SUPPORTS_CLONING
14571460
# managed models cannot have their schema mutated because theyre based on queries, so clone + alter wont work
14581461
and not snapshot.is_managed
1462+
and not snapshot.is_dbt_custom
14591463
and not deployability_index.is_deployable(snapshot)
14601464
# If the deployable table is missing we can't clone it
14611465
and adapter.table_exists(snapshot.table_name())
@@ -1540,6 +1544,19 @@ def _evaluation_strategy(snapshot: SnapshotInfoLike, adapter: EngineAdapter) ->
15401544
klass = ViewStrategy
15411545
elif snapshot.is_scd_type_2:
15421546
klass = SCDType2Strategy
1547+
elif snapshot.is_dbt_custom:
1548+
if hasattr(snapshot, "model") and isinstance(
1549+
(model_kind := snapshot.model.kind), DbtCustomKind
1550+
):
1551+
return DbtCustomMaterialization(
1552+
adapter=adapter,
1553+
materialization_name=model_kind.materialization,
1554+
materialization_template=model_kind.definition,
1555+
)
1556+
1557+
raise SQLMeshError(
1558+
f"Expected DbtCustomKind for dbt custom materialization in model '{snapshot.name}'"
1559+
)
15431560
elif snapshot.is_custom:
15441561
if snapshot.custom_materialization is None:
15451562
raise SQLMeshError(
@@ -2593,6 +2610,139 @@ def get_custom_materialization_type_or_raise(
25932610
raise SQLMeshError(f"Custom materialization '{name}' not present in the Python environment")
25942611

25952612

2613+
class DbtCustomMaterialization(MaterializableStrategy):
2614+
def __init__(
2615+
self,
2616+
adapter: EngineAdapter,
2617+
materialization_name: str,
2618+
materialization_template: str,
2619+
):
2620+
super().__init__(adapter)
2621+
self.materialization_name = materialization_name
2622+
self.materialization_template = materialization_template
2623+
2624+
def create(
2625+
self,
2626+
table_name: str,
2627+
model: Model,
2628+
is_table_deployable: bool,
2629+
render_kwargs: t.Dict[str, t.Any],
2630+
**kwargs: t.Any,
2631+
) -> None:
2632+
original_query = model.render_query_or_raise(**render_kwargs)
2633+
self._execute_materialization(
2634+
table_name=table_name,
2635+
query_or_df=original_query.limit(0),
2636+
model=model,
2637+
is_first_insert=True,
2638+
render_kwargs=render_kwargs,
2639+
create_only=True,
2640+
**kwargs,
2641+
)
2642+
2643+
def insert(
2644+
self,
2645+
table_name: str,
2646+
query_or_df: QueryOrDF,
2647+
model: Model,
2648+
is_first_insert: bool,
2649+
render_kwargs: t.Dict[str, t.Any],
2650+
**kwargs: t.Any,
2651+
) -> None:
2652+
self._execute_materialization(
2653+
table_name=table_name,
2654+
query_or_df=query_or_df,
2655+
model=model,
2656+
is_first_insert=is_first_insert,
2657+
render_kwargs=render_kwargs,
2658+
**kwargs,
2659+
)
2660+
2661+
def append(
2662+
self,
2663+
table_name: str,
2664+
query_or_df: QueryOrDF,
2665+
model: Model,
2666+
render_kwargs: t.Dict[str, t.Any],
2667+
**kwargs: t.Any,
2668+
) -> None:
2669+
return self.insert(
2670+
table_name,
2671+
query_or_df,
2672+
model,
2673+
is_first_insert=False,
2674+
render_kwargs=render_kwargs,
2675+
**kwargs,
2676+
)
2677+
2678+
def _execute_materialization(
2679+
self,
2680+
table_name: str,
2681+
query_or_df: QueryOrDF,
2682+
model: Model,
2683+
is_first_insert: bool,
2684+
render_kwargs: t.Dict[str, t.Any],
2685+
create_only: bool = False,
2686+
**kwargs: t.Any,
2687+
) -> None:
2688+
from sqlmesh.dbt.builtin import create_builtin_globals
2689+
2690+
jinja_macros = getattr(model, "jinja_macros", JinjaMacroRegistry())
2691+
existing_globals = jinja_macros.global_objs.copy()
2692+
2693+
# For vdes we need to use the table, since we don't know the schema/table at parse time
2694+
parts = exp.to_table(table_name, dialect=self.adapter.dialect)
2695+
2696+
relation_info = existing_globals.pop("this")
2697+
if isinstance(relation_info, dict):
2698+
relation_info["database"] = parts.catalog
2699+
relation_info["identifier"] = parts.name
2700+
relation_info["name"] = parts.name
2701+
2702+
jinja_globals = {
2703+
**existing_globals,
2704+
"this": relation_info,
2705+
"database": parts.catalog,
2706+
"schema": parts.db,
2707+
"identifier": parts.name,
2708+
"target": existing_globals.get("target", {"type": self.adapter.dialect}),
2709+
"execution_dt": kwargs.get("execution_time"),
2710+
}
2711+
2712+
context = create_builtin_globals(
2713+
jinja_macros=jinja_macros, jinja_globals=jinja_globals, engine_adapter=self.adapter
2714+
)
2715+
2716+
context.update(
2717+
{
2718+
"sql": str(query_or_df),
2719+
"is_first_insert": is_first_insert,
2720+
"create_only": create_only,
2721+
"pre_hooks": model.render_pre_statements(**render_kwargs),
2722+
"post_hooks": model.render_post_statements(**render_kwargs),
2723+
**kwargs,
2724+
}
2725+
)
2726+
2727+
try:
2728+
jinja_env = jinja_macros.build_environment(**context)
2729+
template = jinja_env.from_string(self.materialization_template)
2730+
2731+
try:
2732+
template.render(**context)
2733+
except MacroReturnVal as ret:
2734+
# this is a succesful return from a macro call (dbt uses this list of Relations to update their relation cache)
2735+
returned_relations = ret.value.get("relations", [])
2736+
logger.info(
2737+
f"Materialization {self.materialization_name} returned relations: {returned_relations}"
2738+
)
2739+
2740+
except Exception as e:
2741+
raise SQLMeshError(
2742+
f"Failed to execute dbt materialization '{self.materialization_name}': {e}"
2743+
) from e
2744+
2745+
25962746
class EngineManagedStrategy(MaterializableStrategy):
25972747
def create(
25982748
self,

sqlmesh/dbt/adapter.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ def execute(
9999
) -> t.Tuple[AdapterResponse, agate.Table]:
100100
"""Executes the given SQL statement and returns the results as an agate table."""
101101

102+
@abc.abstractmethod
103+
def run_hooks(
104+
self, hooks: t.List[str | exp.Expression], inside_transaction: bool = True
105+
) -> None:
106+
"""Executes the given hooks."""
107+
102108
@abc.abstractmethod
103109
def resolve_schema(self, relation: BaseRelation) -> t.Optional[str]:
104110
"""Resolves the relation's schema to its physical schema."""
@@ -241,6 +247,12 @@ def execute(
241247
self._raise_parsetime_adapter_call_error("execute SQL")
242248
raise
243249

250+
def run_hooks(
251+
self, hooks: t.List[str | exp.Expression], inside_transaction: bool = True
252+
) -> None:
253+
self._raise_parsetime_adapter_call_error("run hooks")
254+
raise
255+
244256
def resolve_schema(self, relation: BaseRelation) -> t.Optional[str]:
245257
return relation.schema
246258

@@ -451,6 +463,12 @@ def resolve_identifier(self, relation: BaseRelation) -> t.Optional[str]:
451463
identifier = self._map_table_name(self._normalize(self._relation_to_table(relation))).name
452464
return identifier if identifier else None
453465

466+
def run_hooks(
467+
self, hooks: t.List[str | exp.Expression], inside_transaction: bool = True
468+
) -> None:
469+
# inside_transaction not yet supported similarly to transaction
470+
self.engine_adapter.execute([exp.maybe_parse(hook) for hook in hooks])
471+
454472
def _map_table_name(self, table: exp.Table) -> exp.Table:
455473
# Use the default dialect since this is the dialect used to normalize and quote keys in the
456474
# mapping table.

sqlmesh/dbt/basemodel.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ class Materialization(str, Enum):
5757
# Snowflake, https://docs.getdbt.com/reference/resource-configs/snowflake-configs#dynamic-tables
5858
DYNAMIC_TABLE = "dynamic_table"
5959

60+
CUSTOM = "custom"
61+
62+
@classmethod
63+
def _missing_(cls, value): # type: ignore
64+
return cls.CUSTOM
65+
6066

6167
class SnapshotStrategy(str, Enum):
6268
"""DBT snapshot strategies"""
@@ -295,6 +301,14 @@ def sqlmesh_model_kwargs(
295301
# precisely which variables are referenced in the model
296302
dependencies.variables |= set(context.variables)
297303

304+
if (
305+
getattr(self, "model_materialization", None) == Materialization.CUSTOM
306+
and hasattr(self, "_get_custom_materialization")
307+
and (custom_mat := self._get_custom_materialization(context))
308+
):
309+
# include custom materialization dependencies as they might use macros
310+
dependencies = dependencies.union(custom_mat.dependencies)
311+
298312
model_dialect = self.dialect(context)
299313
model_context = context.context_for_dependencies(
300314
dependencies.union(self.tests_ref_source_dependencies)

sqlmesh/dbt/builtin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,7 @@ def create_builtin_globals(
544544
"load_result": sql_execution.load_result,
545545
"run_query": sql_execution.run_query,
546546
"statement": sql_execution.statement,
547+
"run_hooks": adapter.run_hooks,
547548
"graph": adapter.graph,
548549
"selected_resources": list(jinja_globals.get("selected_models") or []),
549550
}

0 commit comments

Comments
 (0)