Skip to content

Commit d3c950d

Browse files
authored
Fix: ensure we keep identifiers in unique keys (#1537)
1 parent 15c8788 commit d3c950d

5 files changed

Lines changed: 55 additions & 23 deletions

File tree

sqlmesh/core/engine_adapter/base.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,10 +1130,7 @@ def merge(
11301130
source=False,
11311131
then=exp.Update(
11321132
expressions=[
1133-
exp.EQ(
1134-
this=exp.column(col, MERGE_TARGET_ALIAS),
1135-
expression=exp.column(col, MERGE_SOURCE_ALIAS),
1136-
)
1133+
exp.column(col, MERGE_TARGET_ALIAS).eq(exp.column(col, MERGE_SOURCE_ALIAS))
11371134
for col in columns_to_types
11381135
],
11391136
),

sqlmesh/core/model/common.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,7 @@ def parse_expressions(cls: t.Type, v: t.Any, values: t.Dict[str, t.Any]) -> t.Li
6363
results = []
6464

6565
for expr in expressions:
66-
expr = normalize_identifiers(
67-
exp.to_column(expr.name) if isinstance(expr, exp.Identifier) else expr
68-
)
66+
expr = normalize_identifiers(exp.column(expr) if isinstance(expr, exp.Identifier) else expr)
6967
expr.meta["dialect"] = dialect
7068
results.append(expr)
7169

sqlmesh/core/model/definition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def render_query(
245245
"""
246246
return exp.select(
247247
*(
248-
exp.cast(exp.Null(), column_type, copy=False).as_(name, copy=False)
248+
exp.cast(exp.Null(), column_type, copy=False).as_(name, copy=False, quoted=True)
249249
for name, column_type in (self.columns_to_types or {}).items()
250250
),
251251
copy=False,

tests/core/engine_adapter/test_base.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -722,23 +722,37 @@ def table_columns(table_name: str) -> t.Dict[str, exp.DataType]:
722722
adapter.cursor.execute.assert_has_calls([call(x) for x in expected])
723723

724724

725-
def test_merge_upsert(make_mocked_engine_adapter: t.Callable):
725+
def test_merge_upsert(make_mocked_engine_adapter: t.Callable, assert_exp_eq):
726726
adapter = make_mocked_engine_adapter(EngineAdapter)
727727

728728
adapter.merge(
729729
target_table="target",
730-
source_table=t.cast(exp.Select, parse_one("SELECT id, ts, val FROM source")),
730+
source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')),
731731
columns_to_types={
732-
"id": exp.DataType.Type.INT,
732+
"ID": exp.DataType.Type.INT,
733733
"ts": exp.DataType.Type.TIMESTAMP,
734734
"val": exp.DataType.Type.INT,
735735
},
736-
unique_key=[exp.to_identifier("id")],
736+
unique_key=[exp.to_identifier("ID", quoted=True)],
737737
)
738-
adapter.cursor.execute.assert_called_once_with(
739-
'MERGE INTO "target" AS "__MERGE_TARGET__" USING (SELECT "id", "ts", "val" FROM "source") AS "__MERGE_SOURCE__" ON "__MERGE_TARGET__"."id" = "__MERGE_SOURCE__"."id" '
740-
'WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."id" = "__MERGE_SOURCE__"."id", "__MERGE_TARGET__"."ts" = "__MERGE_SOURCE__"."ts", "__MERGE_TARGET__"."val" = "__MERGE_SOURCE__"."val" '
741-
'WHEN NOT MATCHED THEN INSERT ("id", "ts", "val") VALUES ("__MERGE_SOURCE__"."id", "__MERGE_SOURCE__"."ts", "__MERGE_SOURCE__"."val")'
738+
739+
assert_exp_eq(
740+
adapter.cursor.execute.call_args[0][0],
741+
"""
742+
MERGE INTO "target" AS "__MERGE_TARGET__" USING (
743+
SELECT
744+
"ID",
745+
"ts",
746+
"val"
747+
FROM "source"
748+
) AS "__MERGE_SOURCE__"
749+
ON "__MERGE_TARGET__"."ID" = "__MERGE_SOURCE__"."ID"
750+
WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."ID" = "__MERGE_SOURCE__"."ID",
751+
"__MERGE_TARGET__"."ts" = "__MERGE_SOURCE__"."ts",
752+
"__MERGE_TARGET__"."val" = "__MERGE_SOURCE__"."val"
753+
WHEN NOT MATCHED THEN INSERT ("ID", "ts", "val")
754+
VALUES ("__MERGE_SOURCE__"."ID", "__MERGE_SOURCE__"."ts", "__MERGE_SOURCE__"."val")
755+
""",
742756
)
743757

744758
adapter.cursor.reset_mock()

tests/core/test_model.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,16 +1248,30 @@ def test_parse(assert_exp_eq):
12481248
CONST = "bar"
12491249

12501250

1251-
def test_python_model_deps() -> None:
1252-
@model(name="my_model", kind="full", columns={"foo": "int"})
1251+
def test_python_model(assert_exp_eq) -> None:
1252+
@model(name="my_model", kind="full", columns={'"COL"': "int"})
12531253
def my_model(context, **kwargs):
12541254
context.table("foo")
12551255
context.table(model_name=CONST + ".baz")
12561256

1257-
assert model.get_registry()["my_model"].model(
1257+
m = model.get_registry()["my_model"].model(
12581258
module_path=Path("."),
12591259
path=Path("."),
1260-
).depends_on == {"foo", "bar.baz"}
1260+
)
1261+
1262+
assert m.depends_on == {"foo", "bar.baz"}
1263+
assert m.columns_to_types == {"COL": exp.DataType.build("int")}
1264+
assert_exp_eq(
1265+
m.ctas_query(),
1266+
"""
1267+
SELECT
1268+
CAST(NULL AS INT) AS "COL"
1269+
FROM (VALUES
1270+
(1)) AS t(dummy)
1271+
WHERE
1272+
FALSE
1273+
""",
1274+
)
12611275

12621276

12631277
def test_python_models_returning_sql(assert_exp_eq) -> None:
@@ -2163,11 +2177,11 @@ def test_scd_type_2_defaults():
21632177
MODEL (
21642178
name db.table,
21652179
kind SCD_TYPE_2 (
2166-
unique_key id,
2180+
unique_key "ID",
21672181
),
21682182
);
21692183
SELECT
2170-
1 as id,
2184+
1 as "ID",
21712185
'2020-01-01' as ds,
21722186
'2020-01-01' as test_updated_at,
21732187
'2020-01-01' as test_valid_from,
@@ -2176,7 +2190,16 @@ def test_scd_type_2_defaults():
21762190
"""
21772191
)
21782192
scd_type_2_model = load_sql_based_model(view_model_expressions)
2179-
assert scd_type_2_model.unique_key == [exp.to_column("id")]
2193+
assert scd_type_2_model.unique_key == [exp.to_column("ID", quoted=True)]
2194+
assert scd_type_2_model.columns_to_types == {
2195+
"ID": exp.DataType.build("int"),
2196+
"ds": exp.DataType.build("varchar"),
2197+
"test_updated_at": exp.DataType.build("varchar"),
2198+
"test_valid_from": exp.DataType.build("varchar"),
2199+
"test_valid_to": exp.DataType.build("varchar"),
2200+
"valid_from": exp.DataType.build("TIMESTAMP"),
2201+
"valid_to": exp.DataType.build("TIMESTAMP"),
2202+
}
21802203
assert scd_type_2_model.managed_columns == {
21812204
"valid_from": exp.DataType.build("TIMESTAMP"),
21822205
"valid_to": exp.DataType.build("TIMESTAMP"),

0 commit comments

Comments
 (0)