Skip to content

Commit f07bb06

Browse files
authored
Feat: Support partitioned_by expressions for bigquery (#1041)
1 parent 73c3a32 commit f07bb06

19 files changed

Lines changed: 405 additions & 125 deletions

File tree

sqlmesh/core/dialect.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sqlglot.tokens import Token
1515

1616
from sqlmesh.core.constants import MAX_MODEL_DEFINITION_SIZE
17+
from sqlmesh.utils.errors import SQLMeshError
1718
from sqlmesh.utils.pandas import columns_to_types_from_df
1819

1920

@@ -463,6 +464,15 @@ class ChunkType(Enum):
463464
SQL = auto()
464465

465466

467+
def parse_one(sql: str, dialect: t.Optional[str] = None) -> exp.Expression:
468+
expressions = parse(sql, default_dialect=dialect)
469+
if not expressions:
470+
raise SQLMeshError(f"No expressions found in '{sql}'")
471+
elif len(expressions) > 1:
472+
raise SQLMeshError(f"Multiple expressions found in '{sql}'")
473+
return expressions[0]
474+
475+
466476
def parse(sql: str, default_dialect: t.Optional[str] = None) -> t.List[exp.Expression]:
467477
"""Parse a sql string.
468478

sqlmesh/core/engine_adapter/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -855,7 +855,7 @@ def temp_table(self, query_or_df: QueryOrDF, name: str = "diff") -> t.Iterator[e
855855
def _create_table_properties(
856856
self,
857857
storage_format: t.Optional[str] = None,
858-
partitioned_by: t.Optional[t.List[str]] = None,
858+
partitioned_by: t.Optional[t.List[exp.Expression]] = None,
859859
partition_interval_unit: t.Optional[IntervalUnit] = None,
860860
) -> t.Optional[exp.Properties]:
861861
return None

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -335,28 +335,38 @@ def _fetch_native_df(self, query: t.Union[exp.Expression, str]) -> DF:
335335
def _create_table_properties(
336336
self,
337337
storage_format: t.Optional[str] = None,
338-
partitioned_by: t.Optional[t.List[str]] = None,
338+
partitioned_by: t.Optional[t.List[exp.Expression]] = None,
339339
partition_interval_unit: t.Optional[IntervalUnit] = None,
340340
) -> t.Optional[exp.Properties]:
341341
if not partitioned_by:
342342
return None
343343
if partition_interval_unit is None:
344344
raise SQLMeshError("partition_interval_unit is required when partitioning a table")
345-
if partition_interval_unit == IntervalUnit.MINUTE:
346-
raise SQLMeshError("BigQuery does not support partitioning by minute")
347345
if len(partitioned_by) > 1:
348346
raise SQLMeshError("BigQuery only supports partitioning by a single column")
349-
partition_col = exp.to_column(partitioned_by[0])
350-
this: t.Union[exp.Func, exp.Column]
351-
if partition_interval_unit == IntervalUnit.HOUR:
352-
this = exp.func(
353-
"TIMESTAMP_TRUNC",
354-
partition_col,
355-
exp.var(IntervalUnit.HOUR.value.upper()),
356-
dialect=self.dialect,
357-
)
347+
348+
this: exp.Expression
349+
if isinstance(partitioned_by[0], exp.Column):
350+
if partition_interval_unit == IntervalUnit.MINUTE:
351+
raise SQLMeshError("BigQuery does not support partitioning by minute")
352+
353+
trunc_func: t.Optional[str] = None
354+
if partition_interval_unit == IntervalUnit.HOUR:
355+
trunc_func = "TIMESTAMP_TRUNC"
356+
elif partition_interval_unit in (IntervalUnit.MONTH, IntervalUnit.YEAR):
357+
trunc_func = "DATE_TRUNC"
358+
359+
if trunc_func:
360+
this = exp.func(
361+
trunc_func,
362+
partitioned_by[0],
363+
exp.var(partition_interval_unit.value.upper()),
364+
dialect=self.dialect,
365+
)
366+
else:
367+
this = partitioned_by[0]
358368
else:
359-
this = partition_col
369+
this = partitioned_by[0]
360370

361371
partition_columns_property = exp.PartitionedByProperty(this=this)
362372
return exp.Properties(expressions=[partition_columns_property])

sqlmesh/core/engine_adapter/spark.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,18 +237,21 @@ def create_view(
237237
def _create_table_properties(
238238
self,
239239
storage_format: t.Optional[str] = None,
240-
partitioned_by: t.Optional[t.List[str]] = None,
240+
partitioned_by: t.Optional[t.List[exp.Expression]] = None,
241241
partition_interval_unit: t.Optional[IntervalUnit] = None,
242242
) -> t.Optional[exp.Properties]:
243243
format_property = None
244244
partition_columns_property = None
245245
if storage_format:
246246
format_property = exp.FileFormatProperty(this=exp.Var(this=storage_format))
247247
if partitioned_by:
248+
for expr in partitioned_by:
249+
if not isinstance(expr, exp.Column):
250+
raise SQLMeshError(
251+
f"PARTITIONED BY contains non-column value '{expr.sql(dialect='spark')}'."
252+
)
248253
partition_columns_property = exp.PartitionedByProperty(
249-
this=exp.Schema(
250-
expressions=[exp.to_identifier(column) for column in partitioned_by]
251-
),
254+
this=exp.Schema(expressions=partitioned_by),
252255
)
253256
return exp.Properties(
254257
expressions=[

sqlmesh/core/model/definition.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class _Model(ModelMeta, frozen=True):
100100
lookback: The number of previous incremental intervals in the lookback window.
101101
storage_format: The storage format used to store the physical table, only applicable in certain engines.
102102
(eg. 'parquet')
103-
partitioned_by: The partition columns, only applicable in certain engines. (eg. (ds, hour))
103+
partitioned_by: The partition columns or engine specific expressions, only applicable in certain engines. (eg. (ds, hour))
104104
python_env: Dictionary containing all global variables needed to render the model's macros.
105105
mapping_schema: The schema of table names to column and types.
106106
"""
@@ -557,7 +557,11 @@ def validate_definition(self) -> None:
557557
ConfigError
558558
"""
559559
if self.partitioned_by:
560-
unique_partition_keys = {k.strip().lower() for k in self.partitioned_by}
560+
unique_partition_keys = {
561+
col.name.strip().lower()
562+
for expr in self.partitioned_by
563+
for col in expr.find_all(exp.Column)
564+
}
561565
if len(self.partitioned_by) != len(unique_partition_keys):
562566
raise_config_error(
563567
"All partition keys must be unique in the model definition",
@@ -1644,12 +1648,16 @@ def _single_value_or_tuple(values: t.Sequence) -> exp.Identifier | exp.Tuple:
16441648
)
16451649

16461650

1651+
def _single_expr_or_tuple(values: t.Sequence[exp.Expression]) -> exp.Expression | exp.Tuple:
1652+
return values[0] if len(values) == 1 else exp.Tuple(expressions=values)
1653+
1654+
16471655
META_FIELD_CONVERTER: t.Dict[str, t.Callable] = {
16481656
"name": lambda value: exp.to_table(value),
16491657
"start": lambda value: exp.Literal.string(value),
16501658
"cron": lambda value: exp.Literal.string(value),
16511659
"batch_size": lambda value: exp.Literal.number(value),
1652-
"partitioned_by_": _single_value_or_tuple,
1660+
"partitioned_by_": _single_expr_or_tuple,
16531661
"depends_on_": lambda value: exp.Tuple(expressions=value),
16541662
"pre": _list_of_calls_to_exp,
16551663
"post": _list_of_calls_to_exp,

sqlmesh/core/model/meta.py

Lines changed: 71 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from __future__ import annotations
22

33
import typing as t
4-
from datetime import timedelta
54
from enum import Enum
65

76
from pydantic import Field, root_validator, validator
87
from sqlglot import exp
8+
from sqlglot.helper import ensure_list
99

1010
from sqlmesh.core import dialect as d
1111
from sqlmesh.core.model.kind import (
@@ -15,7 +15,6 @@
1515
ViewKind,
1616
_Incremental,
1717
)
18-
from sqlmesh.utils import unique
1918
from sqlmesh.utils.cron import CroniterCache
2019
from sqlmesh.utils.date import TimeLike, to_datetime
2120
from sqlmesh.utils.errors import ConfigError
@@ -25,15 +24,21 @@
2524
class IntervalUnit(str, Enum):
2625
"""IntervalUnit is the inferred granularity of an incremental model.
2726
28-
IntervalUnit can be one of 4 types, DAY, HOUR, MINUTE. The unit is inferred
27+
IntervalUnit can be one of 5 types, YEAR, MONTH, DAY, HOUR, MINUTE. The unit is inferred
2928
based on the cron schedule of a model. The minimum time delta between a sample set of dates
3029
is used to determine which unit a model's schedule is.
3130
"""
3231

32+
YEAR = "year"
33+
MONTH = "month"
3334
DAY = "day"
3435
HOUR = "hour"
3536
MINUTE = "minute"
3637

38+
@property
39+
def is_date_granularity(self) -> bool:
40+
return self in (IntervalUnit.YEAR, IntervalUnit.MONTH, IntervalUnit.DAY)
41+
3742

3843
AuditReference = t.Tuple[str, t.Dict[str, exp.Expression]]
3944

@@ -51,7 +56,7 @@ class ModelMeta(PydanticModel):
5156
start: t.Optional[TimeLike]
5257
retention: t.Optional[int] # not implemented yet
5358
storage_format: t.Optional[str]
54-
partitioned_by_: t.List[str] = Field(default=[], alias="partitioned_by")
59+
partitioned_by_: t.List[exp.Expression] = Field(default=[], alias="partitioned_by")
5560
depends_on_: t.Optional[t.Set[str]] = Field(default=None, alias="depends_on")
5661
columns_to_types_: t.Optional[t.Dict[str, exp.DataType]] = Field(default=None, alias="columns")
5762
column_descriptions_: t.Optional[t.Dict[str, str]]
@@ -110,7 +115,7 @@ def extract(v: exp.Expression) -> t.Tuple[str, t.Dict[str, str]]:
110115
]
111116
return v
112117

113-
@validator("partitioned_by_", "tags", "grain", pre=True)
118+
@validator("tags", "grain", pre=True)
114119
def _value_or_tuple_validator(cls, v: t.Any) -> t.Any:
115120
if isinstance(v, (exp.Tuple, exp.Array)):
116121
return [e.name for e in v.expressions]
@@ -136,6 +141,39 @@ def _cron_validator(cls, v: t.Any) -> t.Optional[str]:
136141
raise ConfigError(f"Invalid cron expression '{cron}'")
137142
return cron
138143

144+
@validator("partitioned_by_", pre=True)
145+
def _partition_by_validator(
146+
cls, v: t.Any, values: t.Dict[str, t.Any]
147+
) -> t.List[exp.Expression]:
148+
partitions: t.List[exp.Expression]
149+
if isinstance(v, (exp.Tuple, exp.Array)):
150+
partitions = v.expressions
151+
elif isinstance(v, exp.Expression):
152+
partitions = [v]
153+
else:
154+
dialect = values.get("dialect")
155+
partitions = [
156+
d.parse_one(entry, dialect=dialect) if isinstance(entry, str) else entry
157+
for entry in ensure_list(v)
158+
]
159+
partitions = [
160+
exp.to_column(expr.name) if isinstance(expr, exp.Identifier) else expr
161+
for expr in partitions
162+
]
163+
164+
for partition in partitions:
165+
num_cols = len(list(partition.find_all(exp.Column)))
166+
error_msg: t.Optional[str] = None
167+
if num_cols == 0:
168+
error_msg = "does not contain a column"
169+
elif num_cols > 1:
170+
error_msg = "contains multiple columns"
171+
172+
if error_msg:
173+
raise ConfigError(f"partitioned_by field '{partition}' {error_msg}")
174+
175+
return partitions
176+
139177
@validator("columns_to_types_", pre=True)
140178
def _columns_validator(
141179
cls, v: t.Any, values: t.Dict[str, t.Any]
@@ -194,9 +232,12 @@ def unique_key(self) -> t.List[str]:
194232
return []
195233

196234
@property
197-
def partitioned_by(self) -> t.List[str]:
198-
time_column = [self.time_column.column] if self.time_column else []
199-
return unique([*time_column, *self.partitioned_by_])
235+
def partitioned_by(self) -> t.List[exp.Expression]:
236+
if self.time_column and self.time_column.column not in [
237+
col.name for col in self._partition_by_columns
238+
]:
239+
return [*[exp.to_column(self.time_column.column)], *self.partitioned_by_]
240+
return self.partitioned_by_
200241

201242
@property
202243
def column_descriptions(self) -> t.Dict[str, str]:
@@ -208,18 +249,13 @@ def lookback(self) -> int:
208249
"""The incremental lookback window."""
209250
return (self.kind.lookback if isinstance(self.kind, _Incremental) else 0) or 0
210251

211-
@property
212-
def lookback_delta(self) -> timedelta:
213-
"""The incremental lookback time delta."""
214-
if isinstance(self.kind, _Incremental):
215-
interval_unit = self.interval_unit()
216-
if interval_unit == IntervalUnit.DAY:
217-
return timedelta(days=self.lookback)
218-
if interval_unit == IntervalUnit.HOUR:
219-
return timedelta(hours=self.lookback)
220-
if interval_unit == IntervalUnit.MINUTE:
221-
return timedelta(minutes=self.lookback)
222-
return timedelta()
252+
def lookback_start(self, start: TimeLike) -> TimeLike:
253+
if self.lookback == 0:
254+
return start
255+
256+
for _ in range(self.lookback):
257+
start = self.cron_prev(start)
258+
return start
223259

224260
@property
225261
def batch_size(self) -> t.Optional[int]:
@@ -241,7 +277,11 @@ def interval_unit(self, sample_size: int = 10) -> IntervalUnit:
241277
croniter = CroniterCache(self.cron)
242278
samples = [croniter.get_next() for _ in range(sample_size)]
243279
min_interval = min(b - a for a, b in zip(samples, samples[1:]))
244-
if min_interval >= 86400:
280+
if min_interval >= 31536000:
281+
self._interval_unit = IntervalUnit.YEAR
282+
elif min_interval >= 2419200:
283+
self._interval_unit = IntervalUnit.MONTH
284+
elif min_interval >= 86400:
245285
self._interval_unit = IntervalUnit.DAY
246286
elif min_interval >= 3600:
247287
self._interval_unit = IntervalUnit.HOUR
@@ -252,8 +292,8 @@ def interval_unit(self, sample_size: int = 10) -> IntervalUnit:
252292
def normalized_cron(self) -> str:
253293
"""Returns the UTC normalized cron based on sampling heuristics.
254294
255-
SQLMesh supports 3 interval units, daily, hourly, and minutes. If a job is scheduled
256-
daily at 1PM, the actual intervals are shifted back to midnight UTC.
295+
SQLMesh supports 5 interval units, yearly, monthly, daily, hourly, and minutes. If a
296+
job is scheduled daily at 1PM, the actual intervals are shifted back to midnight UTC.
257297
258298
Returns:
259299
The cron string representing either daily, hourly, or minutes.
@@ -265,6 +305,10 @@ def normalized_cron(self) -> str:
265305
return "0 * * * *"
266306
if unit == IntervalUnit.DAY:
267307
return "0 0 * * *"
308+
if unit == IntervalUnit.MONTH:
309+
return "0 0 1 * *"
310+
if unit == IntervalUnit.YEAR:
311+
return "0 0 1 1 *"
268312
return ""
269313

270314
def croniter(self, value: TimeLike) -> CroniterCache:
@@ -309,3 +353,7 @@ def cron_floor(self, value: TimeLike) -> TimeLike:
309353
The timestamp floor.
310354
"""
311355
return self.croniter(self.cron_next(value)).get_prev()
356+
357+
@property
358+
def _partition_by_columns(self) -> t.List[exp.Column]:
359+
return [col for expr in self.partitioned_by_ for col in expr.find_all(exp.Column)]

sqlmesh/core/plan/definition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def _get_end_date(self, end_and_units: t.List[t.Tuple[int, IntervalUnit]]) -> Ti
181181
if end_and_units:
182182
end, unit = max(end_and_units)
183183

184-
if unit == IntervalUnit.DAY:
184+
if unit.is_date_granularity:
185185
return to_date(make_inclusive_end(end))
186186
return end
187187
return now()

sqlmesh/core/snapshot/definition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,7 @@ def _model_data_hash(model: Model) -> str:
914914
model.cron,
915915
model.storage_format,
916916
str(model.lookback),
917-
*(model.partitioned_by or []),
917+
*(expr.sql() for expr in (model.partitioned_by or [])),
918918
model.stamp,
919919
]
920920

@@ -1062,7 +1062,7 @@ def merge_intervals(intervals: Intervals) -> Intervals:
10621062

10631063

10641064
def _format_date_time(time_like: TimeLike, unit: t.Optional[IntervalUnit]) -> str:
1065-
if unit is None or unit == IntervalUnit.DAY:
1065+
if unit is None or unit.is_date_granularity:
10661066
return to_ds(time_like)
10671067
return to_datetime(time_like).isoformat()[:19]
10681068

0 commit comments

Comments
 (0)