Skip to content

Commit 5ce7ff8

Browse files
authored
Merge pull request #4 from azurechen97/feat/split-and-0.3.0
Feat/split and 0.3.0
2 parents ded0583 + bb6ac67 commit 5ce7ff8

8 files changed

Lines changed: 582 additions & 377 deletions

File tree

CHANGELOG.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,32 @@
11
# Changelog
22

3+
## [0.3.0] - 2026-03-31
4+
5+
### Fixed (generator correctness)
6+
7+
- `SPACE(n)` now emits `SPACE(n)` instead of `REPEAT(' ', n)`
8+
- `VAR_POP(x)` now emits `VAR_POP(x)` instead of `VARIANCE_POP(x)`
9+
- `VAR_SAMP(x)` / `VARIANCE(x)` now emits `VAR_SAMP(x)` instead of `VARIANCE(x)`
10+
- `INSTR(s, sub)` now emits `INSTR(s, sub)` instead of `LOCATE(sub, s)`
11+
- `SUBSTR(s, pos, len)` now emits `SUBSTR` instead of `SUBSTRING`
12+
13+
### Added (parser)
14+
15+
- `SUBSTR` as explicit MaxCompute parser alias for `exp.Substring`
16+
17+
### Changed (internal)
18+
19+
- Dialect split: `maxcompute.py` now delegates to `parser.py` and `generator.py` (mirrors sqlglot's own mypyc-compile refactor)
20+
21+
### Tests
22+
23+
- Regression coverage for ~20 functions previously relying on untested Hive inheritance:
24+
INITCAP, REVERSE, REPEAT, LPAD/RPAD, LTRIM/RTRIM, REGEXP_REPLACE,
25+
REGEXP_EXTRACT_ALL, INSTR, FIND_IN_SET, SUBSTR, SUBSTRING_INDEX,
26+
CONCAT_WS, FORMAT_NUMBER, COLLECT_LIST/SET, VAR_SAMP, VAR_POP,
27+
PERCENTILE, STDDEV, GREATEST/LEAST, CBRT, FACTORIAL, GET_JSON_OBJECT,
28+
JSON_TUPLE
29+
330
## [0.2.0] - 2026-03-31
431

532
### Added

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "sqlglot-maxcompute"
3-
version = "0.2.0"
3+
version = "0.3.0"
44
description = "MaxCompute dialect plugin for SQLGlot"
55
readme = "README.md"
66
license = { text = "MIT" }

src/sqlglot_maxcompute/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
1-
def hello() -> str:
2-
return "Hello from sqlglot-maxcompute!"
1+
from sqlglot_maxcompute.maxcompute import MaxCompute
2+
from sqlglot_maxcompute.parser import MaxComputeParser
3+
from sqlglot_maxcompute.generator import MaxComputeGenerator
4+
5+
__all__ = ["MaxCompute", "MaxComputeParser", "MaxComputeGenerator"]
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
from __future__ import annotations
2+
3+
import typing as t
4+
5+
from sqlglot import exp
6+
from sqlglot.dialects.hive import Hive
7+
from sqlglot.dialects.dialect import rename_func, unit_to_str
8+
from sqlglot.transforms import (
9+
move_schema_columns_to_partitioned_by,
10+
preprocess,
11+
remove_unique_constraints,
12+
ctas_with_tmp_tables_to_create_tmp_view,
13+
)
14+
15+
16+
_AUTO_PARTITION_TYPES = (exp.DateTrunc, exp.TimestampTrunc, exp.DatetimeTrunc, exp.Alias)
17+
18+
19+
def _move_schema_columns_to_partitioned_by(expression: exp.Expr) -> exp.Expr:
20+
"""Like the Hive transform, but skip AUTO PARTITIONED BY (where this is a DateTrunc/Alias)."""
21+
assert isinstance(expression, exp.Create)
22+
prop = expression.find(exp.PartitionedByProperty)
23+
if prop and isinstance(prop.this, _AUTO_PARTITION_TYPES):
24+
return expression
25+
return move_schema_columns_to_partitioned_by(expression)
26+
27+
28+
class MaxComputeGenerator(Hive.Generator):
29+
TYPE_MAPPING = {
30+
**Hive.Generator.TYPE_MAPPING,
31+
exp.DType.DATETIME: "DATETIME",
32+
exp.DType.VARCHAR: "STRING",
33+
exp.DType.CHAR: "STRING",
34+
exp.DType.TEXT: "STRING",
35+
}
36+
37+
TRANSFORMS = {
38+
**Hive.Generator.TRANSFORMS,
39+
exp.Create: preprocess(
40+
[
41+
remove_unique_constraints,
42+
ctas_with_tmp_tables_to_create_tmp_view,
43+
_move_schema_columns_to_partitioned_by,
44+
]
45+
),
46+
exp.PartitionedByProperty: lambda self, e: self._partitioned_by_sql(e),
47+
# Date/time transforms
48+
exp.TsOrDsAdd: lambda self, e: self._dateadd_sql(e),
49+
exp.DateAdd: lambda self, e: self._dateadd_sql(e),
50+
exp.TimestampAdd: lambda self, e: self._dateadd_sql(e),
51+
exp.DatetimeAdd: lambda self, e: self._dateadd_sql(e),
52+
exp.DateSub: lambda self, e: self._dateadd_sql(e),
53+
exp.DateDiff: lambda self, e: self._datediff_sql(e),
54+
exp.DateTrunc: lambda self, e: self._datetrunc_sql(e),
55+
exp.TimestampTrunc: lambda self, e: self._datetrunc_sql(e),
56+
exp.DatetimeTrunc: lambda self, e: self._datetrunc_sql(e),
57+
exp.CurrentTimestamp: lambda self, e: "GETDATE()",
58+
exp.CurrentDatetime: lambda self, e: "NOW()",
59+
# String transforms
60+
exp.Lower: rename_func("TOLOWER"),
61+
exp.Upper: rename_func("TOUPPER"),
62+
# JSON / misc
63+
exp.ParseJSON: rename_func("FROM_JSON"),
64+
exp.CurrentUser: lambda self, e: "GET_USER_ID()",
65+
exp.UnixMillis: rename_func("TO_MILLIS"),
66+
# Aggregate
67+
exp.ApproxDistinct: rename_func("APPROX_DISTINCT"),
68+
exp.ArgMax: lambda self, e: self.func("ARG_MAX", e.this, e.expression),
69+
exp.ArgMin: lambda self, e: self.func("ARG_MIN", e.this, e.expression),
70+
# Statistical aggregate fixes (Hive emits wrong names)
71+
exp.Space: rename_func("SPACE"),
72+
exp.VariancePop: rename_func("VAR_POP"),
73+
exp.Variance: rename_func("VAR_SAMP"),
74+
# String position: MaxCompute uses INSTR(str, substr), not LOCATE(substr, str)
75+
exp.StrPosition: lambda self, e: self.func("INSTR", e.this, e.args.get("substr")),
76+
# TO_DATE(str, fmt) returns DATETIME — modeled as StrToTime; emit TO_DATE in MaxCompute
77+
exp.StrToTime: lambda self, e: self.func("TO_DATE", e.this, e.args.get("format")),
78+
}
79+
80+
def _dateadd_sql(
81+
self,
82+
expression: exp.TsOrDsAdd | exp.DateAdd | exp.DateSub | exp.TimestampAdd | exp.DatetimeAdd,
83+
) -> str:
84+
unit = unit_to_str(expression) if expression.args.get("unit") else "'DAY'"
85+
delta = expression.expression
86+
if isinstance(expression, exp.DateSub):
87+
# DateSub magnitude is positive; negate it so DATEADD subtracts.
88+
# Some dialects (e.g. BigQuery) store the magnitude as a string
89+
# literal — normalize to a number first so we emit -3 not -'3'.
90+
if isinstance(delta, exp.Literal) and delta.is_string:
91+
delta = exp.Literal.number(delta.this)
92+
delta = exp.Neg(this=delta)
93+
return self.func("DATEADD", expression.this, delta, unit)
94+
95+
def _datediff_sql(self, expression: exp.DateDiff) -> str:
96+
unit = unit_to_str(expression) if expression.args.get("unit") else None
97+
return self.func("DATEDIFF", expression.this, expression.expression, unit)
98+
99+
def _datetrunc_sql(
100+
self, expression: exp.DateTrunc | exp.TimestampTrunc | exp.DatetimeTrunc
101+
) -> str:
102+
unit = expression.args.get("unit")
103+
# WeekStart units must be emitted as 'week(day)' string literals.
104+
# unit_to_str returns the raw node name which would produce DATETRUNC(dt, WEEK(MONDAY))
105+
# — invalid MaxCompute SQL. Reconstruct the canonical 'week(day)' form instead.
106+
if isinstance(unit, exp.WeekStart):
107+
day = unit.this.name.lower() if unit.args.get("this") else "monday"
108+
unit_sql = exp.Literal.string(f"week({day})")
109+
else:
110+
unit_sql = unit_to_str(expression)
111+
return self.func("DATETRUNC", expression.this, unit_sql)
112+
113+
def groupconcat_sql(self, expression: exp.GroupConcat) -> str:
114+
sep = expression.args.get("separator") or exp.Literal.string(",")
115+
return self.func("WM_CONCAT", sep, expression.this)
116+
117+
def tochar_sql(self, expression: exp.ToChar) -> str:
118+
return self.func("TO_CHAR", expression.this, expression.args.get("format"))
119+
120+
def substring_sql(self, expression: exp.Substring) -> str:
121+
return self.func("SUBSTR", expression.this, expression.args.get("start"), expression.args.get("length"))
122+
123+
def extract_sql(self, expression: exp.Extract) -> str:
124+
unit = expression.this
125+
return self.func("DATEPART", expression.expression, exp.Literal.string(unit.name))
126+
127+
def mod_sql(self, expression: exp.Mod) -> str:
128+
# Reverse the WEEKDAY parser transform: (DAYOFWEEK(x) + 5) % 7 → WEEKDAY(x)
129+
rhs = expression.expression
130+
lhs = expression.this
131+
if (
132+
isinstance(rhs, exp.Literal) and rhs.this == "7"
133+
and isinstance(lhs, exp.Paren)
134+
and isinstance(lhs.this, exp.Add)
135+
and isinstance(lhs.this.this, exp.DayOfWeek)
136+
and isinstance(lhs.this.expression, exp.Literal)
137+
and lhs.this.expression.this == "5"
138+
):
139+
return self.func("WEEKDAY", lhs.this.this.this)
140+
return super().mod_sql(expression)
141+
142+
def _partitioned_by_sql(self, expression: exp.PartitionedByProperty) -> str:
143+
inner = expression.this
144+
if isinstance(inner, _AUTO_PARTITION_TYPES):
145+
alias_sql = ""
146+
if isinstance(inner, exp.Alias):
147+
alias_sql = f" AS {inner.alias}"
148+
inner = inner.this
149+
unit = inner.args.get("unit")
150+
unit_str = unit.name.lower() if unit else ""
151+
trunc_sql = self.func("TRUNC_TIME", inner.this, exp.Literal.string(unit_str))
152+
return f"AUTO PARTITIONED BY ({trunc_sql}{alias_sql})"
153+
return f"PARTITIONED BY {self.sql(expression, 'this')}"
154+
155+
def clusteredbyproperty_sql(self, expression: exp.ClusteredByProperty) -> str:
156+
sql = super().clusteredbyproperty_sql(expression)
157+
return f"RANGE {sql}" if expression.args.get("range") else sql
158+
159+
def datatype_sql(self, expression: exp.DataType) -> str:
160+
# VARCHAR and CHAR map to STRING in MaxCompute, with no length parameters
161+
if expression.this in (exp.DType.VARCHAR, exp.DType.CHAR):
162+
return self.TYPE_MAPPING.get(expression.this, super().datatype_sql(expression))
163+
return super().datatype_sql(expression)
164+
165+
def properties_sql(self, expression: exp.Properties) -> str:
166+
# Var-keyed exp.Property instances (e.g. LIFECYCLE 30) render as bare
167+
# KEY value after the schema. String-keyed ones stay in TBLPROPERTIES.
168+
var_keyed = [
169+
p
170+
for p in expression.expressions
171+
if isinstance(p, exp.Property) and isinstance(p.this, exp.Var)
172+
]
173+
other = [p for p in expression.expressions if p not in var_keyed]
174+
175+
other_node = exp.Properties(expressions=other)
176+
other_node.parent = expression.parent
177+
base_sql = super().properties_sql(other_node) if other else ""
178+
179+
bare_sql = " ".join(f"{p.name} {self.sql(p, 'value')}" for p in var_keyed)
180+
181+
if base_sql and bare_sql:
182+
return f"{base_sql} {bare_sql}"
183+
return base_sql or bare_sql

0 commit comments

Comments
 (0)