Skip to content

Commit 33b0844

Browse files
authored
Feat: cluster by closes #570 (#1046)
1 parent 74fdcb6 commit 33b0844

18 files changed

Lines changed: 143 additions & 78 deletions

File tree

docs/concepts/models/overview.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ Name is ***required*** and must be ***unique***.
9595
### partitioned_by
9696
- Partitioned by is an optional property for engines such as Spark or Hive that support partitioning. Use this to add additional columns to the time column partition key.
9797

98+
### clustered_by
99+
- Clustered by is an optional property for engines such as Bigquery that support clustering.
100+
98101
### tags
99102
- Tags are one or more labels used to organize your models.
100103

sqlmesh/core/engine_adapter/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,7 +857,9 @@ def _create_table_properties(
857857
storage_format: t.Optional[str] = None,
858858
partitioned_by: t.Optional[t.List[exp.Expression]] = None,
859859
partition_interval_unit: t.Optional[IntervalUnit] = None,
860+
clustered_by: t.Optional[t.List[str]] = None,
860861
) -> t.Optional[exp.Properties]:
862+
"""Creates a SQLGlot table properties expression for ddl."""
861863
return None
862864

863865
def _to_sql(self, e: exp.Expression, **kwargs: t.Any) -> str:

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -337,39 +337,45 @@ def _create_table_properties(
337337
storage_format: t.Optional[str] = None,
338338
partitioned_by: t.Optional[t.List[exp.Expression]] = None,
339339
partition_interval_unit: t.Optional[IntervalUnit] = None,
340+
clustered_by: t.Optional[t.List[str]] = None,
340341
) -> t.Optional[exp.Properties]:
341-
if not partitioned_by:
342-
return None
343-
if partition_interval_unit is None:
344-
raise SQLMeshError("partition_interval_unit is required when partitioning a table")
345-
if len(partitioned_by) > 1:
346-
raise SQLMeshError("BigQuery only supports partitioning by a single column")
347-
348-
this: exp.Expression
349-
if isinstance(partitioned_by[0], exp.Column):
350-
if partition_interval_unit == IntervalUnit.MINUTE:
351-
raise SQLMeshError("BigQuery does not support partitioning by minute")
352-
353-
trunc_func: t.Optional[str] = None
354-
if partition_interval_unit == IntervalUnit.HOUR:
355-
trunc_func = "TIMESTAMP_TRUNC"
356-
elif partition_interval_unit in (IntervalUnit.MONTH, IntervalUnit.YEAR):
357-
trunc_func = "DATE_TRUNC"
358-
359-
if trunc_func:
360-
this = exp.func(
361-
trunc_func,
362-
partitioned_by[0],
363-
exp.var(partition_interval_unit.value.upper()),
364-
dialect=self.dialect,
365-
)
366-
else:
367-
this = partitioned_by[0]
368-
else:
342+
properties: t.List[exp.Expression] = []
343+
344+
if partitioned_by:
345+
if partition_interval_unit is None:
346+
raise SQLMeshError("partition_interval_unit is required when partitioning a table")
347+
if len(partitioned_by) > 1:
348+
raise SQLMeshError("BigQuery only supports partitioning by a single column")
349+
369350
this = partitioned_by[0]
370351

371-
partition_columns_property = exp.PartitionedByProperty(this=this)
372-
return exp.Properties(expressions=[partition_columns_property])
352+
if isinstance(this, exp.Column):
353+
if partition_interval_unit == IntervalUnit.MINUTE:
354+
raise SQLMeshError("BigQuery does not support partitioning by minute")
355+
356+
if partition_interval_unit == IntervalUnit.HOUR:
357+
trunc_func = "TIMESTAMP_TRUNC"
358+
elif partition_interval_unit in (IntervalUnit.MONTH, IntervalUnit.YEAR):
359+
trunc_func = "DATE_TRUNC"
360+
else:
361+
trunc_func = ""
362+
363+
if trunc_func:
364+
this = exp.func(
365+
trunc_func,
366+
this,
367+
exp.var(partition_interval_unit.value.upper()),
368+
dialect=self.dialect,
369+
)
370+
371+
properties.append(exp.PartitionedByProperty(this=this))
372+
373+
if clustered_by:
374+
properties.append(exp.Cluster(expressions=[exp.column(col) for col in clustered_by]))
375+
376+
if properties:
377+
return exp.Properties(expressions=properties)
378+
return None
373379

374380
def create_state_table(
375381
self,

sqlmesh/core/engine_adapter/spark.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -239,27 +239,26 @@ def _create_table_properties(
239239
storage_format: t.Optional[str] = None,
240240
partitioned_by: t.Optional[t.List[exp.Expression]] = None,
241241
partition_interval_unit: t.Optional[IntervalUnit] = None,
242+
clustered_by: t.Optional[t.List[str]] = None,
242243
) -> t.Optional[exp.Properties]:
243-
format_property = None
244-
partition_columns_property = None
244+
properties: t.List[exp.Expression] = []
245+
245246
if storage_format:
246-
format_property = exp.FileFormatProperty(this=exp.Var(this=storage_format))
247+
properties.append(exp.FileFormatProperty(this=exp.Var(this=storage_format)))
247248
if partitioned_by:
248249
for expr in partitioned_by:
249250
if not isinstance(expr, exp.Column):
250251
raise SQLMeshError(
251252
f"PARTITIONED BY contains non-column value '{expr.sql(dialect='spark')}'."
252253
)
253-
partition_columns_property = exp.PartitionedByProperty(
254-
this=exp.Schema(expressions=partitioned_by),
254+
properties.append(
255+
exp.PartitionedByProperty(
256+
this=exp.Schema(expressions=partitioned_by),
257+
)
255258
)
256-
return exp.Properties(
257-
expressions=[
258-
table_property
259-
for table_property in [format_property, partition_columns_property]
260-
if table_property
261-
]
262-
)
259+
if properties:
260+
return exp.Properties(expressions=properties)
261+
return None
263262

264263
def supports_transactions(self, transaction_type: TransactionType) -> bool:
265264
return False

sqlmesh/core/model/definition.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class _Model(ModelMeta, frozen=True):
101101
storage_format: The storage format used to store the physical table, only applicable in certain engines.
102102
(eg. 'parquet')
103103
partitioned_by: The partition columns or engine specific expressions, only applicable in certain engines. (eg. (ds, hour))
104+
clustered_by: The cluster columns, only applicable in certain engines. (eg. (ds, hour))
104105
python_env: Dictionary containing all global variables needed to render the model's macros.
105106
mapping_schema: The schema of table names to column and types.
106107
"""
@@ -556,29 +557,38 @@ def validate_definition(self) -> None:
556557
Raises:
557558
ConfigError
558559
"""
559-
if self.partitioned_by:
560-
unique_partition_keys = {
561-
col.name.strip().lower()
562-
for expr in self.partitioned_by
563-
for col in expr.find_all(exp.Column)
564-
}
565-
if len(self.partitioned_by) != len(unique_partition_keys):
566-
raise_config_error(
567-
"All partition keys must be unique in the model definition",
568-
self._path,
569-
)
570560

571-
columns_to_types = self.columns_to_types
572-
if columns_to_types is not None:
573-
column_names = {c.lower() for c in columns_to_types}
574-
missing_keys = unique_partition_keys - column_names
575-
if missing_keys:
576-
missing_keys_str = ", ".join(f"'{k}'" for k in sorted(missing_keys))
561+
for field in ("partitioned_by", "clustered_by"):
562+
values = getattr(self, field)
563+
564+
if values:
565+
values = [
566+
col.name
567+
for expr in values
568+
for col in t.cast(
569+
exp.Expression, exp.maybe_parse(expr, dialect=self.dialect)
570+
).find_all(exp.Column)
571+
]
572+
573+
unique_keys = set(values)
574+
575+
if len(values) != len(unique_keys):
577576
raise_config_error(
578-
f"Partition keys [{missing_keys_str}] are missing in the model definition",
577+
"All keys in '{field}' must be unique in the model definition",
579578
self._path,
580579
)
581580

581+
columns_to_types = self.columns_to_types
582+
if columns_to_types is not None:
583+
column_names = {c.lower() for c in columns_to_types}
584+
missing_keys = unique_keys - column_names
585+
if missing_keys:
586+
missing_keys_str = ", ".join(f"'{k}'" for k in sorted(missing_keys))
587+
raise_config_error(
588+
f"{field} keys [{missing_keys_str}] are missing in the model definition",
589+
self._path,
590+
)
591+
582592
if self.kind.is_incremental_by_time_range and not self.time_column:
583593
raise_config_error(
584594
"Incremental by time range models must have a time_column field.",
@@ -1658,6 +1668,7 @@ def _single_expr_or_tuple(values: t.Sequence[exp.Expression]) -> exp.Expression
16581668
"cron": lambda value: exp.Literal.string(value),
16591669
"batch_size": lambda value: exp.Literal.number(value),
16601670
"partitioned_by_": _single_expr_or_tuple,
1671+
"clustered_by": _single_value_or_tuple,
16611672
"depends_on_": lambda value: exp.Tuple(expressions=value),
16621673
"pre": _list_of_calls_to_exp,
16631674
"post": _list_of_calls_to_exp,

sqlmesh/core/model/meta.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class ModelMeta(PydanticModel):
5757
retention: t.Optional[int] # not implemented yet
5858
storage_format: t.Optional[str]
5959
partitioned_by_: t.List[exp.Expression] = Field(default=[], alias="partitioned_by")
60+
clustered_by: t.List[str] = []
6061
depends_on_: t.Optional[t.Set[str]] = Field(default=None, alias="depends_on")
6162
columns_to_types_: t.Optional[t.Dict[str, exp.DataType]] = Field(default=None, alias="columns")
6263
column_descriptions_: t.Optional[t.Dict[str, str]]
@@ -115,7 +116,7 @@ def extract(v: exp.Expression) -> t.Tuple[str, t.Dict[str, str]]:
115116
]
116117
return v
117118

118-
@validator("tags", "grain", pre=True)
119+
@validator("clustered_by", "tags", "grain", pre=True)
119120
def _value_or_tuple_validator(cls, v: t.Any) -> t.Any:
120121
if isinstance(v, (exp.Tuple, exp.Array)):
121122
return [e.name for e in v.expressions]
@@ -215,8 +216,9 @@ def _date_validator(cls, v: t.Any) -> t.Optional[TimeLike]:
215216
def _kind_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
216217
kind = values.get("kind")
217218
if kind:
218-
if values.get("partitioned_by_") and not kind.is_materialized:
219-
raise ValueError(f"partitioned_by field cannot be set for {kind} models")
219+
for field in ("partitioned_by_", "clustered_by"):
220+
if values.get(field) and not kind.is_materialized:
221+
raise ValueError(f"{field} field cannot be set for {kind} models")
220222

221223
return values
222224

sqlmesh/core/snapshot/definition.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,7 @@ def _model_data_hash(model: Model) -> str:
915915
model.storage_format,
916916
str(model.lookback),
917917
*(expr.sql() for expr in (model.partitioned_by or [])),
918+
*(model.clustered_by or []),
918919
model.stamp,
919920
]
920921

sqlmesh/core/snapshot/evaluator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,7 @@ def create(
646646
storage_format=model.storage_format,
647647
partitioned_by=model.partitioned_by,
648648
partition_interval_unit=model.interval_unit(),
649+
clustered_by=model.clustered_by,
649650
)
650651
else:
651652
self.adapter.ctas(
@@ -655,6 +656,7 @@ def create(
655656
storage_format=model.storage_format,
656657
partitioned_by=model.partitioned_by,
657658
partition_interval_unit=model.interval_unit(),
659+
clustered_by=model.clustered_by,
658660
)
659661

660662
def migrate(self, target_table_name: str, source_table_name: str) -> None:

sqlmesh/dbt/model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,12 +219,13 @@ def to_sqlmesh(self, context: DbtContext) -> Model:
219219
optional_kwargs["partitioned_by"] = [exp.to_column(val) for val in self.partition_by]
220220
elif self.partition_by and isinstance(self.partition_by, dict):
221221
optional_kwargs["partitioned_by"] = [
222-
d.parse_one(
223-
f"TIMESTAMP_TRUNC(`{self.partition_by['field']}`, {self.partition_by['granularity']})",
224-
dialect=dialect,
222+
exp.TimestampTrunc(
223+
this=exp.to_column(self.partition_by["field"]),
224+
unit=exp.var(self.partition_by["granularity"]),
225225
)
226226
]
227-
227+
if self.cluster_by:
228+
optional_kwargs["clustered_by"] = self.cluster_by
228229
for field in ["cron"]:
229230
field_val = getattr(self, field, None) or self.meta.get(field, None)
230231
if field_val:

tests/core/engine_adapter/test_bigquery.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def test_create_table_date_partition(
169169
{"a": "int", "b": "int"},
170170
partitioned_by=partition_by_cols,
171171
partition_interval_unit=IntervalUnit.DAY,
172+
clustered_by=["b"],
172173
)
173174

174175
sql_calls = [
@@ -179,7 +180,7 @@ def test_create_table_date_partition(
179180
for call in execute_mock.call_args_list
180181
]
181182
assert sql_calls == [
182-
f"CREATE TABLE IF NOT EXISTS `test_table` (`a` int, `b` int) PARTITION BY {partition_by_statement}"
183+
f"CREATE TABLE IF NOT EXISTS `test_table` (`a` int, `b` int) PARTITION BY {partition_by_statement} CLUSTER BY `b`"
183184
]
184185

185186

0 commit comments

Comments
 (0)