Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
# Changelog

## [0.4.0] - 2026-04-01

### Changed (internal refactor)

- `maxcompute.py` renamed to `dialect.py` — the coordinator class `MaxCompute` is now in `src/sqlglot_maxcompute/dialect.py`
- `MaxComputeParser` now inherits from `HiveParser` (imported from `sqlglot.parsers.hive`) instead of `Hive.Parser`
- `MaxComputeGenerator` now inherits from `HiveGenerator` (imported from `sqlglot.generators.hive`) instead of `Hive.Generator`
- `sqlglot` dependency floor raised to `>=30.1.0` (first release with split `parsers/` and `generators/` modules)

### Fixed (parser + generator correctness)

- `TRUNC(n, d)` now maps to `exp.Trunc` for numeric truncation; `TRUNC(dt, 'unit')` still routes to date truncation
- `BOOL_AND(col)` / `BOOL_OR(col)` now emit correct MaxCompute names instead of `LOGICAL_AND` / `LOGICAL_OR`
- `LOCATE(sub, str, start)` now passes the start position through to `INSTR(str, sub, start)` instead of silently dropping it

## [0.3.1] - 2026-04-01

### Fixed (parser correctness)
Expand All @@ -24,7 +39,7 @@

### Changed (internal)

- Dialect split: `maxcompute.py` now delegates to `parser.py` and `generator.py` (mirrors sqlglot's own mypyc-compile refactor)
- Dialect split: `dialect.py` now delegates to `parser.py` and `generator.py` (mirrors sqlglot's own mypyc-compile refactor)

### Tests

Expand Down
22 changes: 15 additions & 7 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,26 @@ uv run pytest tests/test_foo.py::test_bar

The dialect is split across three files in `src/sqlglot_maxcompute/`:

- **`parser.py`** — `MaxComputeParser(Hive.Parser)`: `FUNCTIONS` dict mapping MaxCompute function names to canonical `sqlglot.exp` nodes; `PROPERTY_PARSERS` for `LIFECYCLE`, `RANGE`, and `AUTO`; helper builders `_build_dateadd`, `_build_datetrunc`.
- **`generator.py`** — `MaxComputeGenerator(Hive.Generator)`: `TYPE_MAPPING`, `TRANSFORMS`, and named `_sql` methods that map canonical AST nodes back to MaxCompute SQL.
- **`maxcompute.py`** — `MaxCompute(Hive)`: slim coordinator that sets `TIME_MAPPING`/`DATE_FORMAT`/`TIME_FORMAT`, adds `Tokenizer` keywords (`EXPORT`, `LIFECYCLE`, `OPTION`), and wires `Parser = MaxComputeParser` / `Generator = MaxComputeGenerator`.
- **`parser.py`** — `MaxComputeParser(HiveParser)`: `FUNCTIONS` dict mapping MaxCompute function names to canonical `sqlglot.exp` nodes; `PROPERTY_PARSERS` for `LIFECYCLE`, `RANGE`, and `AUTO`; helper builders `_build_dateadd`, `_build_datetrunc`.
- **`generator.py`** — `MaxComputeGenerator(HiveGenerator)`: `TYPE_MAPPING`, `TRANSFORMS`, and named `_sql` methods that map canonical AST nodes back to MaxCompute SQL.
- **`dialect.py`** — `MaxCompute(Hive)`: slim coordinator that sets `TIME_MAPPING`/`DATE_FORMAT`/`TIME_FORMAT`, adds `Tokenizer` keywords (`EXPORT`, `LIFECYCLE`, `OPTION`), and wires `Parser = MaxComputeParser` / `Generator = MaxComputeGenerator`.

The dialect is registered as a plugin in `pyproject.toml` under `[project.entry-points."sqlglot.dialects"]`, so after installation it is automatically discoverable by sqlglot as `"maxcompute"`.

This split mirrors sqlglot's own mypyc-compile refactor (parsers/generators split by file) and is required for compatibility with sqlglot ≥ 31 compiled wheels.
This split mirrors sqlglot's own mypyc-compile refactor (parsers/generators split into `sqlglot.parsers.*` / `sqlglot.generators.*` modules) and requires sqlglot ≥ 30.1.0.

`local/` contains development scratch files and references — **not part of the package**:
- `scratch.py` — keyword comparison scratch script
- `sqlglot/` — full clone of the sqlglot repo for reference (expressions, dialects, generator internals); `sqlglot/posts/` contains official guides (`onboarding.md` for architecture deep-dive, `ast_primer.md` for AST tutorial). Note: local clone is newer than installed (30.0.1) — dialect parsers moved to `parsers/`, expressions split into `expressions/` package
- `sqlglot/` — full clone of the sqlglot repo for reference (expressions, dialects, generator internals); `sqlglot/posts/` contains official guides (`onboarding.md` for architecture deep-dive, `ast_primer.md` for AST tutorial). Parsers live in `parsers/`, generators in `generators/`, expressions in `expressions/` package
- `ydb-sqlglot-plugin/` — YDB dialect plugin, used as reference for how a well-behaved plugin is structured
- `maxcompute_doc/` — MaxCompute official function documentation (e.g., `date_func.md`, `func_comparison.md`)

## Implementation Status

The dialect is complete at v0.3.1:
The dialect is complete at v0.4.0:
- **Parser**: ~65 functions explicitly mapped (date/time, string, aggregate, array, map); remainder inherited from Hive.
- **Generator**: `TRANSFORMS` + named `_sql` methods for all major expression types; Hive handles the rest.
- **Tests**: 39 test methods, 180+ subtests covering parse, round-trip, and cross-dialect transpilation.
- **Tests**: 40 test methods, 186 subtests covering parse, round-trip, and cross-dialect transpilation.

## Key sqlglot patterns

Expand All @@ -60,6 +60,14 @@ When adding generator transforms in `Generator.TRANSFORMS`, use `self.func(name,
Tests use a `Validator` base class (inline in `tests/test_maxcompute.py`) mirroring sqlglot's pattern:
- `validate_all(sql, write={dialect: expected})` — cross-dialect transpilation assertions
- `assertIsInstance(parse_one(sql, read="maxcompute"), exp.SomeClass)` — parse node assertions
- **`read=` must be a dict** — `read={"spark": "LOCATE(...)"}`, not `read="spark"`. Bare string is silently ignored by `validate_all`.
- **Pyright false positive** — `assertIsNotNone(x)` does not narrow types in Pyright; `x.field` after it shows "attribute of None" errors that are noise, not real bugs.

**Development is test-driven (TDD).** For every fix or feature:
1. Write the failing test first and run it to confirm it fails
2. Implement the minimal change to make it pass
3. Run the full suite to confirm no regressions
4. Commit

Before writing `validate_all` assertions, probe actual output first:
```bash
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Registers the `maxcompute` dialect via Python entry points so that SQLGlot can p
pip install sqlglot-maxcompute
```

Requires Python ≥ 3.9 and SQLGlot ≥ 29.
Requires Python ≥ 3.9 and SQLGlot ≥ 30.1.

## Usage

Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "sqlglot-maxcompute"
version = "0.3.1"
version = "0.4.0"
description = "MaxCompute dialect plugin for SQLGlot"
readme = "README.md"
license = { text = "MIT" }
Expand All @@ -9,7 +9,7 @@ authors = [
]
requires-python = ">=3.9"
dependencies = [
"sqlglot>=29.0.0,<31",
"sqlglot>=30.1.0,<31",
]
classifiers = [
"Development Status :: 3 - Alpha",
Expand Down Expand Up @@ -38,4 +38,4 @@ dev = [
testpaths = ["tests"]

[project.entry-points."sqlglot.dialects"]
maxcompute = "sqlglot_maxcompute.maxcompute:MaxCompute"
maxcompute = "sqlglot_maxcompute.dialect:MaxCompute"
2 changes: 1 addition & 1 deletion src/sqlglot_maxcompute/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sqlglot_maxcompute.maxcompute import MaxCompute
from sqlglot_maxcompute.dialect import MaxCompute
from sqlglot_maxcompute.parser import MaxComputeParser
from sqlglot_maxcompute.generator import MaxComputeGenerator

Expand Down
File renamed without changes.
57 changes: 44 additions & 13 deletions src/sqlglot_maxcompute/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import typing as t

from sqlglot import exp
from sqlglot.dialects.hive import Hive
from sqlglot.generators.hive import HiveGenerator
from sqlglot.dialects.dialect import rename_func, unit_to_str
from sqlglot.transforms import (
move_schema_columns_to_partitioned_by,
Expand All @@ -13,7 +13,12 @@
)


_AUTO_PARTITION_TYPES = (exp.DateTrunc, exp.TimestampTrunc, exp.DatetimeTrunc, exp.Alias)
_AUTO_PARTITION_TYPES = (
exp.DateTrunc,
exp.TimestampTrunc,
exp.DatetimeTrunc,
exp.Alias,
)


def _move_schema_columns_to_partitioned_by(expression: exp.Expr) -> exp.Expr:
Expand All @@ -25,17 +30,17 @@ def _move_schema_columns_to_partitioned_by(expression: exp.Expr) -> exp.Expr:
return move_schema_columns_to_partitioned_by(expression)


class MaxComputeGenerator(Hive.Generator):
class MaxComputeGenerator(HiveGenerator):
TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING,
**HiveGenerator.TYPE_MAPPING,
exp.DType.DATETIME: "DATETIME",
exp.DType.VARCHAR: "STRING",
exp.DType.CHAR: "STRING",
exp.DType.TEXT: "STRING",
}

TRANSFORMS = {
**Hive.Generator.TRANSFORMS,
**HiveGenerator.TRANSFORMS,
exp.Create: preprocess(
[
remove_unique_constraints,
Expand Down Expand Up @@ -67,19 +72,33 @@ class MaxComputeGenerator(Hive.Generator):
exp.ApproxDistinct: rename_func("APPROX_DISTINCT"),
exp.ArgMax: lambda self, e: self.func("ARG_MAX", e.this, e.expression),
exp.ArgMin: lambda self, e: self.func("ARG_MIN", e.this, e.expression),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
# Statistical aggregate fixes (Hive emits wrong names)
exp.Space: rename_func("SPACE"),
exp.VariancePop: rename_func("VAR_POP"),
exp.Variance: rename_func("VAR_SAMP"),
# Numeric truncation: TRUNC(n, d)
exp.Trunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("decimals")),
# String position: MaxCompute uses INSTR(str, substr), not LOCATE(substr, str)
exp.StrPosition: lambda self, e: self.func("INSTR", e.this, e.args.get("substr")),
exp.StrPosition: lambda self, e: self.func(
"INSTR", e.this, e.args.get("substr"), e.args.get("position")
),
# TO_DATE(str, fmt) returns DATETIME — modeled as StrToTime; emit TO_DATE in MaxCompute
exp.StrToTime: lambda self, e: self.func("TO_DATE", e.this, e.args.get("format")),
exp.StrToTime: lambda self, e: self.func(
"TO_DATE", e.this, e.args.get("format")
),
}

def _dateadd_sql(
self,
expression: exp.TsOrDsAdd | exp.DateAdd | exp.DateSub | exp.TimestampAdd | exp.DatetimeAdd,
expression: (
exp.TsOrDsAdd
| exp.DateAdd
| exp.DateSub
| exp.TimestampAdd
| exp.DatetimeAdd
),
) -> str:
unit = unit_to_str(expression) if expression.args.get("unit") else "'DAY'"
delta = expression.expression
Expand Down Expand Up @@ -118,18 +137,26 @@ def tochar_sql(self, expression: exp.ToChar) -> str:
return self.func("TO_CHAR", expression.this, expression.args.get("format"))

def substring_sql(self, expression: exp.Substring) -> str:
return self.func("SUBSTR", expression.this, expression.args.get("start"), expression.args.get("length"))
return self.func(
"SUBSTR",
expression.this,
expression.args.get("start"),
expression.args.get("length"),
)

def extract_sql(self, expression: exp.Extract) -> str:
unit = expression.this
return self.func("DATEPART", expression.expression, exp.Literal.string(unit.name))
return self.func(
"DATEPART", expression.expression, exp.Literal.string(unit.name)
)

def mod_sql(self, expression: exp.Mod) -> str:
# Reverse the WEEKDAY parser transform: (DAYOFWEEK(x) + 5) % 7 → WEEKDAY(x)
rhs = expression.expression
lhs = expression.this
if (
isinstance(rhs, exp.Literal) and rhs.this == "7"
isinstance(rhs, exp.Literal)
and rhs.this == "7"
and isinstance(lhs, exp.Paren)
and isinstance(lhs.this, exp.Add)
and isinstance(lhs.this.this, exp.DayOfWeek)
Expand All @@ -148,7 +175,9 @@ def _partitioned_by_sql(self, expression: exp.PartitionedByProperty) -> str:
inner = inner.this
unit = inner.args.get("unit")
unit_str = unit.name.lower() if unit else ""
trunc_sql = self.func("TRUNC_TIME", inner.this, exp.Literal.string(unit_str))
trunc_sql = self.func(
"TRUNC_TIME", inner.this, exp.Literal.string(unit_str)
)
return f"AUTO PARTITIONED BY ({trunc_sql}{alias_sql})"
return f"PARTITIONED BY {self.sql(expression, 'this')}"

Expand All @@ -159,7 +188,9 @@ def clusteredbyproperty_sql(self, expression: exp.ClusteredByProperty) -> str:
def datatype_sql(self, expression: exp.DataType) -> str:
# VARCHAR and CHAR map to STRING in MaxCompute, with no length parameters
if expression.this in (exp.DType.VARCHAR, exp.DType.CHAR):
return self.TYPE_MAPPING.get(expression.this, super().datatype_sql(expression))
return self.TYPE_MAPPING.get(
expression.this, super().datatype_sql(expression)
)
return super().datatype_sql(expression)

def properties_sql(self, expression: exp.Properties) -> str:
Expand Down
32 changes: 24 additions & 8 deletions src/sqlglot_maxcompute/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import typing as t

from sqlglot import exp
from sqlglot.dialects.hive import Hive
from sqlglot.parsers.hive import HiveParser
from sqlglot.dialects.dialect import build_timetostr_or_tochar
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
Expand Down Expand Up @@ -61,9 +61,9 @@ def _build_datetrunc(
return exp.DateTrunc(unit=unit, this=this)


class MaxComputeParser(Hive.Parser):
class MaxComputeParser(HiveParser):
FUNCTIONS = {
**Hive.Parser.FUNCTIONS,
**HiveParser.FUNCTIONS,
# Hive overrides: MaxCompute accepts date/datetime/timestamp/string directly
# without needing TsOrDsToDate wrapping
"DAY": exp.Day.from_arg_list,
Expand All @@ -86,7 +86,9 @@ class MaxComputeParser(Hive.Parser):
# Hive override: produce exp.DateSub so _dateadd_sql emits DATEADD(date, -n, unit)
# cleanly. Hive maps DATE_SUB to TsOrDsAdd(expression=Mul(n, -1)) which generates
# "3 * -1" in the output.
"DATE_SUB": lambda args: exp.DateSub(this=seq_get(args, 0), expression=seq_get(args, 1)),
"DATE_SUB": lambda args: exp.DateSub(
this=seq_get(args, 0), expression=seq_get(args, 1)
),
# Date arithmetic
"DATEADD": _build_dateadd,
"DATEDIFF": lambda args: exp.DateDiff(
Expand All @@ -104,14 +106,24 @@ class MaxComputeParser(Hive.Parser):
),
"DATETRUNC": _build_datetrunc,
"TRUNC_TIME": _build_datetrunc,
# TRUNC(n, d) → exp.Trunc (numeric truncation)
# TRUNC(dt, 'u') → _build_datetrunc (date truncation, same as TRUNC_TIME)
"TRUNC": lambda args: (
_build_datetrunc(args)
if seq_get(args, 1) is not None and seq_get(args, 1).is_string
else exp.Trunc(this=seq_get(args, 0), decimals=seq_get(args, 1))
),
"DAYOFMONTH": exp.DayOfMonth.from_arg_list,
"DAYOFWEEK": exp.DayOfWeek.from_arg_list,
"DAYOFYEAR": exp.DayOfYear.from_arg_list,
"HOUR": exp.Hour.from_arg_list,
"MINUTE": exp.Minute.from_arg_list,
"SECOND": exp.Second.from_arg_list,
"QUARTER": exp.Quarter.from_arg_list,
"WEEKDAY": lambda args: exp.paren(exp.DayOfWeek(this=seq_get(args, 0)) + 5, copy=False) % 7,
"WEEKDAY": lambda args: exp.paren(
exp.DayOfWeek(this=seq_get(args, 0)) + 5, copy=False
)
% 7,
"WEEKOFYEAR": exp.WeekOfYear.from_arg_list,
# Last/next day
"LAST_DAY": exp.LastDay.from_arg_list,
Expand All @@ -134,7 +146,9 @@ class MaxComputeParser(Hive.Parser):
),
"ISDATE": lambda args: exp.not_(
exp.Is(
this=exp.TsOrDsToDate(this=seq_get(args, 0), format=seq_get(args, 1), safe=True),
this=exp.TsOrDsToDate(
this=seq_get(args, 0), format=seq_get(args, 1), safe=True
),
expression=exp.Null(),
)
),
Expand Down Expand Up @@ -191,7 +205,7 @@ class MaxComputeParser(Hive.Parser):
}

PROPERTY_PARSERS = {
**Hive.Parser.PROPERTY_PARSERS,
**HiveParser.PROPERTY_PARSERS,
# LIFECYCLE n — MaxCompute table retention in days. Stored as a generic
# exp.Property with a Var key so no custom expression class is needed and
# sqlglot's PROPERTIES_LOCATION contract is not broken.
Expand All @@ -202,7 +216,9 @@ class MaxComputeParser(Hive.Parser):
"AUTO": lambda self: self._parse_auto_partition(),
}

def _parse_auto_partition(self) -> exp.PartitionedByProperty | exp.AutoRefreshProperty | None:
def _parse_auto_partition(
self,
) -> exp.PartitionedByProperty | exp.AutoRefreshProperty | None:
if self._match(TokenType.PARTITION_BY):
self._match(TokenType.L_PAREN)
expr = self._parse_conjunction()
Expand Down
45 changes: 45 additions & 0 deletions tests/test_maxcompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,51 @@ def test_031_fixes(self):
},
)

def test_032_fixes(self):
# Bug 1: TRUNC(n, d) — numeric truncation, not date truncation
expr = self.parse_one("TRUNC(3.14, 2)")
self.assertIsInstance(expr, exp.Trunc)
self.validate_all(
"TRUNC(3.14, 2)",
write={
"maxcompute": "TRUNC(3.14, 2)",
},
)
# TRUNC with a string unit still routes to DATETRUNC
expr2 = self.parse_one("TRUNC(dt, 'MONTH')")
self.assertIsInstance(expr2, exp.DateTrunc)

# Bug 2: BOOL_AND / BOOL_OR — aggregate, not infix AND/OR
self.validate_all(
"SELECT BOOL_AND(flag) FROM t",
write={
"maxcompute": "SELECT BOOL_AND(flag) FROM t",
},
)
self.validate_all(
"SELECT BOOL_OR(flag) FROM t",
write={
"maxcompute": "SELECT BOOL_OR(flag) FROM t",
},
)

# Bug 3: LOCATE(sub, str, start) — start position must pass through to INSTR
self.validate_all(
"LOCATE('bc', 'abcd', 2)",
read={"spark": "LOCATE('bc', 'abcd', 2)"},
write={
"maxcompute": "INSTR('abcd', 'bc', 2)",
},
)
# Without start position, INSTR(str, sub) is unchanged
self.validate_all(
"LOCATE('bc', 'abcd')",
read={"spark": "LOCATE('bc', 'abcd')"},
write={
"maxcompute": "INSTR('abcd', 'bc')",
},
)

# -------------------------------------------------------------------------
# Date/time conversion
# -------------------------------------------------------------------------
Expand Down
Loading
Loading