Skip to content

Commit 589713c

Browse files
committed
refactor partition
1 parent cc3090a commit 589713c

File tree

2 files changed

+250
-66
lines changed

2 files changed

+250
-66
lines changed

sqlmesh/core/engine_adapter/doris.py

Lines changed: 87 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,51 @@ def _create_table_from_columns(
485485
**kwargs,
486486
)
487487

488+
def _parse_partition_expressions(
489+
self, partitioned_by: t.List[exp.Expression]
490+
) -> t.Tuple[t.List[exp.Expression], t.Optional[str]]:
491+
"""Parse partition expressions and extract partition kind and normalized columns.
492+
493+
Returns:
494+
Tuple of (normalized_partitioned_by, partition_kind)
495+
"""
496+
parsed_partitioned_by: t.List[exp.Expression] = []
497+
partition_kind: t.Optional[str] = None
498+
499+
for expr in partitioned_by:
500+
try:
501+
# Handle Anonymous function calls like RANGE(col) or LIST(col)
502+
if isinstance(expr, exp.Anonymous) and expr.this:
503+
func_name = str(expr.this).upper()
504+
if func_name in ("RANGE", "LIST"):
505+
partition_kind = func_name
506+
# Extract column expressions from function arguments
507+
for arg in expr.expressions:
508+
if isinstance(arg, exp.Column):
509+
parsed_partitioned_by.append(arg)
510+
else:
511+
# Convert other expressions to columns if possible
512+
parsed_partitioned_by.append(exp.to_column(str(arg)))
513+
continue
514+
515+
# Handle literal strings like "RANGE(col)" or "LIST(col)"
516+
if isinstance(expr, exp.Literal) and getattr(expr, "is_string", False):
517+
text = str(expr.this)
518+
match = re.match(r"^\s*(RANGE|LIST)\s*\((.*?)\)\s*$", text, flags=re.IGNORECASE)
519+
if match:
520+
partition_kind = match.group(1).upper()
521+
inner = match.group(2)
522+
inner_cols = [c.strip().strip("`") for c in inner.split(",") if c.strip()]
523+
for col in inner_cols:
524+
parsed_partitioned_by.append(exp.to_column(col))
525+
continue
526+
except Exception:
527+
# If anything goes wrong, keep the original expr
528+
pass
529+
parsed_partitioned_by.append(expr)
530+
531+
return parsed_partitioned_by, partition_kind
532+
488533
def _build_partitioned_by_exp(
489534
self,
490535
partitioned_by: t.List[exp.Expression],
@@ -505,12 +550,11 @@ def _build_partitioned_by_exp(
505550
506551
Supports both RANGE and LIST partition syntaxes using sqlglot's doris dialect nodes.
507552
The partition kind is chosen by:
508-
- kwargs["partition_kind"] if provided (expects 'RANGE' or 'LIST', case-insensitive)
553+
- inferred from partitioned_by expressions like 'RANGE(col)' or 'LIST(col)'
509554
- otherwise inferred from the provided 'partitions' strings: if any contains 'VALUES IN' -> LIST; else RANGE.
510555
"""
511556
partitions = kwargs.get("partitions")
512557
create_expressions = None
513-
partition_kind: t.Optional[str] = kwargs.get("partition_kind")
514558

515559
def to_raw_sql(expr: t.Union[exp.Literal, exp.Var, str, t.Any]) -> exp.Var:
516560
# If it's a Literal, extract the string and wrap as Var (no quotes)
@@ -525,6 +569,9 @@ def to_raw_sql(expr: t.Union[exp.Literal, exp.Var, str, t.Any]) -> exp.Var:
525569
# Fallback: return as is
526570
return expr
527571

572+
# Parse partition kind and columns from partitioned_by expressions
573+
partitioned_by, partition_kind = self._parse_partition_expressions(partitioned_by)
574+
528575
if partitions:
529576
if isinstance(partitions, exp.Tuple):
530577
create_expressions = [
@@ -555,8 +602,8 @@ def to_raw_sql(expr: t.Union[exp.Literal, exp.Var, str, t.Any]) -> exp.Var:
555602
try:
556603
if is_list:
557604
return exp.PartitionByListProperty(
558-
this=exp.Schema(expressions=partitioned_by),
559-
partitions=create_expressions,
605+
partition_expressions=partitioned_by,
606+
create_expressions=create_expressions,
560607
)
561608
return exp.PartitionByRangeProperty(
562609
partition_expressions=partitioned_by,
@@ -800,46 +847,43 @@ def _parse_trigger_string(
800847
)
801848

802849
# Handle duplicate_key - only handle Tuple expressions or single Column expressions
850+
# Both tables and materialized views support duplicate keys in Doris
803851
duplicate_key = table_properties_copy.pop("duplicate_key", None)
804852
if duplicate_key is not None:
805-
if not is_materialized_view:
806-
if isinstance(duplicate_key, exp.Tuple):
807-
# Extract column names from Tuple expressions
808-
column_names = []
809-
for expr in duplicate_key.expressions:
810-
if (
811-
isinstance(expr, exp.Column)
812-
and hasattr(expr, "this")
813-
and hasattr(expr.this, "this")
814-
):
815-
column_names.append(str(expr.this.this))
816-
elif hasattr(expr, "this"):
817-
column_names.append(str(expr.this))
818-
else:
819-
column_names.append(str(expr))
820-
properties.append(
821-
exp.DuplicateKeyProperty(
822-
expressions=[exp.to_column(k) for k in column_names]
823-
)
824-
)
825-
elif isinstance(duplicate_key, exp.Column):
826-
# Handle as single column
827-
if hasattr(duplicate_key, "this") and hasattr(duplicate_key.this, "this"):
828-
column_name = str(duplicate_key.this.this)
853+
if isinstance(duplicate_key, exp.Tuple):
854+
# Extract column names from Tuple expressions
855+
column_names = []
856+
for expr in duplicate_key.expressions:
857+
if (
858+
isinstance(expr, exp.Column)
859+
and hasattr(expr, "this")
860+
and hasattr(expr.this, "this")
861+
):
862+
column_names.append(str(expr.this.this))
863+
elif hasattr(expr, "this"):
864+
column_names.append(str(expr.this))
829865
else:
830-
column_name = str(duplicate_key.this)
831-
properties.append(
832-
exp.DuplicateKeyProperty(expressions=[exp.to_column(column_name)])
833-
)
834-
elif isinstance(duplicate_key, exp.Literal):
835-
properties.append(
836-
exp.DuplicateKeyProperty(expressions=[exp.to_column(duplicate_key.this)])
837-
)
838-
elif isinstance(duplicate_key, str):
839-
properties.append(
840-
exp.DuplicateKeyProperty(expressions=[exp.to_column(duplicate_key)])
841-
)
842-
# Note: Materialized views don't typically use duplicate_key, so we skip it
866+
column_names.append(str(expr))
867+
properties.append(
868+
exp.DuplicateKeyProperty(expressions=[exp.to_column(k) for k in column_names])
869+
)
870+
elif isinstance(duplicate_key, exp.Column):
871+
# Handle as single column
872+
if hasattr(duplicate_key, "this") and hasattr(duplicate_key.this, "this"):
873+
column_name = str(duplicate_key.this.this)
874+
else:
875+
column_name = str(duplicate_key.this)
876+
properties.append(
877+
exp.DuplicateKeyProperty(expressions=[exp.to_column(column_name)])
878+
)
879+
elif isinstance(duplicate_key, exp.Literal):
880+
properties.append(
881+
exp.DuplicateKeyProperty(expressions=[exp.to_column(duplicate_key.this)])
882+
)
883+
elif isinstance(duplicate_key, str):
884+
properties.append(
885+
exp.DuplicateKeyProperty(expressions=[exp.to_column(duplicate_key)])
886+
)
843887

844888
if table_description:
845889
properties.append(
@@ -851,30 +895,8 @@ def _parse_trigger_string(
851895
# Handle partitioning
852896
add_partition = True
853897
if partitioned_by:
854-
normalized_partitioned_by: t.List[exp.Expression] = []
855-
for expr in partitioned_by:
856-
try:
857-
# Handle literal strings like "RANGE(col)" or "LIST(col)"
858-
if isinstance(expr, exp.Literal) and getattr(expr, "is_string", False):
859-
text = str(expr.this)
860-
match = re.match(
861-
r"^\s*(RANGE|LIST)\s*\((.*?)\)\s*$", text, flags=re.IGNORECASE
862-
)
863-
if match:
864-
inner = match.group(2)
865-
inner_cols = [
866-
c.strip().strip("`") for c in inner.split(",") if c.strip()
867-
]
868-
for col in inner_cols:
869-
normalized_partitioned_by.append(exp.to_column(col))
870-
continue
871-
except Exception:
872-
# If anything goes wrong, keep the original expr
873-
pass
874-
normalized_partitioned_by.append(expr)
875-
876-
# Replace with normalized expressions
877-
partitioned_by = normalized_partitioned_by
898+
# Parse and normalize partition expressions
899+
partitioned_by, _ = self._parse_partition_expressions(partitioned_by)
878900
# For tables, check if partitioned_by columns are in unique_key; for materialized views, allow regardless
879901
if unique_key is not None and not is_materialized_view:
880902
# Extract key column names from unique_key (only Tuple or Column expressions)

tests/core/engine_adapter/test_doris.py

Lines changed: 163 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def test_create_table_with_partitioned_by(
280280
adapter.create_table(
281281
"test_table",
282282
target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("DATE")},
283-
partitioned_by=[exp.to_column("b")],
283+
partitioned_by=[exp.Literal.string("RANGE(b)")],
284284
table_properties={
285285
"partitions": exp.Literal.string(
286286
"FROM ('2000-11-14') TO ('2021-11-14') INTERVAL 2 YEAR"
@@ -293,6 +293,168 @@ def test_create_table_with_partitioned_by(
293293
]
294294

295295

296+
def test_create_table_with_range_partitioned_by_with_partitions(
297+
make_mocked_engine_adapter: t.Callable[..., DorisEngineAdapter],
298+
):
299+
adapter = make_mocked_engine_adapter(DorisEngineAdapter)
300+
adapter.create_table(
301+
"test_table",
302+
target_columns_to_types={
303+
"id": exp.DataType.build("INT"),
304+
"waiter_id": exp.DataType.build("INT"),
305+
"customer_id": exp.DataType.build("INT"),
306+
"ds": exp.DataType.build("DATETIME"),
307+
},
308+
partitioned_by=[exp.Literal.string("RANGE(ds)")],
309+
table_properties={
310+
"partitions": exp.Tuple(
311+
expressions=[
312+
exp.Literal.string('PARTITION `p2023` VALUES [("2023-01-01"), ("2024-01-01"))'),
313+
exp.Literal.string('PARTITION `p2024` VALUES [("2024-01-01"), ("2025-01-01"))'),
314+
exp.Literal.string('PARTITION `p2025` VALUES [("2025-01-01"), ("2026-01-01"))'),
315+
exp.Literal.string("PARTITION `other` VALUES LESS THAN MAXVALUE"),
316+
]
317+
),
318+
"distributed_by": exp.Tuple(
319+
expressions=[
320+
exp.EQ(
321+
this=exp.Column(this=exp.Identifier(this="kind", quoted=True)),
322+
expression=exp.Literal.string("HASH"),
323+
),
324+
exp.EQ(
325+
this=exp.Column(this=exp.Identifier(this="expressions", quoted=True)),
326+
expression=exp.Column(this=exp.Identifier(this="id", quoted=True)),
327+
),
328+
exp.EQ(
329+
this=exp.Column(this=exp.Identifier(this="buckets", quoted=True)),
330+
expression=exp.Literal.number(10),
331+
),
332+
]
333+
),
334+
"replication_allocation": exp.Literal.string("tag.location.default: 3"),
335+
"in_memory": exp.Literal.string("false"),
336+
"storage_format": exp.Literal.string("V2"),
337+
"disable_auto_compaction": exp.Literal.string("false"),
338+
},
339+
)
340+
341+
expected_sql = (
342+
"CREATE TABLE IF NOT EXISTS `test_table` "
343+
"(`id` INT, `waiter_id` INT, `customer_id` INT, `ds` DATETIME) "
344+
"PARTITION BY RANGE (`ds`) "
345+
'(PARTITION `p2023` VALUES [("2023-01-01"), ("2024-01-01")), '
346+
'PARTITION `p2024` VALUES [("2024-01-01"), ("2025-01-01")), '
347+
'PARTITION `p2025` VALUES [("2025-01-01"), ("2026-01-01")), '
348+
"PARTITION `other` VALUES LESS THAN MAXVALUE) "
349+
"DISTRIBUTED BY HASH (`id`) BUCKETS 10 "
350+
"PROPERTIES ("
351+
"'replication_allocation'='tag.location.default: 3', "
352+
"'in_memory'='false', "
353+
"'storage_format'='V2', "
354+
"'disable_auto_compaction'='false')"
355+
)
356+
357+
assert to_sql_calls(adapter) == [expected_sql]
358+
359+
360+
def test_create_table_with_list_partitioned_by(
361+
make_mocked_engine_adapter: t.Callable[..., DorisEngineAdapter],
362+
):
363+
adapter = make_mocked_engine_adapter(DorisEngineAdapter)
364+
adapter.create_table(
365+
"test_table",
366+
target_columns_to_types={
367+
"id": exp.DataType.build("INT"),
368+
"status": exp.DataType.build("VARCHAR(10)"),
369+
},
370+
partitioned_by=[exp.Literal.string("LIST(status)")],
371+
table_properties={
372+
"partitions": exp.Tuple(
373+
expressions=[
374+
exp.Literal.string('PARTITION `active` VALUES IN ("active", "pending")'),
375+
exp.Literal.string('PARTITION `inactive` VALUES IN ("inactive", "disabled")'),
376+
]
377+
),
378+
},
379+
)
380+
381+
expected_sql = (
382+
"CREATE TABLE IF NOT EXISTS `test_table` "
383+
"(`id` INT, `status` VARCHAR(10)) "
384+
"PARTITION BY LIST (`status`) "
385+
'(PARTITION `active` VALUES IN ("active", "pending"), '
386+
'PARTITION `inactive` VALUES IN ("inactive", "disabled"))'
387+
)
388+
389+
assert to_sql_calls(adapter) == [expected_sql]
390+
391+
392+
def test_create_table_with_range_partitioned_by_anonymous_function(
393+
make_mocked_engine_adapter: t.Callable[..., DorisEngineAdapter],
394+
):
395+
"""Test that RANGE(ds) function call syntax generates correct SQL without duplicate RANGE."""
396+
adapter = make_mocked_engine_adapter(DorisEngineAdapter)
397+
adapter.create_table(
398+
"test_table",
399+
target_columns_to_types={
400+
"id": exp.DataType.build("INT"),
401+
"ds": exp.DataType.build("DATETIME"),
402+
},
403+
# This simulates how partitioned_by RANGE(ds) gets parsed from model definition
404+
partitioned_by=[exp.Anonymous(this="RANGE", expressions=[exp.to_column("ds")])],
405+
table_properties={
406+
"partitions": exp.Literal.string(
407+
'FROM ("2000-11-14") TO ("2099-11-14") INTERVAL 1 MONTH'
408+
)
409+
},
410+
)
411+
412+
expected_sql = (
413+
"CREATE TABLE IF NOT EXISTS `test_table` "
414+
"(`id` INT, `ds` DATETIME) "
415+
"PARTITION BY RANGE (`ds`) "
416+
'(FROM ("2000-11-14") TO ("2099-11-14") INTERVAL 1 MONTH)'
417+
)
418+
419+
assert to_sql_calls(adapter) == [expected_sql]
420+
421+
422+
def test_create_materialized_view_with_duplicate_key(
423+
make_mocked_engine_adapter: t.Callable[..., DorisEngineAdapter],
424+
):
425+
adapter = make_mocked_engine_adapter(DorisEngineAdapter)
426+
adapter.create_view(
427+
"test_mv",
428+
parse_one("SELECT id, status, COUNT(*) as cnt FROM orders GROUP BY id, status"),
429+
materialized=True,
430+
target_columns_to_types={
431+
"id": exp.DataType.build("INT"),
432+
"status": exp.DataType.build("VARCHAR(10)"),
433+
"cnt": exp.DataType.build("BIGINT"),
434+
},
435+
view_properties={
436+
"duplicate_key": exp.Tuple(
437+
expressions=[
438+
exp.to_column("id"),
439+
exp.to_column("status"),
440+
]
441+
),
442+
},
443+
)
444+
445+
expected_sqls = [
446+
"DROP MATERIALIZED VIEW IF EXISTS `test_mv`",
447+
(
448+
"CREATE MATERIALIZED VIEW `test_mv` "
449+
"(`id`, `status`, `cnt`) "
450+
"DUPLICATE KEY (`id`, `status`) "
451+
"AS SELECT `id`, `status`, COUNT(*) AS `cnt` FROM `orders` GROUP BY `id`, `status`"
452+
),
453+
]
454+
455+
assert to_sql_calls(adapter) == expected_sqls
456+
457+
296458
def test_create_full_materialized_view(
297459
make_mocked_engine_adapter: t.Callable[..., DorisEngineAdapter],
298460
):

0 commit comments

Comments
 (0)