Skip to content

Commit dfaeffd

Browse files
authored
feat: support when matched expression for merge (#1569)
1 parent 074a928 commit dfaeffd

10 files changed

Lines changed: 211 additions & 16 deletions

File tree

docs/concepts/models/model_kinds.md

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ If a key is missing in the model's table, the new data row is inserted; otherwis
115115
* There is at most one record associated with each unique key.
116116
* It is appropriate to upsert records, so existing records can be overridden by new arrivals when their keys match.
117117

118-
A [Slowly Changing Dimension](../glossary.md#slowly-changing-dimension-scd) (SCD) is one approach that fits this description well.
118+
A [Slowly Changing Dimension](../glossary.md#slowly-changing-dimension-scd) (SCD) is one approach that fits this description well. See the [SCD Type 2](#scd-type-2) model kind for a specific model kind for SCD Type 2 models.
119119

120120
The name of the unique key column must be provided as part of the `MODEL` DDL, as in this example:
121121
```sql linenums="1" hl_lines="3-5"
@@ -156,6 +156,43 @@ WHERE
156156

157157
**Note:** Models of the `INCREMENTAL_BY_UNIQUE_KEY` kind are inherently [non-idempotent](../glossary.md#idempotency), which should be taken into consideration during data [restatement](../plans.md#restatement-plans).
158158

159+
### Unique Key Expressions
160+
161+
The `unique_key` values can either be column names or SQL expressions. For example, if you wanted to create a key that is based on the coalesce of a value then you could do the following:
162+
163+
```sql linenums="1" hl_lines="4"
164+
MODEL (
165+
name db.employees,
166+
kind INCREMENTAL_BY_UNIQUE_KEY (
167+
unique_key (COALESCE("ds", ''))
168+
)
169+
);
170+
```
171+
172+
### When Matched Expression
173+
174+
The logic to use when updating columns when a match occurs (the source and target match on the given keys) by default updates all the columns. This can be overriden with custom logic like below:
175+
176+
```sql linenums="1" hl_lines="4"
177+
MODEL (
178+
name db.employees,
179+
kind INCREMENTAL_BY_UNIQUE_KEY (
180+
unique_key name,
181+
when_matched WHEN MATCHED THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary)
182+
)
183+
);
184+
```
185+
186+
The `source` and `target` aliases are required when using the `when_matched` expression in order to distinguish between the source and target columns.
187+
188+
**Note**: `when_matched` is only available on engines that support the `MERGE` statement. Currently supported engines include:
189+
190+
* BigQuery
191+
* Databricks
192+
* Postgres
193+
* Snowflake
194+
* Spark
195+
159196
### Materialization strategy
160197
Depending on the target engine, models of the `INCREMENTAL_BY_UNIQUE_KEY` kind are materialized using the following strategies:
161198

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
"requests",
4747
"rich[jupyter]",
4848
"ruamel.yaml",
49-
"sqlglot~=18.12.0",
49+
"sqlglot~=18.13.0",
5050
],
5151
extras_require={
5252
"bigquery": [

sqlmesh/core/dialect.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,15 +294,17 @@ def _parse_props(self: Parser) -> t.Optional[exp.Expression]:
294294
if not key:
295295
return None
296296

297-
if self._match(TokenType.L_PAREN):
298-
value: t.Optional[exp.Expression] = self.expression(
297+
name = key.name.lower()
298+
if name == "when_matched":
299+
value: t.Optional[exp.Expression] = self._parse_when_matched()[0]
300+
elif self._match(TokenType.L_PAREN):
301+
value = self.expression(
299302
exp.Tuple, expressions=self._parse_csv(lambda: _parse_prop_value(self))
300303
)
301304
self._match_r_paren()
302305
else:
303306
value = self._parse_bracket(self._parse_field(any_token=True))
304307

305-
name = key.name.lower()
306308
if name == "path" and value:
307309
# Make sure if we get a windows path that it is converted to posix
308310
value = exp.Literal.string(value.this.replace("\\", "/"))

sqlmesh/core/engine_adapter/base.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,6 +1159,7 @@ def merge(
11591159
source_table: QueryOrDF,
11601160
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
11611161
unique_key: t.Sequence[exp.Expression],
1162+
when_matched: t.Optional[exp.When] = None,
11621163
) -> None:
11631164
source_queries, columns_to_types = self._get_source_queries_and_columns_to_types(
11641165
source_table, columns_to_types, target_table=target_table
@@ -1170,16 +1171,17 @@ def merge(
11701171
for part in unique_key
11711172
)
11721173
)
1173-
when_matched = exp.When(
1174-
matched=True,
1175-
source=False,
1176-
then=exp.Update(
1177-
expressions=[
1178-
exp.column(col, MERGE_TARGET_ALIAS).eq(exp.column(col, MERGE_SOURCE_ALIAS))
1179-
for col in columns_to_types
1180-
],
1181-
),
1182-
)
1174+
if not when_matched:
1175+
when_matched = exp.When(
1176+
matched=True,
1177+
source=False,
1178+
then=exp.Update(
1179+
expressions=[
1180+
exp.column(col, MERGE_TARGET_ALIAS).eq(exp.column(col, MERGE_SOURCE_ALIAS))
1181+
for col in columns_to_types
1182+
],
1183+
),
1184+
)
11831185
when_not_matched = exp.When(
11841186
matched=False,
11851187
source=False,

sqlmesh/core/engine_adapter/mixins.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def merge(
2424
source_table: QueryOrDF,
2525
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
2626
unique_key: t.Sequence[exp.Expression],
27+
when_matched: t.Optional[exp.When] = None,
2728
) -> None:
2829
"""
2930
Merge implementation for engine adapters that do not support merge natively.
@@ -35,6 +36,10 @@ def merge(
3536
within the temporary table are ommitted.
3637
4. Drop the temporary table.
3738
"""
39+
if when_matched:
40+
raise SQLMeshError(
41+
"This engine does not support MERGE expressions and therefore `when_matched` is not supported."
42+
)
3843
if columns_to_types is None:
3944
columns_to_types = self.columns(target_table)
4045

sqlmesh/core/model/kind.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
SQLGlotPositiveInt,
1919
SQLGlotString,
2020
field_validator,
21+
field_validator_v1_args,
2122
model_validator,
2223
model_validator_v1_args,
2324
)
@@ -208,8 +209,39 @@ def to_expression(self, dialect: str = "", **kwargs: t.Any) -> d.ModelKind:
208209
class IncrementalByUniqueKeyKind(_Incremental):
209210
name: Literal[ModelKindName.INCREMENTAL_BY_UNIQUE_KEY] = ModelKindName.INCREMENTAL_BY_UNIQUE_KEY
210211
unique_key: t.List[exp.Expression]
212+
when_matched: t.Optional[exp.When] = None
213+
211214
_unique_key_validator = _unique_key_validator
212215

216+
@field_validator("when_matched", mode="before")
217+
@field_validator_v1_args
218+
def _when_matched_validator(
219+
cls, v: t.Optional[exp.When], values: t.Dict[str, t.Any]
220+
) -> t.Optional[exp.When]:
221+
def replace_table_references(expression: exp.Expression) -> exp.Expression:
222+
from sqlmesh.core.engine_adapter.base import (
223+
MERGE_SOURCE_ALIAS,
224+
MERGE_TARGET_ALIAS,
225+
)
226+
227+
if isinstance(expression, exp.Column):
228+
if expression.table.lower() == "target":
229+
expression.set(
230+
"table",
231+
exp.to_identifier(MERGE_TARGET_ALIAS),
232+
)
233+
elif expression.table.lower() == "source":
234+
expression.set(
235+
"table",
236+
exp.to_identifier(MERGE_SOURCE_ALIAS),
237+
)
238+
return expression
239+
240+
if not v:
241+
return v
242+
v.meta["dialect"] = values.get("dialect")
243+
return v.transform(replace_table_references)
244+
213245

214246
class IncrementalUnmanagedKind(_ModelKind):
215247
name: Literal[ModelKindName.INCREMENTAL_UNMANAGED] = ModelKindName.INCREMENTAL_UNMANAGED

sqlmesh/core/model/meta.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,3 +354,9 @@ def _partition_by_columns(self) -> t.List[exp.Column]:
354354
@property
355355
def managed_columns(self) -> t.Dict[str, exp.DataType]:
356356
return getattr(self.kind, "managed_columns", {})
357+
358+
@property
359+
def when_matched(self) -> t.Optional[exp.When]:
360+
if isinstance(self.kind, IncrementalByUniqueKeyKind):
361+
return self.kind.when_matched
362+
return None

sqlmesh/core/snapshot/evaluator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,7 @@ def insert(
925925
query_or_df,
926926
columns_to_types=model.columns_to_types,
927927
unique_key=model.unique_key,
928+
when_matched=model.when_matched,
928929
)
929930

930931
def append(
@@ -942,6 +943,7 @@ def append(
942943
query_or_df,
943944
columns_to_types=model.columns_to_types,
944945
unique_key=model.unique_key,
946+
when_matched=model.when_matched,
945947
)
946948

947949

tests/core/engine_adapter/test_base.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,53 @@ def test_merge_upsert_pandas(make_mocked_engine_adapter: t.Callable):
814814
)
815815

816816

817+
def test_merge_when_matched(make_mocked_engine_adapter: t.Callable, assert_exp_eq):
818+
adapter = make_mocked_engine_adapter(EngineAdapter)
819+
820+
adapter.merge(
821+
target_table="target",
822+
source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')),
823+
columns_to_types={
824+
"ID": exp.DataType.Type.INT,
825+
"ts": exp.DataType.Type.TIMESTAMP,
826+
"val": exp.DataType.Type.INT,
827+
},
828+
unique_key=[exp.to_identifier("ID", quoted=True)],
829+
when_matched=exp.When(
830+
matched=True,
831+
source=False,
832+
then=exp.Update(
833+
expressions=[
834+
exp.column("val", "__MERGE_TARGET__").eq(exp.column("val", "__MERGE_SOURCE__")),
835+
exp.column("ts", "__MERGE_TARGET__").eq(
836+
exp.Coalesce(
837+
this=exp.column("ts", "__MERGE_SOURCE__"),
838+
expressions=[exp.column("ts", "__MERGE_TARGET__")],
839+
)
840+
),
841+
],
842+
),
843+
),
844+
)
845+
846+
assert_exp_eq(
847+
adapter.cursor.execute.call_args[0][0],
848+
"""
849+
MERGE INTO "target" AS "__MERGE_TARGET__" USING (
850+
SELECT
851+
"ID",
852+
"ts",
853+
"val"
854+
FROM "source"
855+
) AS "__MERGE_SOURCE__"
856+
ON "__MERGE_TARGET__"."ID" = "__MERGE_SOURCE__"."ID"
857+
WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."val" = "__MERGE_SOURCE__"."val", "__MERGE_TARGET__"."ts" = COALESCE("__MERGE_SOURCE__"."ts", "__MERGE_TARGET__"."ts")
858+
WHEN NOT MATCHED THEN INSERT ("ID", "ts", "val")
859+
VALUES ("__MERGE_SOURCE__"."ID", "__MERGE_SOURCE__"."ts", "__MERGE_SOURCE__"."val")
860+
""",
861+
)
862+
863+
817864
def test_scd_type_2(make_mocked_engine_adapter: t.Callable):
818865
adapter = make_mocked_engine_adapter(EngineAdapter)
819866

tests/core/test_snapshot_evaluator.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
from sqlmesh.core.audit import StandaloneAudit
1010
from sqlmesh.core.dialect import to_schema
1111
from sqlmesh.core.engine_adapter import EngineAdapter, create_engine_adapter
12-
from sqlmesh.core.engine_adapter.base import InsertOverwriteStrategy
12+
from sqlmesh.core.engine_adapter.base import (
13+
MERGE_SOURCE_ALIAS,
14+
MERGE_TARGET_ALIAS,
15+
InsertOverwriteStrategy,
16+
)
1317
from sqlmesh.core.environment import EnvironmentNamingInfo
1418
from sqlmesh.core.macros import RuntimeStage, macro
1519
from sqlmesh.core.model import (
@@ -1009,6 +1013,64 @@ def test_insert_into_scd_type_2(adapter_mock, make_snapshot):
10091013
)
10101014

10111015

1016+
def test_create_incremental_by_unique_key_updated_at_exp(adapter_mock, make_snapshot):
1017+
evaluator = SnapshotEvaluator(adapter_mock)
1018+
model = load_sql_based_model(
1019+
parse( # type: ignore
1020+
"""
1021+
MODEL (
1022+
name test_schema.test_model,
1023+
kind INCREMENTAL_BY_UNIQUE_KEY (
1024+
unique_key [id],
1025+
when_matched WHEN MATCHED THEN UPDATE SET target.name = source.name, target.updated_at = COALESCE(source.updated_at, target.updated_at)
1026+
)
1027+
);
1028+
1029+
SELECT id::int, name::string, updated_at::timestamp FROM tbl;
1030+
"""
1031+
)
1032+
)
1033+
1034+
snapshot = make_snapshot(model)
1035+
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)
1036+
1037+
evaluator.evaluate(
1038+
snapshot,
1039+
"2020-01-01",
1040+
"2020-01-02",
1041+
"2020-01-02",
1042+
snapshots={},
1043+
)
1044+
1045+
adapter_mock.merge.assert_called_once_with(
1046+
snapshot.table_name(),
1047+
model.render_query(),
1048+
columns_to_types={
1049+
"id": exp.DataType.build("INT"),
1050+
"name": exp.DataType.build("STRING"),
1051+
"updated_at": exp.DataType.build("TIMESTAMP"),
1052+
},
1053+
unique_key=[exp.to_column("id")],
1054+
when_matched=exp.When(
1055+
matched=True,
1056+
source=False,
1057+
then=exp.Update(
1058+
expressions=[
1059+
exp.column("name", MERGE_TARGET_ALIAS).eq(
1060+
exp.column("name", MERGE_SOURCE_ALIAS)
1061+
),
1062+
exp.column("updated_at", MERGE_TARGET_ALIAS).eq(
1063+
exp.Coalesce(
1064+
this=exp.column("updated_at", MERGE_SOURCE_ALIAS),
1065+
expressions=[exp.column("updated_at", MERGE_TARGET_ALIAS)],
1066+
)
1067+
),
1068+
],
1069+
),
1070+
),
1071+
)
1072+
1073+
10121074
def test_standalone_audit(mocker: MockerFixture, adapter_mock, make_snapshot):
10131075
evaluator = SnapshotEvaluator(adapter_mock)
10141076

0 commit comments

Comments
 (0)