Skip to content

Commit b783aaa

Browse files
authored
Fix!: bigquery create table requires correct trunc method for column type (#1048)
* partitioned_by config field removed for dbt. Use partition_by field instead.
1 parent 33b0844 commit b783aaa

3 files changed

Lines changed: 52 additions & 49 deletions

File tree

sqlmesh/dbt/model.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ class ModelConfig(BaseModelConfig):
4444
Args:
4545
sql: The model sql
4646
time_column: The name of the time column
47-
partitioned_by: List of columns to partition by. time_column will automatically be
48-
included, if specified.
4947
cron: A cron string specifying how often the model should be refreshed, leveraging the
5048
[croniter](https://github.com/kiorky/croniter) library.
5149
dialect: The SQL dialect that the model's query is written in. By default,
@@ -60,14 +58,12 @@ class ModelConfig(BaseModelConfig):
6058
materialized: How the model will be materialized in the database
6159
sql_header: SQL statement to inject above create table/view as
6260
unique_key: List of columns that define row uniqueness for the model
63-
partition_by: Dictionary of bigquery partition by parameters ([dbt bigquery config](https://docs.getdbt.com/reference/resource-configs/bigquery-configs)).
64-
If partitioned_by is set, this field will be ignored.
61+
partition_by: List of partition columns or dictionary of bigquery partition by parameters ([dbt bigquery config](https://docs.getdbt.com/reference/resource-configs/bigquery-configs)).
6562
"""
6663

6764
# sqlmesh fields
6865
sql: SqlStr = SqlStr("")
6966
time_column: t.Optional[str] = None
70-
partitioned_by: t.Optional[t.List[str]] = None
7167
cron: t.Optional[str] = None
7268
dialect: t.Optional[str] = None
7369
batch_size: t.Optional[int] = None
@@ -92,7 +88,6 @@ class ModelConfig(BaseModelConfig):
9288
@validator(
9389
"unique_key",
9490
"cluster_by",
95-
"partitioned_by",
9691
pre=True,
9792
)
9893
def _validate_list(cls, v: t.Union[str, t.List[str]]) -> t.List[str]:
@@ -111,9 +106,7 @@ def _validate_partition_by(cls, v: t.Any) -> t.Union[t.List[str], t.Dict[str, t.
111106
if isinstance(v, dict):
112107
if not v.get("field"):
113108
raise ConfigError("'field' key required for partition_by.")
114-
if not v.get("granularity"):
115-
v["granularity"] = "day"
116-
return v
109+
return {"data_type": "date", "granularity": "day", **v}
117110
raise ConfigError(f"Invalid format for partition_by '{v}'")
118111

119112
_FIELD_UPDATE_STRATEGY: t.ClassVar[t.Dict[str, UpdateStrategy]] = {
@@ -204,28 +197,37 @@ def _extract_sql_config(self) -> None:
204197
self._sql_no_config = SqlStr(no_config)
205198
self._sql_embedded_config = SqlStr(embedded_config)
206199

200+
@property
201+
def _big_query_partition_by_expr(self) -> exp.Expression:
202+
assert isinstance(self.partition_by, dict)
203+
data_type = self.partition_by["data_type"].lower()
204+
if data_type == "int64" or (
205+
data_type == "date" and self.partition_by["granularity"].lower() == "day"
206+
):
207+
return exp.to_column(self.partition_by["field"])
208+
209+
return TIME_TYPE_TO_TRUNC_EXPR[data_type](
210+
this=exp.to_column(self.partition_by["field"]),
211+
unit=exp.var(self.partition_by["granularity"].upper()),
212+
)
213+
207214
def to_sqlmesh(self, context: DbtContext) -> Model:
208215
"""Converts the dbt model into a SQLMesh model."""
209216
dialect = self.model_dialect or context.dialect
210217
query = d.jinja_query(self.sql_no_config)
211218

212219
optional_kwargs: t.Dict[str, t.Any] = {}
213220

214-
if self.partitioned_by:
215-
optional_kwargs["partitioned_by"] = [
216-
d.parse_one(val, dialect=dialect) for val in self.partitioned_by
217-
]
218-
elif self.partition_by and isinstance(self.partition_by, list):
219-
optional_kwargs["partitioned_by"] = [exp.to_column(val) for val in self.partition_by]
220-
elif self.partition_by and isinstance(self.partition_by, dict):
221-
optional_kwargs["partitioned_by"] = [
222-
exp.TimestampTrunc(
223-
this=exp.to_column(self.partition_by["field"]),
224-
unit=exp.var(self.partition_by["granularity"]),
225-
)
226-
]
221+
if self.partition_by:
222+
optional_kwargs["partitioned_by"] = (
223+
[exp.to_column(val) for val in self.partition_by]
224+
if isinstance(self.partition_by, list)
225+
else self._big_query_partition_by_expr
226+
)
227+
227228
if self.cluster_by:
228229
optional_kwargs["clustered_by"] = self.cluster_by
230+
229231
for field in ["cron"]:
230232
field_val = getattr(self, field, None) or self.meta.get(field, None)
231233
if field_val:
@@ -243,3 +245,10 @@ def to_sqlmesh(self, context: DbtContext) -> Model:
243245
**optional_kwargs,
244246
**self.sqlmesh_model_kwargs(context),
245247
)
248+
249+
250+
TIME_TYPE_TO_TRUNC_EXPR = {
251+
"date": exp.DateTrunc,
252+
"datetime": exp.DatetimeTrunc,
253+
"timestamp": exp.TimestampTrunc,
254+
}

tests/dbt/test_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def test_model_to_sqlmesh_fields(sushi_test_project: Project):
7272
description="test model",
7373
sql="SELECT 1 AS a FROM foo",
7474
start="Jan 1 2023",
75-
partitioned_by=["a"],
75+
partition_by=["a"],
7676
cluster_by=["a"],
7777
cron="@hourly",
7878
batch_size=5,

tests/dbt/test_transformation.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -454,55 +454,49 @@ def test_parsetime_adapter_call(
454454
def test_partition_by(sushi_test_project: Project):
455455
context = sushi_test_project.context
456456
model_config = ModelConfig(
457-
dialect="bigquery",
457+
dialect="spark",
458458
name="model",
459459
schema="test",
460460
package_name="package",
461461
materialized="table",
462462
unique_key="ds",
463-
partitioned_by="ds",
463+
partition_by="ds",
464464
sql="""SELECT 1 AS one, ds, ts FROM foo""",
465465
)
466466
assert model_config.to_sqlmesh(context).partitioned_by == [exp.to_column("ds")]
467467

468-
model_config.partitioned_by = "DATE_TRUNC(ds, MONTH)" # type: ignore
469-
assert model_config.to_sqlmesh(context).partitioned_by == [
470-
parse_one(model_config.partitioned_by[0], read="bigquery") # type: ignore
471-
]
468+
assert model_config.partition_by == ["ds"]
469+
assert model_config.to_sqlmesh(context).partitioned_by == [exp.to_column("ds")]
472470

473-
model_config.partitioned_by = ["ds", "ts"]
471+
model_config.partition_by = ["ds", "ts"]
474472
assert model_config.to_sqlmesh(context).partitioned_by == [
475473
exp.to_column("ds"),
476474
exp.to_column("ts"),
477475
]
478476

479477
model_config = ModelConfig(
480-
dialect="spark",
478+
dialect="bigquery",
481479
name="model",
482480
schema="test",
483481
package_name="package",
484482
materialized="table",
485483
unique_key="ds",
486-
partition_by="ds",
484+
partition_by={"field": "ds", "granularity": "month"},
487485
sql="""SELECT 1 AS one, ds FROM foo""",
488486
)
489-
assert model_config.partition_by == ["ds"]
490-
assert model_config.to_sqlmesh(context).partitioned_by == [exp.to_column("ds")]
487+
assert (
488+
model_config.to_sqlmesh(context).partitioned_by[0].sql(dialect="bigquery")
489+
== "DATE_TRUNC(ds, MONTH)"
490+
)
491491

492-
model_config.partition_by = ["ds"]
493-
assert model_config.partition_by == ["ds"]
492+
model_config.partition_by = {"field": "ds", "data_type": "timestamp", "granularity": "day"}
493+
assert (
494+
model_config.to_sqlmesh(context).partitioned_by[0].sql(dialect="bigquery")
495+
== "TIMESTAMP_TRUNC(ds, DAY)"
496+
)
497+
498+
model_config.partition_by = {"field": "ds", "data_type": "int64", "granularity": "day"}
494499
assert model_config.to_sqlmesh(context).partitioned_by == [exp.to_column("ds")]
495500

496-
model_config = ModelConfig(
497-
dialect="bigquery",
498-
name="model",
499-
schema="test",
500-
package_name="package",
501-
materialized="table",
502-
unique_key="ds",
503-
partition_by={"field": "ds", "granularity": "month"},
504-
sql="""SELECT 1 AS one, ds FROM foo""",
505-
)
506-
assert model_config.to_sqlmesh(context).partitioned_by == [
507-
parse_one("TIMESTAMP_TRUNC(ds, MONTH)", read="bigquery")
508-
]
501+
model_config.partition_by = {"field": "ds", "data_type": "date", "granularity": "day"}
502+
assert model_config.to_sqlmesh(context).partitioned_by == [exp.to_column("ds")]

0 commit comments

Comments
 (0)