Skip to content

Commit 15d1cb2

Browse files
authored
Merge pull request #5 from azurechen97/feat/0.3.2-bug-fixes
Feat/0.3.2 bug fixes
2 parents d889985 + fac8f26 commit 15d1cb2

10 files changed

Lines changed: 154 additions & 39 deletions

File tree

CHANGELOG.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,20 @@
11
# Changelog
22

3+
## [0.4.0] - 2026-04-01
4+
5+
### Changed (internal refactor)
6+
7+
- `maxcompute.py` renamed to `dialect.py` — the coordinator class `MaxCompute` is now in `src/sqlglot_maxcompute/dialect.py`
8+
- `MaxComputeParser` now inherits from `HiveParser` (imported from `sqlglot.parsers.hive`) instead of `Hive.Parser`
9+
- `MaxComputeGenerator` now inherits from `HiveGenerator` (imported from `sqlglot.generators.hive`) instead of `Hive.Generator`
10+
- `sqlglot` dependency floor raised to `>=30.1.0` (first release with split `parsers/` and `generators/` modules)
11+
12+
### Fixed (parser + generator correctness)
13+
14+
- `TRUNC(n, d)` now maps to `exp.Trunc` for numeric truncation; `TRUNC(dt, 'unit')` still routes to date truncation
15+
- `BOOL_AND(col)` / `BOOL_OR(col)` now emit correct MaxCompute names instead of `LOGICAL_AND` / `LOGICAL_OR`
16+
- `LOCATE(sub, str, start)` now passes the start position through to `INSTR(str, sub, start)` instead of silently dropping it
17+
318
## [0.3.1] - 2026-04-01
419

520
### Fixed (parser correctness)
@@ -24,7 +39,7 @@
2439

2540
### Changed (internal)
2641

27-
- Dialect split: `maxcompute.py` now delegates to `parser.py` and `generator.py` (mirrors sqlglot's own mypyc-compile refactor)
42+
- Dialect split: `dialect.py` now delegates to `parser.py` and `generator.py` (mirrors sqlglot's own mypyc-compile refactor)
2843

2944
### Tests
3045

CLAUDE.md

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,26 +28,26 @@ uv run pytest tests/test_foo.py::test_bar
2828

2929
The dialect is split across three files in `src/sqlglot_maxcompute/`:
3030

31-
- **`parser.py`**`MaxComputeParser(Hive.Parser)`: `FUNCTIONS` dict mapping MaxCompute function names to canonical `sqlglot.exp` nodes; `PROPERTY_PARSERS` for `LIFECYCLE`, `RANGE`, and `AUTO`; helper builders `_build_dateadd`, `_build_datetrunc`.
32-
- **`generator.py`**`MaxComputeGenerator(Hive.Generator)`: `TYPE_MAPPING`, `TRANSFORMS`, and named `_sql` methods that map canonical AST nodes back to MaxCompute SQL.
33-
- **`maxcompute.py`**`MaxCompute(Hive)`: slim coordinator that sets `TIME_MAPPING`/`DATE_FORMAT`/`TIME_FORMAT`, adds `Tokenizer` keywords (`EXPORT`, `LIFECYCLE`, `OPTION`), and wires `Parser = MaxComputeParser` / `Generator = MaxComputeGenerator`.
31+
- **`parser.py`**`MaxComputeParser(HiveParser)`: `FUNCTIONS` dict mapping MaxCompute function names to canonical `sqlglot.exp` nodes; `PROPERTY_PARSERS` for `LIFECYCLE`, `RANGE`, and `AUTO`; helper builders `_build_dateadd`, `_build_datetrunc`.
32+
- **`generator.py`**`MaxComputeGenerator(HiveGenerator)`: `TYPE_MAPPING`, `TRANSFORMS`, and named `_sql` methods that map canonical AST nodes back to MaxCompute SQL.
33+
- **`dialect.py`**`MaxCompute(Hive)`: slim coordinator that sets `TIME_MAPPING`/`DATE_FORMAT`/`TIME_FORMAT`, adds `Tokenizer` keywords (`EXPORT`, `LIFECYCLE`, `OPTION`), and wires `Parser = MaxComputeParser` / `Generator = MaxComputeGenerator`.
3434

3535
The dialect is registered as a plugin in `pyproject.toml` under `[project.entry-points."sqlglot.dialects"]`, so after installation it is automatically discoverable by sqlglot as `"maxcompute"`.
3636

37-
This split mirrors sqlglot's own mypyc-compile refactor (parsers/generators split by file) and is required for compatibility with sqlglot ≥ 31 compiled wheels.
37+
This split mirrors sqlglot's own mypyc-compile refactor (parsers/generators split into `sqlglot.parsers.*` / `sqlglot.generators.*` modules) and requires sqlglot ≥ 30.1.0.
3838

3939
`local/` contains development scratch files and references — **not part of the package**:
4040
- `scratch.py` — keyword comparison scratch script
41-
- `sqlglot/` — full clone of the sqlglot repo for reference (expressions, dialects, generator internals); `sqlglot/posts/` contains official guides (`onboarding.md` for architecture deep-dive, `ast_primer.md` for AST tutorial). Note: local clone is newer than installed (30.0.1) — dialect parsers moved to `parsers/`, expressions split into `expressions/` package
41+
- `sqlglot/` — full clone of the sqlglot repo for reference (expressions, dialects, generator internals); `sqlglot/posts/` contains official guides (`onboarding.md` for architecture deep-dive, `ast_primer.md` for AST tutorial). Parsers live in `parsers/`, generators in `generators/`, expressions in `expressions/` package
4242
- `ydb-sqlglot-plugin/` — YDB dialect plugin, used as reference for how a well-behaved plugin is structured
4343
- `maxcompute_doc/` — MaxCompute official function documentation (e.g., `date_func.md`, `func_comparison.md`)
4444

4545
## Implementation Status
4646

47-
The dialect is complete at v0.3.1:
47+
The dialect is complete at v0.4.0:
4848
- **Parser**: ~65 functions explicitly mapped (date/time, string, aggregate, array, map); remainder inherited from Hive.
4949
- **Generator**: `TRANSFORMS` + named `_sql` methods for all major expression types; Hive handles the rest.
50-
- **Tests**: 39 test methods, 180+ subtests covering parse, round-trip, and cross-dialect transpilation.
50+
- **Tests**: 40 test methods, 186 subtests covering parse, round-trip, and cross-dialect transpilation.
5151

5252
## Key sqlglot patterns
5353

@@ -60,6 +60,14 @@ When adding generator transforms in `Generator.TRANSFORMS`, use `self.func(name,
6060
Tests use a `Validator` base class (inline in `tests/test_maxcompute.py`) mirroring sqlglot's pattern:
6161
- `validate_all(sql, write={dialect: expected})` — cross-dialect transpilation assertions
6262
- `assertIsInstance(parse_one(sql, read="maxcompute"), exp.SomeClass)` — parse node assertions
63+
- **`read=` must be a dict**`read={"spark": "LOCATE(...)"}`, not `read="spark"`. Bare string is silently ignored by `validate_all`.
64+
- **Pyright false positive**`assertIsNotNone(x)` does not narrow types in Pyright; `x.field` after it shows "attribute of None" errors that are noise, not real bugs.
65+
66+
**Development is test-driven (TDD).** For every fix or feature:
67+
1. Write the failing test first and run it to confirm it fails
68+
2. Implement the minimal change to make it pass
69+
3. Run the full suite to confirm no regressions
70+
4. Commit
6371

6472
Before writing `validate_all` assertions, probe actual output first:
6573
```bash

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Registers the `maxcompute` dialect via Python entry points so that SQLGlot can p
1010
pip install sqlglot-maxcompute
1111
```
1212

13-
Requires Python ≥ 3.9 and SQLGlot ≥ 29.
13+
Requires Python ≥ 3.9 and SQLGlot ≥ 30.1.
1414

1515
## Usage
1616

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "sqlglot-maxcompute"
3-
version = "0.3.1"
3+
version = "0.4.0"
44
description = "MaxCompute dialect plugin for SQLGlot"
55
readme = "README.md"
66
license = { text = "MIT" }
@@ -9,7 +9,7 @@ authors = [
99
]
1010
requires-python = ">=3.9"
1111
dependencies = [
12-
"sqlglot>=29.0.0,<31",
12+
"sqlglot>=30.1.0,<31",
1313
]
1414
classifiers = [
1515
"Development Status :: 3 - Alpha",
@@ -38,4 +38,4 @@ dev = [
3838
testpaths = ["tests"]
3939

4040
[project.entry-points."sqlglot.dialects"]
41-
maxcompute = "sqlglot_maxcompute.maxcompute:MaxCompute"
41+
maxcompute = "sqlglot_maxcompute.dialect:MaxCompute"

src/sqlglot_maxcompute/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sqlglot_maxcompute.maxcompute import MaxCompute
1+
from sqlglot_maxcompute.dialect import MaxCompute
22
from sqlglot_maxcompute.parser import MaxComputeParser
33
from sqlglot_maxcompute.generator import MaxComputeGenerator
44

src/sqlglot_maxcompute/generator.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import typing as t
44

55
from sqlglot import exp
6-
from sqlglot.dialects.hive import Hive
6+
from sqlglot.generators.hive import HiveGenerator
77
from sqlglot.dialects.dialect import rename_func, unit_to_str
88
from sqlglot.transforms import (
99
move_schema_columns_to_partitioned_by,
@@ -13,7 +13,12 @@
1313
)
1414

1515

16-
_AUTO_PARTITION_TYPES = (exp.DateTrunc, exp.TimestampTrunc, exp.DatetimeTrunc, exp.Alias)
16+
_AUTO_PARTITION_TYPES = (
17+
exp.DateTrunc,
18+
exp.TimestampTrunc,
19+
exp.DatetimeTrunc,
20+
exp.Alias,
21+
)
1722

1823

1924
def _move_schema_columns_to_partitioned_by(expression: exp.Expr) -> exp.Expr:
@@ -25,17 +30,17 @@ def _move_schema_columns_to_partitioned_by(expression: exp.Expr) -> exp.Expr:
2530
return move_schema_columns_to_partitioned_by(expression)
2631

2732

28-
class MaxComputeGenerator(Hive.Generator):
33+
class MaxComputeGenerator(HiveGenerator):
2934
TYPE_MAPPING = {
30-
**Hive.Generator.TYPE_MAPPING,
35+
**HiveGenerator.TYPE_MAPPING,
3136
exp.DType.DATETIME: "DATETIME",
3237
exp.DType.VARCHAR: "STRING",
3338
exp.DType.CHAR: "STRING",
3439
exp.DType.TEXT: "STRING",
3540
}
3641

3742
TRANSFORMS = {
38-
**Hive.Generator.TRANSFORMS,
43+
**HiveGenerator.TRANSFORMS,
3944
exp.Create: preprocess(
4045
[
4146
remove_unique_constraints,
@@ -67,19 +72,33 @@ class MaxComputeGenerator(Hive.Generator):
6772
exp.ApproxDistinct: rename_func("APPROX_DISTINCT"),
6873
exp.ArgMax: lambda self, e: self.func("ARG_MAX", e.this, e.expression),
6974
exp.ArgMin: lambda self, e: self.func("ARG_MIN", e.this, e.expression),
75+
exp.LogicalAnd: rename_func("BOOL_AND"),
76+
exp.LogicalOr: rename_func("BOOL_OR"),
7077
# Statistical aggregate fixes (Hive emits wrong names)
7178
exp.Space: rename_func("SPACE"),
7279
exp.VariancePop: rename_func("VAR_POP"),
7380
exp.Variance: rename_func("VAR_SAMP"),
81+
# Numeric truncation: TRUNC(n, d)
82+
exp.Trunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("decimals")),
7483
# String position: MaxCompute uses INSTR(str, substr), not LOCATE(substr, str)
75-
exp.StrPosition: lambda self, e: self.func("INSTR", e.this, e.args.get("substr")),
84+
exp.StrPosition: lambda self, e: self.func(
85+
"INSTR", e.this, e.args.get("substr"), e.args.get("position")
86+
),
7687
# TO_DATE(str, fmt) returns DATETIME — modeled as StrToTime; emit TO_DATE in MaxCompute
77-
exp.StrToTime: lambda self, e: self.func("TO_DATE", e.this, e.args.get("format")),
88+
exp.StrToTime: lambda self, e: self.func(
89+
"TO_DATE", e.this, e.args.get("format")
90+
),
7891
}
7992

8093
def _dateadd_sql(
8194
self,
82-
expression: exp.TsOrDsAdd | exp.DateAdd | exp.DateSub | exp.TimestampAdd | exp.DatetimeAdd,
95+
expression: (
96+
exp.TsOrDsAdd
97+
| exp.DateAdd
98+
| exp.DateSub
99+
| exp.TimestampAdd
100+
| exp.DatetimeAdd
101+
),
83102
) -> str:
84103
unit = unit_to_str(expression) if expression.args.get("unit") else "'DAY'"
85104
delta = expression.expression
@@ -118,18 +137,26 @@ def tochar_sql(self, expression: exp.ToChar) -> str:
118137
return self.func("TO_CHAR", expression.this, expression.args.get("format"))
119138

120139
def substring_sql(self, expression: exp.Substring) -> str:
121-
return self.func("SUBSTR", expression.this, expression.args.get("start"), expression.args.get("length"))
140+
return self.func(
141+
"SUBSTR",
142+
expression.this,
143+
expression.args.get("start"),
144+
expression.args.get("length"),
145+
)
122146

123147
def extract_sql(self, expression: exp.Extract) -> str:
124148
unit = expression.this
125-
return self.func("DATEPART", expression.expression, exp.Literal.string(unit.name))
149+
return self.func(
150+
"DATEPART", expression.expression, exp.Literal.string(unit.name)
151+
)
126152

127153
def mod_sql(self, expression: exp.Mod) -> str:
128154
# Reverse the WEEKDAY parser transform: (DAYOFWEEK(x) + 5) % 7 → WEEKDAY(x)
129155
rhs = expression.expression
130156
lhs = expression.this
131157
if (
132-
isinstance(rhs, exp.Literal) and rhs.this == "7"
158+
isinstance(rhs, exp.Literal)
159+
and rhs.this == "7"
133160
and isinstance(lhs, exp.Paren)
134161
and isinstance(lhs.this, exp.Add)
135162
and isinstance(lhs.this.this, exp.DayOfWeek)
@@ -148,7 +175,9 @@ def _partitioned_by_sql(self, expression: exp.PartitionedByProperty) -> str:
148175
inner = inner.this
149176
unit = inner.args.get("unit")
150177
unit_str = unit.name.lower() if unit else ""
151-
trunc_sql = self.func("TRUNC_TIME", inner.this, exp.Literal.string(unit_str))
178+
trunc_sql = self.func(
179+
"TRUNC_TIME", inner.this, exp.Literal.string(unit_str)
180+
)
152181
return f"AUTO PARTITIONED BY ({trunc_sql}{alias_sql})"
153182
return f"PARTITIONED BY {self.sql(expression, 'this')}"
154183

@@ -159,7 +188,9 @@ def clusteredbyproperty_sql(self, expression: exp.ClusteredByProperty) -> str:
159188
def datatype_sql(self, expression: exp.DataType) -> str:
160189
# VARCHAR and CHAR map to STRING in MaxCompute, with no length parameters
161190
if expression.this in (exp.DType.VARCHAR, exp.DType.CHAR):
162-
return self.TYPE_MAPPING.get(expression.this, super().datatype_sql(expression))
191+
return self.TYPE_MAPPING.get(
192+
expression.this, super().datatype_sql(expression)
193+
)
163194
return super().datatype_sql(expression)
164195

165196
def properties_sql(self, expression: exp.Properties) -> str:

src/sqlglot_maxcompute/parser.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import typing as t
55

66
from sqlglot import exp
7-
from sqlglot.dialects.hive import Hive
7+
from sqlglot.parsers.hive import HiveParser
88
from sqlglot.dialects.dialect import build_timetostr_or_tochar
99
from sqlglot.helper import seq_get
1010
from sqlglot.tokens import TokenType
@@ -61,9 +61,9 @@ def _build_datetrunc(
6161
return exp.DateTrunc(unit=unit, this=this)
6262

6363

64-
class MaxComputeParser(Hive.Parser):
64+
class MaxComputeParser(HiveParser):
6565
FUNCTIONS = {
66-
**Hive.Parser.FUNCTIONS,
66+
**HiveParser.FUNCTIONS,
6767
# Hive overrides: MaxCompute accepts date/datetime/timestamp/string directly
6868
# without needing TsOrDsToDate wrapping
6969
"DAY": exp.Day.from_arg_list,
@@ -86,7 +86,9 @@ class MaxComputeParser(Hive.Parser):
8686
# Hive override: produce exp.DateSub so _dateadd_sql emits DATEADD(date, -n, unit)
8787
# cleanly. Hive maps DATE_SUB to TsOrDsAdd(expression=Mul(n, -1)) which generates
8888
# "3 * -1" in the output.
89-
"DATE_SUB": lambda args: exp.DateSub(this=seq_get(args, 0), expression=seq_get(args, 1)),
89+
"DATE_SUB": lambda args: exp.DateSub(
90+
this=seq_get(args, 0), expression=seq_get(args, 1)
91+
),
9092
# Date arithmetic
9193
"DATEADD": _build_dateadd,
9294
"DATEDIFF": lambda args: exp.DateDiff(
@@ -104,14 +106,24 @@ class MaxComputeParser(Hive.Parser):
104106
),
105107
"DATETRUNC": _build_datetrunc,
106108
"TRUNC_TIME": _build_datetrunc,
109+
# TRUNC(n, d) → exp.Trunc (numeric truncation)
110+
# TRUNC(dt, 'u') → _build_datetrunc (date truncation, same as TRUNC_TIME)
111+
"TRUNC": lambda args: (
112+
_build_datetrunc(args)
113+
if seq_get(args, 1) is not None and seq_get(args, 1).is_string
114+
else exp.Trunc(this=seq_get(args, 0), decimals=seq_get(args, 1))
115+
),
107116
"DAYOFMONTH": exp.DayOfMonth.from_arg_list,
108117
"DAYOFWEEK": exp.DayOfWeek.from_arg_list,
109118
"DAYOFYEAR": exp.DayOfYear.from_arg_list,
110119
"HOUR": exp.Hour.from_arg_list,
111120
"MINUTE": exp.Minute.from_arg_list,
112121
"SECOND": exp.Second.from_arg_list,
113122
"QUARTER": exp.Quarter.from_arg_list,
114-
"WEEKDAY": lambda args: exp.paren(exp.DayOfWeek(this=seq_get(args, 0)) + 5, copy=False) % 7,
123+
"WEEKDAY": lambda args: exp.paren(
124+
exp.DayOfWeek(this=seq_get(args, 0)) + 5, copy=False
125+
)
126+
% 7,
115127
"WEEKOFYEAR": exp.WeekOfYear.from_arg_list,
116128
# Last/next day
117129
"LAST_DAY": exp.LastDay.from_arg_list,
@@ -134,7 +146,9 @@ class MaxComputeParser(Hive.Parser):
134146
),
135147
"ISDATE": lambda args: exp.not_(
136148
exp.Is(
137-
this=exp.TsOrDsToDate(this=seq_get(args, 0), format=seq_get(args, 1), safe=True),
149+
this=exp.TsOrDsToDate(
150+
this=seq_get(args, 0), format=seq_get(args, 1), safe=True
151+
),
138152
expression=exp.Null(),
139153
)
140154
),
@@ -191,7 +205,7 @@ class MaxComputeParser(Hive.Parser):
191205
}
192206

193207
PROPERTY_PARSERS = {
194-
**Hive.Parser.PROPERTY_PARSERS,
208+
**HiveParser.PROPERTY_PARSERS,
195209
# LIFECYCLE n — MaxCompute table retention in days. Stored as a generic
196210
# exp.Property with a Var key so no custom expression class is needed and
197211
# sqlglot's PROPERTIES_LOCATION contract is not broken.
@@ -202,7 +216,9 @@ class MaxComputeParser(Hive.Parser):
202216
"AUTO": lambda self: self._parse_auto_partition(),
203217
}
204218

205-
def _parse_auto_partition(self) -> exp.PartitionedByProperty | exp.AutoRefreshProperty | None:
219+
def _parse_auto_partition(
220+
self,
221+
) -> exp.PartitionedByProperty | exp.AutoRefreshProperty | None:
206222
if self._match(TokenType.PARTITION_BY):
207223
self._match(TokenType.L_PAREN)
208224
expr = self._parse_conjunction()

tests/test_maxcompute.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,51 @@ def test_031_fixes(self):
229229
},
230230
)
231231

232+
def test_032_fixes(self):
233+
# Bug 1: TRUNC(n, d) — numeric truncation, not date truncation
234+
expr = self.parse_one("TRUNC(3.14, 2)")
235+
self.assertIsInstance(expr, exp.Trunc)
236+
self.validate_all(
237+
"TRUNC(3.14, 2)",
238+
write={
239+
"maxcompute": "TRUNC(3.14, 2)",
240+
},
241+
)
242+
# TRUNC with a string unit still routes to DATETRUNC
243+
expr2 = self.parse_one("TRUNC(dt, 'MONTH')")
244+
self.assertIsInstance(expr2, exp.DateTrunc)
245+
246+
# Bug 2: BOOL_AND / BOOL_OR — aggregate, not infix AND/OR
247+
self.validate_all(
248+
"SELECT BOOL_AND(flag) FROM t",
249+
write={
250+
"maxcompute": "SELECT BOOL_AND(flag) FROM t",
251+
},
252+
)
253+
self.validate_all(
254+
"SELECT BOOL_OR(flag) FROM t",
255+
write={
256+
"maxcompute": "SELECT BOOL_OR(flag) FROM t",
257+
},
258+
)
259+
260+
# Bug 3: LOCATE(sub, str, start) — start position must pass through to INSTR
261+
self.validate_all(
262+
"LOCATE('bc', 'abcd', 2)",
263+
read={"spark": "LOCATE('bc', 'abcd', 2)"},
264+
write={
265+
"maxcompute": "INSTR('abcd', 'bc', 2)",
266+
},
267+
)
268+
# Without start position, INSTR(str, sub) is unchanged
269+
self.validate_all(
270+
"LOCATE('bc', 'abcd')",
271+
read={"spark": "LOCATE('bc', 'abcd')"},
272+
write={
273+
"maxcompute": "INSTR('abcd', 'bc')",
274+
},
275+
)
276+
232277
# -------------------------------------------------------------------------
233278
# Date/time conversion
234279
# -------------------------------------------------------------------------

0 commit comments

Comments
 (0)