Skip to content

Commit 9b0b066

Browse files
authored
Feat: Support the '_dbt_max_partition' variable for BigQuery dbt projects (#1195)
1 parent 7a9c702 commit 9b0b066

5 files changed

Lines changed: 130 additions & 10 deletions

File tree

sqlmesh/dbt/adapter.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,14 @@ def execute(
7979
) -> t.Tuple[AdapterResponse, agate.Table]:
8080
"""Executes the given SQL statement and returns the results as an agate table."""
8181

82+
@abc.abstractmethod
83+
def resolve_schema(self, relation: BaseRelation) -> t.Optional[str]:
84+
"""Resolves the relation's schema to its physical schema."""
85+
86+
@abc.abstractmethod
87+
def resolve_identifier(self, relation: BaseRelation) -> t.Optional[str]:
88+
"""Resolves the relation's schema to its physical identifier."""
89+
8290
def quote(self, identifier: str) -> str:
8391
"""Returns a quoted identifier."""
8492
return exp.to_column(identifier).sql(dialect=self.dialect, identify=True)
@@ -138,6 +146,12 @@ def execute(
138146
self._raise_parsetime_adapter_call_error("execute SQL")
139147
raise
140148

149+
def resolve_schema(self, relation: BaseRelation) -> t.Optional[str]:
150+
return relation.schema
151+
152+
def resolve_identifier(self, relation: BaseRelation) -> t.Optional[str]:
153+
return relation.identifier
154+
141155
@staticmethod
142156
def _raise_parsetime_adapter_call_error(action: str) -> None:
143157
raise ParsetimeAdapterCallError(f"Can't {action} at parse time.")
@@ -276,6 +290,20 @@ def execute(
276290
return AdapterResponse("Success"), pandas_to_agate(resp)
277291
return AdapterResponse("Success"), empty_table()
278292

293+
def resolve_schema(self, relation: BaseRelation) -> t.Optional[str]:
294+
schema = self._map_table_name(relation.database, relation.schema, relation.identifier).db
295+
if not schema:
296+
return None
297+
return schema
298+
299+
def resolve_identifier(self, relation: BaseRelation) -> t.Optional[str]:
300+
identifier = self._map_table_name(
301+
relation.database, relation.schema, relation.identifier
302+
).name
303+
if not identifier:
304+
return None
305+
return identifier
306+
279307
def _map_table_name(
280308
self, database: t.Optional[str], schema: t.Optional[str], identifier: t.Optional[str]
281309
) -> exp.Table:

sqlmesh/dbt/basemodel.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ class BaseModelConfig(GeneralConfig):
7171
database: Database the model is stored in
7272
schema: Custom schema name added to the model schema name
7373
alias: Relation identifier for this model instead of the filename
74-
sql_header: SQL statement to run before table/view creation. Currently implemented as a pre-hook.
7574
pre-hook: List of SQL statements to run before the model is built.
7675
post-hook: List of SQL statements to run after the model is built.
7776
full_refresh: Forces the model to always do a full refresh or never do a full refresh
@@ -94,7 +93,6 @@ class BaseModelConfig(GeneralConfig):
9493
schema_: str = Field("", alias="schema")
9594
database: t.Optional[str] = None
9695
alias: t.Optional[str] = None
97-
sql_header: t.Optional[str] = None
9896
pre_hook: t.List[Hook] = Field([], alias="pre-hook")
9997
post_hook: t.List[Hook] = Field([], alias="post-hook")
10098
full_refresh: t.Optional[bool] = None
@@ -229,10 +227,6 @@ def sqlmesh_model_kwargs(self, context: DbtContext) -> t.Dict[str, t.Any]:
229227
if field_val:
230228
optional_kwargs[field] = field_val
231229

232-
pre_hooks = self.pre_hook
233-
if self.sql_header:
234-
pre_hooks.insert(0, Hook(sql=self.sql_header))
235-
236230
return {
237231
"audits": [(test.name, {}) for test in self.tests],
238232
"columns": column_types_to_sqlmesh(self.columns) or None,
@@ -243,7 +237,7 @@ def sqlmesh_model_kwargs(self, context: DbtContext) -> t.Dict[str, t.Any]:
243237
"jinja_macros": jinja_macros,
244238
"path": self.path,
245239
"hash_raw_query": True,
246-
"pre_statements": [d.jinja_statement(hook.sql) for hook in pre_hooks],
240+
"pre_statements": [d.jinja_statement(hook.sql) for hook in self.pre_hook],
247241
"post_statements": [d.jinja_statement(hook.sql) for hook in self.post_hook],
248242
"tags": self.tags,
249243
**optional_kwargs,

sqlmesh/dbt/model.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class ModelConfig(BaseModelConfig):
5858
cluster_by: Field(s) to use for clustering in data warehouses that support clustering
5959
incremental_strategy: Strategy used to build the incremental model
6060
materialized: How the model will be materialized in the database
61-
sql_header: SQL statement to inject above create table/view as
61+
sql_header: SQL statement to run before table/view creation. Currently implemented as a pre-hook.
6262
unique_key: List of columns that define row uniqueness for the model
6363
partition_by: List of partition columns or dictionary of bigquery partition by parameters ([dbt bigquery config](https://docs.getdbt.com/reference/resource-configs/bigquery-configs)).
6464
"""
@@ -106,6 +106,9 @@ def _validate_partition_by(cls, v: t.Any) -> t.Union[t.List[str], t.Dict[str, t.
106106
if isinstance(v, dict):
107107
if not v.get("field"):
108108
raise ConfigError("'field' key required for partition_by.")
109+
if "granularity" in v and v["granularity"] not in GRANULARITY_TO_PARTITION_FORMAT:
110+
granularity = v["granularity"]
111+
raise ConfigError(f"Unexpected granularity '{granularity}' in partition_by '{v}'.")
109112
return {"data_type": "date", "granularity": "day", **v}
110113
raise ConfigError(f"Invalid format for partition_by '{v}'")
111114

@@ -268,12 +271,59 @@ def to_sqlmesh(self, context: DbtContext) -> Model:
268271
if not context.target:
269272
raise ConfigError(f"Target required to load '{self.sql_name}' into SQLMesh.")
270273

274+
model_kwargs = self.sqlmesh_model_kwargs(context)
275+
if self.sql_header:
276+
model_kwargs["pre_statements"].insert(0, d.jinja_statement(self.sql_header))
277+
278+
if context.target.type == "bigquery":
279+
dbt_max_partition_blob = self._dbt_max_partition_blob()
280+
if dbt_max_partition_blob:
281+
model_kwargs["pre_statements"].append(d.jinja_statement(dbt_max_partition_blob))
282+
271283
return create_sql_model(
272284
self.sql_name,
273285
query,
274286
dialect=dialect,
275287
kind=self.model_kind(context.target),
276288
start=self.start,
277289
**optional_kwargs,
278-
**self.sqlmesh_model_kwargs(context),
290+
**model_kwargs,
279291
)
292+
293+
def _dbt_max_partition_blob(self) -> t.Optional[str]:
294+
"""Returns a SQL blob which declares the _dbt_max_partition variable. Only applicable to BigQuery."""
295+
if (
296+
not isinstance(self.partition_by, dict)
297+
or self.model_materialization != Materialization.INCREMENTAL
298+
):
299+
return None
300+
301+
data_type = self.partition_by["data_type"]
302+
granularity = self.partition_by["granularity"]
303+
304+
parse_fun = f"parse_{data_type}" if data_type in ("date", "datetime", "timestamp") else None
305+
if parse_fun:
306+
parse_format = GRANULARITY_TO_PARTITION_FORMAT[granularity]
307+
partition_exp = f"{parse_fun}('{parse_format}', partition_id)"
308+
else:
309+
partition_exp = "CAST(partition_id AS INT64)"
310+
311+
return f"""
312+
{{% if is_incremental() %}}
313+
DECLARE _dbt_max_partition {data_type.upper()} DEFAULT (
314+
SELECT MAX({partition_exp})
315+
FROM `{{{{ target.database }}}}`.`{{{{ adapter.resolve_schema(this) }}}}`.INFORMATION_SCHEMA.PARTITIONS
316+
WHERE table_name = '{{{{ adapter.resolve_identifier(this) }}}}'
317+
AND partition_id IS NOT NULL
318+
AND partition_id != '__NULL__'
319+
);
320+
{{% endif %}}
321+
"""
322+
323+
324+
GRANULARITY_TO_PARTITION_FORMAT = {
325+
"day": "%Y%m%d",
326+
"month": "%Y%m",
327+
"year": "%Y",
328+
"hour": "%Y%m%d%H",
329+
}

tests/dbt/test_adapter.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import typing as t
44

55
import pytest
6+
from dbt.adapters.base import BaseRelation
67
from dbt.adapters.base.column import Column
78
from pytest_mock.plugin import MockerFixture
89
from sqlglot import exp
@@ -80,7 +81,11 @@ def test_adapter_map_snapshot_tables(
8081
assert context.target
8182
engine_adapter = context.target.to_sqlmesh().create_engine_adapter()
8283
renderer = runtime_renderer(
83-
context, engine_adapter=engine_adapter, snapshots={"test_db.test_model": snapshot_mock}
84+
context,
85+
engine_adapter=engine_adapter,
86+
snapshots={"test_db.test_model": snapshot_mock},
87+
test_model=BaseRelation.create(schema="test_db", identifier="test_model"),
88+
foo_bar=BaseRelation.create(schema="foo", identifier="bar"),
8489
)
8590

8691
engine_adapter.create_schema("foo")
@@ -106,3 +111,9 @@ def test_adapter_map_snapshot_tables(
106111
renderer("{{ adapter.get_relation(database=none, schema='foo', identifier='bar') }}")
107112
== '"foo"."bar"'
108113
)
114+
115+
assert renderer("{{ adapter.resolve_schema(test_model) }}") == "sqlmesh"
116+
assert renderer("{{ adapter.resolve_identifier(test_model) }}") == "test_db__test_model"
117+
118+
assert renderer("{{ adapter.resolve_schema(foo_bar) }}") == "foo"
119+
assert renderer("{{ adapter.resolve_identifier(foo_bar) }}") == "bar"

tests/dbt/test_transformation.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from dbt.adapters.base import BaseRelation
77
from dbt.contracts.relation import Policy
88
from dbt.exceptions import CompilationError
9+
from pytest_mock.plugin import MockerFixture
910
from sqlglot import exp, parse_one
1011

1112
from sqlmesh.core.context import Context
@@ -570,3 +571,39 @@ def test_is_incremental(sushi_test_project: Project, assert_exp_eq):
570571
model_config.to_sqlmesh(context).render_query_or_raise(has_intervals=True).sql(),
571572
'SELECT 1 AS "one" FROM "tbl_a" AS "tbl_a" WHERE "ds" > (SELECT MAX("ds") FROM "model" AS "model")',
572573
)
574+
575+
576+
def test_dbt_max_partition(sushi_test_project: Project, assert_exp_eq, mocker: MockerFixture):
577+
model_config = ModelConfig(
578+
name="model",
579+
package_name="package",
580+
schema="sushi",
581+
partition_by={"field": "`ds`", "data_type": "datetime", "granularity": "month"},
582+
materialized=Materialization.INCREMENTAL,
583+
sql="""
584+
SELECT 1 AS one FROM tbl_a
585+
{% if is_incremental() %}
586+
WHERE ds > _dbt_max_partition
587+
{% endif %}
588+
""",
589+
)
590+
context = sushi_test_project.context
591+
context.target = BigQueryConfig(
592+
name="test_target", schema="test_schema", database="test-project"
593+
)
594+
595+
assert (
596+
model_config.to_sqlmesh(context).pre_statements[-1].sql().strip() # type: ignore
597+
== """
598+
JINJA_STATEMENT_BEGIN;
599+
{% if is_incremental() %}
600+
DECLARE _dbt_max_partition DATETIME DEFAULT (
601+
SELECT MAX(parse_datetime('%Y%m', partition_id))
602+
FROM `{{ target.database }}`.`{{ adapter.resolve_schema(this) }}`.INFORMATION_SCHEMA.PARTITIONS
603+
WHERE table_name = '{{ adapter.resolve_identifier(this) }}'
604+
AND partition_id IS NOT NULL
605+
AND partition_id != '__NULL__'
606+
);
607+
{% endif %}
608+
JINJA_END;""".strip()
609+
)

0 commit comments

Comments
 (0)