Skip to content

Commit 6f39087

Browse files
authored
Feat: accept expressions in unique key (#1528)
closes #1516
1 parent dbe16e3 commit 6f39087

14 files changed

Lines changed: 204 additions & 163 deletions

File tree

sqlmesh/core/dialect.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,19 @@ def find_tables(expression: exp.Expression, dialect: DialectType = None) -> t.Se
814814
}
815815

816816

817+
def add_table(node: exp.Expression, table: str) -> exp.Expression:
818+
"""Add a table to all columns in an expression."""
819+
820+
def _transform(node: exp.Expression) -> exp.Expression:
821+
if isinstance(node, exp.Column) and not node.table:
822+
return exp.column(node.this, table=table)
823+
if isinstance(node, exp.Identifier):
824+
return exp.column(node, table=table)
825+
return node
826+
827+
return node.transform(_transform)
828+
829+
817830
def transform_values(
818831
values: t.Tuple[t.Any, ...], columns_to_types: t.Dict[str, exp.DataType]
819832
) -> t.Iterator[t.Any]:

sqlmesh/core/engine_adapter/base.py

Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from sqlglot.helper import ensure_list
2626
from sqlglot.optimizer.qualify_columns import quote_identifiers
2727

28-
from sqlmesh.core.dialect import select_from_values_for_batch_range
28+
from sqlmesh.core.dialect import add_table, select_from_values_for_batch_range
2929
from sqlmesh.core.engine_adapter.shared import DataObject, TransactionType
3030
from sqlmesh.core.model.kind import TimeColumn
3131
from sqlmesh.core.schema_diff import SchemaDiffer
@@ -911,7 +911,7 @@ def scd_type_2(
911911
self,
912912
target_table: TableName,
913913
source_table: QueryOrDF,
914-
unique_key: t.Sequence[str],
914+
unique_key: t.Sequence[exp.Expression],
915915
valid_from_name: str,
916916
valid_to_name: str,
917917
updated_at_name: str,
@@ -937,7 +937,7 @@ def scd_type_2(
937937
exp.Select() # type: ignore
938938
.with_(
939939
"source",
940-
exp.select(*unmanaged_columns)
940+
exp.select(exp.true().as_("_exists"), *unmanaged_columns)
941941
.distinct(*unique_key)
942942
.from_(source_query.subquery("raw_source")), # type: ignore
943943
)
@@ -964,8 +964,8 @@ def scd_type_2(
964964
"latest",
965965
on=exp.and_(
966966
*[
967-
exp.column(col, table="static").eq(exp.column(col, table="latest"))
968-
for col in unique_key
967+
add_table(key, "static").eq(add_table(key, "latest"))
968+
for key in unique_key
969969
]
970970
),
971971
join_type="left",
@@ -976,7 +976,8 @@ def scd_type_2(
976976
.with_(
977977
"latest_deleted",
978978
exp.select(
979-
*unique_key,
979+
exp.true().as_("_exists"),
980+
*(part.as_(f"_key{i}") for i, part in enumerate(unique_key)),
980981
f"MAX({valid_to_name}) AS {valid_to_name}",
981982
)
982983
.from_("deleted")
@@ -987,34 +988,43 @@ def scd_type_2(
987988
.with_(
988989
"joined",
989990
exp.select(
990-
*(f"latest.{col} AS t_{col}" for col in columns_to_types),
991-
*(f"source.{col} AS s_{col}" for col in unmanaged_columns),
991+
exp.column("_exists", table="source"),
992+
*(
993+
exp.column(col, table="latest").as_(f"t_{col}")
994+
for col in columns_to_types
995+
),
996+
*(exp.column(col, table="source").as_(col) for col in unmanaged_columns),
992997
)
993998
.from_("latest")
994999
.join(
9951000
"source",
9961001
on=exp.and_(
9971002
*[
998-
exp.column(col, table="latest").eq(exp.column(col, table="source"))
999-
for col in unique_key
1003+
add_table(key, "latest").eq(add_table(key, "source"))
1004+
for key in unique_key
10001005
]
10011006
),
10021007
join_type="left",
10031008
)
10041009
.union(
10051010
exp.select(
1006-
*(f"latest.{col} AS t_{col}" for col in columns_to_types),
1007-
*(f"source.{col} AS s_{col}" for col in unmanaged_columns),
1011+
exp.column("_exists", table="source"),
1012+
*(
1013+
exp.column(col, table="latest").as_(f"t_{col}")
1014+
for col in columns_to_types
1015+
),
1016+
*(
1017+
exp.column(col, table="source").as_(col)
1018+
for col in unmanaged_columns
1019+
),
10081020
)
10091021
.from_("latest")
10101022
.join(
10111023
"source",
10121024
on=exp.and_(
10131025
*[
1014-
exp.column(col, table="latest").eq(
1015-
exp.column(col, table="source")
1016-
)
1017-
for col in unique_key
1026+
add_table(key, "latest").eq(add_table(key, "source"))
1027+
for key in unique_key
10181028
]
10191029
),
10201030
join_type="right",
@@ -1025,25 +1035,32 @@ def scd_type_2(
10251035
.with_(
10261036
"updated_rows",
10271037
exp.select(
1028-
*(f"COALESCE(t_{col}, s_{col}) as {col}" for col in unmanaged_columns),
1038+
*(
1039+
exp.func(
1040+
"COALESCE",
1041+
exp.column(f"t_{col}", table="joined"),
1042+
exp.column(col, table="joined"),
1043+
).as_(col)
1044+
for col in unmanaged_columns
1045+
),
10291046
f"""
10301047
CASE
10311048
WHEN t_{valid_from_name} IS NULL
1032-
AND latest_deleted.{unique_key[0]} IS NOT NULL
1049+
AND latest_deleted._exists IS NOT NULL
10331050
THEN CASE
1034-
WHEN latest_deleted.{valid_to_name} > s_{updated_at_name}
1051+
WHEN latest_deleted.{valid_to_name} > {updated_at_name}
10351052
THEN latest_deleted.{valid_to_name}
1036-
ELSE s_{updated_at_name}
1053+
ELSE {updated_at_name}
10371054
END
10381055
WHEN t_{valid_from_name} IS NULL
10391056
THEN {self._to_utc_timestamp('1970-01-01 00:00:00+00:00')}
10401057
ELSE t_{valid_from_name}
10411058
END AS {valid_from_name}""",
10421059
f"""
10431060
CASE
1044-
WHEN s_{updated_at_name} > t_{updated_at_name}
1045-
THEN s_{updated_at_name}
1046-
WHEN s_{unique_key[0]} IS NULL
1061+
WHEN {updated_at_name} > t_{updated_at_name}
1062+
THEN {updated_at_name}
1063+
WHEN joined._exists IS NULL
10471064
THEN {self._to_utc_timestamp(to_ts(execution_time))}
10481065
ELSE t_{valid_to_name}
10491066
END AS {valid_to_name}""",
@@ -1053,10 +1070,10 @@ def scd_type_2(
10531070
"latest_deleted",
10541071
on=exp.and_(
10551072
*[
1056-
exp.column(f"s_{col}", table="joined").eq(
1057-
exp.column(col, table="latest_deleted")
1073+
add_table(part, "joined").eq(
1074+
exp.column(f"_key{i}", "latest_deleted")
10581075
)
1059-
for col in unique_key
1076+
for i, part in enumerate(unique_key)
10601077
]
10611078
),
10621079
join_type="left",
@@ -1066,14 +1083,12 @@ def scd_type_2(
10661083
.with_(
10671084
"inserted_rows",
10681085
exp.select(
1069-
*(f"s_{col} as {col}" for col in unmanaged_columns),
1070-
f"s_{updated_at_name} as {valid_from_name}",
1086+
*unmanaged_columns,
1087+
f"{updated_at_name} as {valid_from_name}",
10711088
f"{self._to_utc_timestamp(exp.null())} as {valid_to_name}",
10721089
)
10731090
.from_("joined")
1074-
.where(
1075-
f"t_{unique_key[0]} IS NOT NULL AND s_{unique_key[0]} IS NOT NULL AND s_{updated_at_name} > t_{updated_at_name}"
1076-
),
1091+
.where(f"{updated_at_name} > t_{updated_at_name}"),
10771092
)
10781093
.select("*")
10791094
.from_("static")
@@ -1097,18 +1112,15 @@ def merge(
10971112
target_table: TableName,
10981113
source_table: QueryOrDF,
10991114
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
1100-
unique_key: t.Sequence[str],
1115+
unique_key: t.Sequence[exp.Expression],
11011116
) -> None:
11021117
source_queries, columns_to_types = self._get_source_queries_and_columns_to_types(
11031118
source_table, columns_to_types, target_table=target_table
11041119
)
11051120
columns_to_types = columns_to_types or self.columns(target_table)
11061121
on = exp.and_(
11071122
*(
1108-
exp.EQ(
1109-
this=exp.column(part, MERGE_TARGET_ALIAS),
1110-
expression=exp.column(part, MERGE_SOURCE_ALIAS),
1111-
)
1123+
add_table(part, MERGE_TARGET_ALIAS).eq(add_table(part, MERGE_SOURCE_ALIAS))
11121124
for part in unique_key
11131125
)
11141126
)

sqlmesh/core/engine_adapter/mixins.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def merge(
2323
target_table: TableName,
2424
source_table: QueryOrDF,
2525
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
26-
unique_key: t.Sequence[str],
26+
unique_key: t.Sequence[exp.Expression],
2727
) -> None:
2828
"""
2929
Merge implementation for engine adapters that do not support merge natively.

sqlmesh/core/model/common.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import typing as t
44

55
from sqlglot import exp
6-
from sqlglot.helper import seq_get
6+
from sqlglot.helper import ensure_list, seq_get
7+
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
78

89
from sqlmesh.core.dialect import parse, parse_one
910
from sqlmesh.utils import str_to_bool
@@ -45,6 +46,32 @@ def parse_expression(
4546
return v
4647

4748

49+
@field_validator_v1_args
50+
def parse_expressions(cls: t.Type, v: t.Any, values: t.Dict[str, t.Any]) -> t.List[exp.Expression]:
51+
dialect = values.get("dialect")
52+
53+
if isinstance(v, (exp.Tuple, exp.Array)):
54+
expressions: t.List[exp.Expression] = v.expressions
55+
elif isinstance(v, exp.Expression):
56+
expressions = [v]
57+
else:
58+
expressions = [
59+
parse_one(entry, dialect=dialect) if isinstance(entry, str) else entry
60+
for entry in ensure_list(v)
61+
]
62+
63+
results = []
64+
65+
for expr in expressions:
66+
expr = normalize_identifiers(
67+
exp.to_column(expr.name) if isinstance(expr, exp.Identifier) else expr
68+
)
69+
expr.meta["dialect"] = dialect
70+
results.append(expr)
71+
72+
return results
73+
74+
4875
def parse_bool(v: t.Any) -> bool:
4976
if isinstance(v, exp.Boolean):
5077
return v.this
@@ -92,6 +119,7 @@ def parse_properties(cls: t.Type, v: t.Any, values: t.Dict[str, t.Any]) -> t.Opt
92119
"expressions_",
93120
"pre_statements_",
94121
"post_statements_",
122+
"unique_key",
95123
mode="before",
96124
check_fields=False,
97125
)(parse_expression)

sqlmesh/core/model/definition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,7 @@ def _data_hash_values(self) -> t.List[str]:
721721
data.append(self.kind.time_column.column)
722722
data.append(self.kind.time_column.format)
723723
elif isinstance(self.kind, IncrementalByUniqueKeyKind):
724-
data.extend(self.kind.unique_key)
724+
data.extend((k.sql() for k in self.kind.unique_key))
725725

726726
return data # type: ignore
727727

sqlmesh/core/model/kind.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@
99
from sqlglot.time import format_time
1010

1111
from sqlmesh.core import dialect as d
12-
from sqlmesh.core.model.common import parse_properties
12+
from sqlmesh.core.model.common import parse_expressions, parse_properties
1313
from sqlmesh.core.model.seed import CsvSettings
1414
from sqlmesh.utils.errors import ConfigError
1515
from sqlmesh.utils.pydantic import (
1616
PydanticModel,
1717
SQLGlotBool,
18-
SQLGlotListOfStrings,
1918
SQLGlotPositiveInt,
2019
SQLGlotString,
2120
field_validator,
@@ -104,15 +103,7 @@ def model_kind_name(self) -> t.Optional[ModelKindName]:
104103
return self
105104

106105

107-
def _unique_key_validator(v: t.Any) -> t.List[str]:
108-
if isinstance(v, exp.Identifier):
109-
return [v.name]
110-
if isinstance(v, (exp.Tuple, exp.Array)):
111-
return [e.name for e in v.expressions]
112-
return [i.name if isinstance(i, exp.Identifier) else str(i) for i in v]
113-
114-
115-
unique_key_validator = field_validator("unique_key", mode="before")(_unique_key_validator)
106+
_unique_key_validator = field_validator("unique_key", mode="before")(parse_expressions)
116107

117108

118109
class _ModelKind(PydanticModel, ModelKindMixin):
@@ -216,7 +207,8 @@ def to_expression(self, dialect: str = "", **kwargs: t.Any) -> d.ModelKind:
216207

217208
class IncrementalByUniqueKeyKind(_Incremental):
218209
name: Literal[ModelKindName.INCREMENTAL_BY_UNIQUE_KEY] = ModelKindName.INCREMENTAL_BY_UNIQUE_KEY
219-
unique_key: SQLGlotListOfStrings
210+
unique_key: t.List[exp.Expression]
211+
_unique_key_validator = _unique_key_validator
220212

221213

222214
class IncrementalUnmanagedKind(_ModelKind):
@@ -281,13 +273,14 @@ class FullKind(_ModelKind):
281273

282274
class SCDType2Kind(_ModelKind):
283275
name: Literal[ModelKindName.SCD_TYPE_2] = ModelKindName.SCD_TYPE_2
284-
unique_key: SQLGlotListOfStrings
276+
unique_key: t.List[exp.Expression]
285277
valid_from_name: SQLGlotString = "valid_from"
286278
valid_to_name: SQLGlotString = "valid_to"
287279
updated_at_name: SQLGlotString = "updated_at"
288280

289281
forward_only: SQLGlotBool = True
290282
disable_restatement: SQLGlotBool = True
283+
_unique_key_validator = _unique_key_validator
291284

292285
@property
293286
def managed_columns(self) -> t.Dict[str, exp.DataType]:

sqlmesh/core/model/meta.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
99

1010
from sqlmesh.core import dialect as d
11-
from sqlmesh.core.model.common import bool_validator, properties_validator
11+
from sqlmesh.core.model.common import (
12+
bool_validator,
13+
parse_expressions,
14+
properties_validator,
15+
)
1216
from sqlmesh.core.model.kind import (
1317
IncrementalByUniqueKeyKind,
1418
ModelKind,
@@ -139,27 +143,9 @@ def _string_validator(cls, v: t.Any) -> t.Optional[str]:
139143
def _partition_by_validator(
140144
cls, v: t.Any, values: t.Dict[str, t.Any]
141145
) -> t.List[exp.Expression]:
142-
dialect = values.get("dialect")
143-
144-
if isinstance(v, (exp.Tuple, exp.Array)):
145-
partitions: t.List[exp.Expression] = v.expressions
146-
elif isinstance(v, exp.Expression):
147-
partitions = [v]
148-
else:
149-
partitions = [
150-
d.parse_one(entry, dialect=dialect) if isinstance(entry, str) else entry
151-
for entry in ensure_list(v)
152-
]
153-
154-
partitions = [
155-
normalize_identifiers(
156-
exp.to_column(expr.name) if isinstance(expr, exp.Identifier) else expr
157-
)
158-
for expr in partitions
159-
]
146+
partitions = parse_expressions(cls, v, values)
160147

161148
for partition in partitions:
162-
partition.meta["dialect"] = dialect
163149
num_cols = len(list(partition.find_all(exp.Column)))
164150

165151
error_msg: t.Optional[str] = None
@@ -289,7 +275,7 @@ def time_column(self) -> t.Optional[TimeColumn]:
289275
return getattr(self.kind, "time_column", None)
290276

291277
@property
292-
def unique_key(self) -> t.List[str]:
278+
def unique_key(self) -> t.List[exp.Expression]:
293279
if isinstance(self.kind, (IncrementalByUniqueKeyKind, SCDType2Kind)):
294280
return self.kind.unique_key
295281
return []

0 commit comments

Comments
 (0)