Skip to content

Commit 3c1a44a

Browse files
authored
Fix!: Change how partitioned_by is parsed so that partition expressions with specialized AST nodes are captured (#4224)
1 parent af974f8 commit 3c1a44a

7 files changed

Lines changed: 331 additions & 4 deletions

File tree

sqlmesh/core/dialect.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,12 @@ def parse(self: Parser) -> t.Optional[exp.Expression]:
610610
value = self.expression(ModelKind, this=kind.value, expressions=props)
611611
elif key == "expression":
612612
value = self._parse_conjunction()
613+
elif key == "partitioned_by":
614+
partitioned_by = self._parse_partitioned_by()
615+
if isinstance(partitioned_by.this, exp.Schema):
616+
value = exp.tuple_(*partitioned_by.this.expressions)
617+
else:
618+
value = partitioned_by.this
613619
else:
614620
value = self._parse_bracket(self._parse_field(any_token=True))
615621

sqlmesh/core/model/meta.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing_extensions import Self
66

77
from pydantic import Field
8-
from sqlglot import Dialect, exp
8+
from sqlglot import Dialect, exp, parse_one
99
from sqlglot.helper import ensure_collection, ensure_list
1010
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
1111

@@ -39,6 +39,7 @@
3939
field_validator,
4040
list_of_fields_validator,
4141
model_validator,
42+
get_dialect,
4243
)
4344

4445
if t.TYPE_CHECKING:
@@ -182,6 +183,22 @@ def _gateway_validator(cls, v: t.Any) -> t.Optional[str]:
182183
def _partition_and_cluster_validator(
183184
cls, v: t.Any, info: ValidationInfo
184185
) -> t.List[exp.Expression]:
186+
if (
187+
isinstance(v, list)
188+
and all(isinstance(i, str) for i in v)
189+
and info.field_name == "partitioned_by_"
190+
):
191+
# this branch gets hit when we are deserializing from json because `partitioned_by` is stored as a List[str]
192+
# however, we should only invoke this if the list contains strings because this validator is also
193+
# called by Python models which might pass a List[exp.Expression]
194+
string_to_parse = (
195+
f"({','.join(v)})" # recreate the (a, b, c) part of "partitioned_by (a, b, c)"
196+
)
197+
parsed = parse_one(
198+
string_to_parse, into=exp.PartitionedByProperty, dialect=get_dialect(info)
199+
)
200+
v = parsed.this.expressions if isinstance(parsed.this, exp.Schema) else v
201+
185202
expressions = list_of_fields_validator(v, info.data)
186203

187204
for expression in expressions:
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""Remove superfluous exp.Paren references from partitioned_by"""
2+
3+
import json
4+
5+
import pandas as pd
6+
from sqlglot import exp
7+
8+
from sqlmesh.utils.migration import index_text_type
9+
from sqlmesh.utils.migration import blob_text_type
10+
11+
12+
def migrate(state_sync, **kwargs): # type: ignore
13+
engine_adapter = state_sync.engine_adapter
14+
schema = state_sync.schema
15+
snapshots_table = "_snapshots"
16+
index_type = index_text_type(engine_adapter.dialect)
17+
if schema:
18+
snapshots_table = f"{schema}.{snapshots_table}"
19+
20+
new_snapshots = []
21+
updated = False
22+
23+
for (
24+
name,
25+
identifier,
26+
version,
27+
snapshot,
28+
kind_name,
29+
updated_ts,
30+
unpaused_ts,
31+
ttl_ms,
32+
unrestorable,
33+
) in engine_adapter.fetchall(
34+
exp.select(
35+
"name",
36+
"identifier",
37+
"version",
38+
"snapshot",
39+
"kind_name",
40+
"updated_ts",
41+
"unpaused_ts",
42+
"ttl_ms",
43+
"unrestorable",
44+
).from_(snapshots_table),
45+
quote_identifiers=True,
46+
):
47+
parsed_snapshot = json.loads(snapshot)
48+
49+
if partitioned_by := parsed_snapshot["node"].get("partitioned_by"):
50+
new_partitioned_by = []
51+
for item in partitioned_by:
52+
# rewrite '(foo)' to 'foo'
53+
if item.startswith("(") and item.endswith(")"):
54+
item = item[1:-1]
55+
updated = True
56+
new_partitioned_by.append(item)
57+
parsed_snapshot["node"]["partitioned_by"] = new_partitioned_by
58+
59+
new_snapshots.append(
60+
{
61+
"name": name,
62+
"identifier": identifier,
63+
"version": version,
64+
"snapshot": json.dumps(parsed_snapshot),
65+
"kind_name": kind_name,
66+
"updated_ts": updated_ts,
67+
"unpaused_ts": unpaused_ts,
68+
"ttl_ms": ttl_ms,
69+
"unrestorable": unrestorable,
70+
}
71+
)
72+
73+
if new_snapshots and updated:
74+
engine_adapter.delete_from(snapshots_table, "TRUE")
75+
blob_type = blob_text_type(engine_adapter.dialect)
76+
77+
engine_adapter.insert_append(
78+
snapshots_table,
79+
pd.DataFrame(new_snapshots),
80+
columns_to_types={
81+
"name": exp.DataType.build(index_type),
82+
"identifier": exp.DataType.build(index_type),
83+
"version": exp.DataType.build(index_type),
84+
"snapshot": exp.DataType.build(blob_type),
85+
"kind_name": exp.DataType.build(index_type),
86+
"updated_ts": exp.DataType.build("bigint"),
87+
"unpaused_ts": exp.DataType.build("bigint"),
88+
"ttl_ms": exp.DataType.build("bigint"),
89+
"unrestorable": exp.DataType.build("boolean"),
90+
},
91+
)

tests/core/engine_adapter/test_athena.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,3 +435,51 @@ def test_drop_partitions_from_metastore_uses_batches(
435435
# third call 50-62
436436
assert calls[2][1]["PartitionsToDelete"][0]["Values"][0] == "50"
437437
assert calls[2][1]["PartitionsToDelete"][-1]["Values"][0] == "62"
438+
439+
440+
def test_iceberg_partition_transforms(adapter: AthenaEngineAdapter):
441+
expressions = d.parse(
442+
"""
443+
MODEL (
444+
name test_table,
445+
kind FULL,
446+
table_format iceberg,
447+
partitioned_by (month(business_date), bucket(4, colb), colc)
448+
);
449+
450+
SELECT 1::timestamp AS business_date, 2::varchar as colb, 'foo' as colc;
451+
"""
452+
)
453+
model: SqlModel = t.cast(SqlModel, load_sql_based_model(expressions))
454+
455+
assert model.partitioned_by == [
456+
exp.Month(this=exp.column("business_date", quoted=True)),
457+
exp.PartitionedByBucket(
458+
this=exp.column("colb", quoted=True), expression=exp.Literal.number(4)
459+
),
460+
exp.column("colc", quoted=True),
461+
]
462+
463+
adapter.s3_warehouse_location = "s3://bucket/prefix/"
464+
465+
adapter.create_table(
466+
table_name=model.name,
467+
columns_to_types=model.columns_to_types_or_raise,
468+
partitioned_by=model.partitioned_by,
469+
table_format=model.table_format,
470+
)
471+
472+
adapter.ctas(
473+
table_name=model.name,
474+
columns_to_types=model.columns_to_types_or_raise,
475+
partitioned_by=model.partitioned_by,
476+
query_or_df=model.ctas_query(),
477+
table_format=model.table_format,
478+
)
479+
480+
assert to_sql_calls(adapter) == [
481+
# Hive syntax - create table
482+
"""CREATE TABLE IF NOT EXISTS `test_table` (`business_date` TIMESTAMP, `colb` STRING, `colc` STRING) PARTITIONED BY (MONTH(`business_date`), BUCKET(4, `colb`), `colc`) LOCATION 's3://bucket/prefix/test_table/' TBLPROPERTIES ('table_type'='iceberg')""",
483+
# Trino syntax - CTAS
484+
"""CREATE TABLE IF NOT EXISTS "test_table" WITH (table_type='iceberg', partitioning=ARRAY['MONTH(business_date)', 'BUCKET(colb, 4)', 'colc'], location='s3://bucket/prefix/test_table/', is_external=false) AS SELECT CAST("business_date" AS TIMESTAMP) AS "business_date", CAST("colb" AS VARCHAR) AS "colb", CAST("colc" AS VARCHAR) AS "colc" FROM (SELECT CAST(1 AS TIMESTAMP) AS "business_date", CAST(2 AS VARCHAR) AS "colb", 'foo' AS "colc" LIMIT 0) AS "_subquery\"""",
485+
]

tests/core/test_context.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1978,19 +1978,22 @@ def test_plan_audit_intervals(tmp_path: pathlib.Path, capsys, caplog):
19781978
)
19791979
)
19801980

1981-
ctx.plan(
1981+
plan = ctx.plan(
19821982
environment="dev", auto_apply=True, no_prompts=True, start="2025-02-01", end="2025-02-01"
19831983
)
19841984

1985+
date_snapshot = next(s for s in plan.new_snapshots if "date_example" in s.name)
1986+
timestamp_snapshot = next(s for s in plan.new_snapshots if "timestamp_example" in s.name)
1987+
19851988
# Case 1: The timestamp audit should be in the inclusive range ['2025-02-01 00:00:00', '2025-02-01 23:59:59.999999']
19861989
assert (
1987-
"""SELECT COUNT(*) FROM (SELECT ("timestamp_id") AS "timestamp_id" FROM (SELECT * FROM "sqlmesh__sqlmesh_audit"."sqlmesh_audit__timestamp_example__2797548448" AS "sqlmesh_audit__timestamp_example__2797548448" WHERE "timestamp_id" BETWEEN CAST('2025-02-01 00:00:00' AS TIMESTAMP) AND CAST('2025-02-01 23:59:59.999999' AS TIMESTAMP)) AS "_q_0" WHERE TRUE GROUP BY ("timestamp_id") HAVING COUNT(*) > 1) AS "audit\""""
1990+
f"""SELECT COUNT(*) FROM (SELECT ("timestamp_id") AS "timestamp_id" FROM (SELECT * FROM "sqlmesh__sqlmesh_audit"."sqlmesh_audit__timestamp_example__{timestamp_snapshot.version}" AS "sqlmesh_audit__timestamp_example__{timestamp_snapshot.version}" WHERE "timestamp_id" BETWEEN CAST('2025-02-01 00:00:00' AS TIMESTAMP) AND CAST('2025-02-01 23:59:59.999999' AS TIMESTAMP)) AS "_q_0" WHERE TRUE GROUP BY ("timestamp_id") HAVING COUNT(*) > 1) AS "audit\""""
19881991
in caplog.text
19891992
)
19901993

19911994
# Case 2: The date audit should be in the inclusive range ['2025-02-01', '2025-02-01']
19921995
assert (
1993-
"""SELECT COUNT(*) FROM (SELECT ("date_id") AS "date_id" FROM (SELECT * FROM "sqlmesh__sqlmesh_audit"."sqlmesh_audit__date_example__4100277424" AS "sqlmesh_audit__date_example__4100277424" WHERE "date_id" BETWEEN CAST('2025-02-01' AS DATE) AND CAST('2025-02-01' AS DATE)) AS "_q_0" WHERE TRUE GROUP BY ("date_id") HAVING COUNT(*) > 1) AS "audit\""""
1996+
f"""SELECT COUNT(*) FROM (SELECT ("date_id") AS "date_id" FROM (SELECT * FROM "sqlmesh__sqlmesh_audit"."sqlmesh_audit__date_example__{date_snapshot.version}" AS "sqlmesh_audit__date_example__{date_snapshot.version}" WHERE "date_id" BETWEEN CAST('2025-02-01' AS DATE) AND CAST('2025-02-01' AS DATE)) AS "_q_0" WHERE TRUE GROUP BY ("date_id") HAVING COUNT(*) > 1) AS "audit\""""
19941997
in caplog.text
19951998
)
19961999

tests/core/test_model.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,6 +1514,134 @@ def test_render_definition_with_defaults():
15141514
) == d.format_model_expressions(expected_expressions)
15151515

15161516

1517+
def test_render_definition_partitioned_by():
1518+
# no parenthesis in definition, no parenthesis when rendered
1519+
model = load_sql_based_model(
1520+
d.parse(
1521+
f"""
1522+
MODEL (
1523+
name db.table,
1524+
kind FULL,
1525+
partitioned_by a
1526+
);
1527+
1528+
select 1 as a;
1529+
"""
1530+
)
1531+
)
1532+
1533+
assert model.partitioned_by == [exp.column("a", quoted=True)]
1534+
assert (
1535+
model.render_definition()[0].sql(pretty=True)
1536+
== """MODEL (
1537+
name db.table,
1538+
kind FULL,
1539+
partitioned_by "a"
1540+
)"""
1541+
)
1542+
1543+
# single column wrapped in parenthesis in defintion, no parenthesis in rendered
1544+
model = load_sql_based_model(
1545+
d.parse(
1546+
f"""
1547+
MODEL (
1548+
name db.table,
1549+
kind FULL,
1550+
partitioned_by (a)
1551+
);
1552+
1553+
select 1 as a;
1554+
"""
1555+
)
1556+
)
1557+
1558+
assert model.partitioned_by == [exp.column("a", quoted=True)]
1559+
assert (
1560+
model.render_definition()[0].sql(pretty=True)
1561+
== """MODEL (
1562+
name db.table,
1563+
kind FULL,
1564+
partitioned_by "a"
1565+
)"""
1566+
)
1567+
1568+
# multiple columns wrapped in parenthesis in definition, parenthesis in rendered
1569+
model = load_sql_based_model(
1570+
d.parse(
1571+
f"""
1572+
MODEL (
1573+
name db.table,
1574+
kind FULL,
1575+
partitioned_by (a, b)
1576+
);
1577+
1578+
select 1 as a, 2 as b;
1579+
"""
1580+
)
1581+
)
1582+
1583+
assert model.partitioned_by == [exp.column("a", quoted=True), exp.column("b", quoted=True)]
1584+
assert (
1585+
model.render_definition()[0].sql(pretty=True)
1586+
== """MODEL (
1587+
name db.table,
1588+
kind FULL,
1589+
partitioned_by ("a", "b")
1590+
)"""
1591+
)
1592+
1593+
# multiple columns not wrapped in parenthesis in the definition is an error
1594+
with pytest.raises(ParseError, match=r"keyword: 'value' missing"):
1595+
load_sql_based_model(
1596+
d.parse(
1597+
f"""
1598+
MODEL (
1599+
name db.table,
1600+
kind FULL,
1601+
partitioned_by a, b
1602+
);
1603+
1604+
select 1 as a, 2 as b;
1605+
"""
1606+
)
1607+
)
1608+
1609+
# Iceberg transforms / functions
1610+
model = load_sql_based_model(
1611+
d.parse(
1612+
f"""
1613+
MODEL (
1614+
name db.table,
1615+
kind FULL,
1616+
partitioned_by (day(a), truncate(b, 4), bucket(c, 3))
1617+
);
1618+
1619+
select 1 as a, 2 as b, 3 as c;
1620+
"""
1621+
),
1622+
dialect="trino",
1623+
)
1624+
1625+
assert model.partitioned_by == [
1626+
exp.Day(this=exp.column("a", quoted=True)),
1627+
exp.PartitionByTruncate(
1628+
this=exp.column("b", quoted=True), expression=exp.Literal.number(4)
1629+
),
1630+
exp.PartitionedByBucket(
1631+
this=exp.column("c", quoted=True), expression=exp.Literal.number(3)
1632+
),
1633+
]
1634+
assert (
1635+
model.render_definition()[0].sql(pretty=True)
1636+
== """MODEL (
1637+
name db.table,
1638+
dialect trino,
1639+
kind FULL,
1640+
partitioned_by (DAY("a"), TRUNCATE("b", 4), BUCKET("c", 3))
1641+
)"""
1642+
)
1643+
1644+
15171645
def test_cron():
15181646
daily = _Node(name="x", cron="@daily")
15191647
assert to_datetime(daily.cron_prev("2020-01-01")) == to_datetime("2019-12-31")

tests/core/test_snapshot.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2938,3 +2938,37 @@ def check_types(batch, env: str, sql: list[SQL], table: exp.Table, default: int
29382938
)
29392939
snapshot_a = make_snapshot(sql_model)
29402940
assert snapshot_a.check_ready_intervals([(0, 1)], mocker.Mock()) == [(0, 1)]
2941+
2942+
2943+
def test_partitioned_by_roundtrip(make_snapshot: t.Callable):
2944+
sql_model = load_sql_based_model(
2945+
parse("""
2946+
MODEL (
2947+
name test_schema.test_model,
2948+
kind full,
2949+
partitioned_by (a, bucket(4, b), truncate(3, c), month(d))
2950+
);
2951+
SELECT a, b, c, d FROM tbl;
2952+
""")
2953+
)
2954+
snapshot = make_snapshot(sql_model)
2955+
assert isinstance(snapshot, Snapshot)
2956+
assert isinstance(snapshot.node, SqlModel)
2957+
2958+
assert snapshot.node.partitioned_by == [
2959+
exp.column("a", quoted=True),
2960+
exp.PartitionedByBucket(
2961+
this=exp.column("b", quoted=True), expression=exp.Literal.number(4)
2962+
),
2963+
exp.PartitionByTruncate(
2964+
this=exp.column("c", quoted=True), expression=exp.Literal.number(3)
2965+
),
2966+
exp.Month(this=exp.column("d", quoted=True)),
2967+
]
2968+
2969+
# roundtrip through json and ensure we get correct AST nodes on the other end
2970+
serialized = snapshot.json()
2971+
deserialized = snapshot.parse_raw(serialized)
2972+
2973+
assert isinstance(deserialized.node, SqlModel)
2974+
assert deserialized.node.partitioned_by == snapshot.node.partitioned_by

0 commit comments

Comments
 (0)