diff --git a/CHANGELOG.md b/CHANGELOG.md index 140061b..aface07 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,20 @@ # Changelog +## [0.4.0] - 2026-04-01 + +### Changed (internal refactor) + +- `maxcompute.py` renamed to `dialect.py` — the coordinator class `MaxCompute` is now in `src/sqlglot_maxcompute/dialect.py` +- `MaxComputeParser` now inherits from `HiveParser` (imported from `sqlglot.parsers.hive`) instead of `Hive.Parser` +- `MaxComputeGenerator` now inherits from `HiveGenerator` (imported from `sqlglot.generators.hive`) instead of `Hive.Generator` +- `sqlglot` dependency floor raised to `>=30.1.0` (first release with split `parsers/` and `generators/` modules) + +### Fixed (parser + generator correctness) + +- `TRUNC(n, d)` now maps to `exp.Trunc` for numeric truncation; `TRUNC(dt, 'unit')` still routes to date truncation +- `BOOL_AND(col)` / `BOOL_OR(col)` now emit correct MaxCompute names instead of `LOGICAL_AND` / `LOGICAL_OR` +- `LOCATE(sub, str, start)` now passes the start position through to `INSTR(str, sub, start)` instead of silently dropping it + ## [0.3.1] - 2026-04-01 ### Fixed (parser correctness) @@ -24,7 +39,7 @@ ### Changed (internal) -- Dialect split: `maxcompute.py` now delegates to `parser.py` and `generator.py` (mirrors sqlglot's own mypyc-compile refactor) +- Dialect split: `dialect.py` now delegates to `parser.py` and `generator.py` (mirrors sqlglot's own mypyc-compile refactor) ### Tests diff --git a/CLAUDE.md b/CLAUDE.md index c328199..bf256ba 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -28,26 +28,26 @@ uv run pytest tests/test_foo.py::test_bar The dialect is split across three files in `src/sqlglot_maxcompute/`: -- **`parser.py`** — `MaxComputeParser(Hive.Parser)`: `FUNCTIONS` dict mapping MaxCompute function names to canonical `sqlglot.exp` nodes; `PROPERTY_PARSERS` for `LIFECYCLE`, `RANGE`, and `AUTO`; helper builders `_build_dateadd`, `_build_datetrunc`. -- **`generator.py`** — `MaxComputeGenerator(Hive.Generator)`: `TYPE_MAPPING`, `TRANSFORMS`, and named `_sql` methods that map canonical AST nodes back to MaxCompute SQL. -- **`maxcompute.py`** — `MaxCompute(Hive)`: slim coordinator that sets `TIME_MAPPING`/`DATE_FORMAT`/`TIME_FORMAT`, adds `Tokenizer` keywords (`EXPORT`, `LIFECYCLE`, `OPTION`), and wires `Parser = MaxComputeParser` / `Generator = MaxComputeGenerator`. +- **`parser.py`** — `MaxComputeParser(HiveParser)`: `FUNCTIONS` dict mapping MaxCompute function names to canonical `sqlglot.exp` nodes; `PROPERTY_PARSERS` for `LIFECYCLE`, `RANGE`, and `AUTO`; helper builders `_build_dateadd`, `_build_datetrunc`. +- **`generator.py`** — `MaxComputeGenerator(HiveGenerator)`: `TYPE_MAPPING`, `TRANSFORMS`, and named `_sql` methods that map canonical AST nodes back to MaxCompute SQL. +- **`dialect.py`** — `MaxCompute(Hive)`: slim coordinator that sets `TIME_MAPPING`/`DATE_FORMAT`/`TIME_FORMAT`, adds `Tokenizer` keywords (`EXPORT`, `LIFECYCLE`, `OPTION`), and wires `Parser = MaxComputeParser` / `Generator = MaxComputeGenerator`. The dialect is registered as a plugin in `pyproject.toml` under `[project.entry-points."sqlglot.dialects"]`, so after installation it is automatically discoverable by sqlglot as `"maxcompute"`. -This split mirrors sqlglot's own mypyc-compile refactor (parsers/generators split by file) and is required for compatibility with sqlglot ≥ 31 compiled wheels. +This split mirrors sqlglot's own mypyc-compile refactor (parsers/generators split into `sqlglot.parsers.*` / `sqlglot.generators.*` modules) and requires sqlglot ≥ 30.1.0. `local/` contains development scratch files and references — **not part of the package**: - `scratch.py` — keyword comparison scratch script -- `sqlglot/` — full clone of the sqlglot repo for reference (expressions, dialects, generator internals); `sqlglot/posts/` contains official guides (`onboarding.md` for architecture deep-dive, `ast_primer.md` for AST tutorial). Note: local clone is newer than installed (30.0.1) — dialect parsers moved to `parsers/`, expressions split into `expressions/` package +- `sqlglot/` — full clone of the sqlglot repo for reference (expressions, dialects, generator internals); `sqlglot/posts/` contains official guides (`onboarding.md` for architecture deep-dive, `ast_primer.md` for AST tutorial). Parsers live in `parsers/`, generators in `generators/`, expressions in `expressions/` package - `ydb-sqlglot-plugin/` — YDB dialect plugin, used as reference for how a well-behaved plugin is structured - `maxcompute_doc/` — MaxCompute official function documentation (e.g., `date_func.md`, `func_comparison.md`) ## Implementation Status -The dialect is complete at v0.3.1: +The dialect is complete at v0.4.0: - **Parser**: ~65 functions explicitly mapped (date/time, string, aggregate, array, map); remainder inherited from Hive. - **Generator**: `TRANSFORMS` + named `_sql` methods for all major expression types; Hive handles the rest. -- **Tests**: 39 test methods, 180+ subtests covering parse, round-trip, and cross-dialect transpilation. +- **Tests**: 40 test methods, 186 subtests covering parse, round-trip, and cross-dialect transpilation. ## Key sqlglot patterns @@ -60,6 +60,14 @@ When adding generator transforms in `Generator.TRANSFORMS`, use `self.func(name, Tests use a `Validator` base class (inline in `tests/test_maxcompute.py`) mirroring sqlglot's pattern: - `validate_all(sql, write={dialect: expected})` — cross-dialect transpilation assertions - `assertIsInstance(parse_one(sql, read="maxcompute"), exp.SomeClass)` — parse node assertions +- **`read=` must be a dict** — `read={"spark": "LOCATE(...)"}`, not `read="spark"`. Bare string is silently ignored by `validate_all`. +- **Pyright false positive** — `assertIsNotNone(x)` does not narrow types in Pyright; `x.field` after it shows "attribute of None" errors that are noise, not real bugs. + +**Development is test-driven (TDD).** For every fix or feature: +1. Write the failing test first and run it to confirm it fails +2. Implement the minimal change to make it pass +3. Run the full suite to confirm no regressions +4. Commit Before writing `validate_all` assertions, probe actual output first: ```bash diff --git a/README.md b/README.md index b8a04a7..1957390 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ Registers the `maxcompute` dialect via Python entry points so that SQLGlot can p pip install sqlglot-maxcompute ``` -Requires Python ≥ 3.9 and SQLGlot ≥ 29. +Requires Python ≥ 3.9 and SQLGlot ≥ 30.1. ## Usage diff --git a/pyproject.toml b/pyproject.toml index 299de56..f6ead6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "sqlglot-maxcompute" -version = "0.3.1" +version = "0.4.0" description = "MaxCompute dialect plugin for SQLGlot" readme = "README.md" license = { text = "MIT" } @@ -9,7 +9,7 @@ authors = [ ] requires-python = ">=3.9" dependencies = [ - "sqlglot>=29.0.0,<31", + "sqlglot>=30.1.0,<31", ] classifiers = [ "Development Status :: 3 - Alpha", @@ -38,4 +38,4 @@ dev = [ testpaths = ["tests"] [project.entry-points."sqlglot.dialects"] -maxcompute = "sqlglot_maxcompute.maxcompute:MaxCompute" +maxcompute = "sqlglot_maxcompute.dialect:MaxCompute" diff --git a/src/sqlglot_maxcompute/__init__.py b/src/sqlglot_maxcompute/__init__.py index 8ac5f2e..33472f0 100644 --- a/src/sqlglot_maxcompute/__init__.py +++ b/src/sqlglot_maxcompute/__init__.py @@ -1,4 +1,4 @@ -from sqlglot_maxcompute.maxcompute import MaxCompute +from sqlglot_maxcompute.dialect import MaxCompute from sqlglot_maxcompute.parser import MaxComputeParser from sqlglot_maxcompute.generator import MaxComputeGenerator diff --git a/src/sqlglot_maxcompute/maxcompute.py b/src/sqlglot_maxcompute/dialect.py similarity index 100% rename from src/sqlglot_maxcompute/maxcompute.py rename to src/sqlglot_maxcompute/dialect.py diff --git a/src/sqlglot_maxcompute/generator.py b/src/sqlglot_maxcompute/generator.py index 4dc9145..c8d82f7 100644 --- a/src/sqlglot_maxcompute/generator.py +++ b/src/sqlglot_maxcompute/generator.py @@ -3,7 +3,7 @@ import typing as t from sqlglot import exp -from sqlglot.dialects.hive import Hive +from sqlglot.generators.hive import HiveGenerator from sqlglot.dialects.dialect import rename_func, unit_to_str from sqlglot.transforms import ( move_schema_columns_to_partitioned_by, @@ -13,7 +13,12 @@ ) -_AUTO_PARTITION_TYPES = (exp.DateTrunc, exp.TimestampTrunc, exp.DatetimeTrunc, exp.Alias) +_AUTO_PARTITION_TYPES = ( + exp.DateTrunc, + exp.TimestampTrunc, + exp.DatetimeTrunc, + exp.Alias, +) def _move_schema_columns_to_partitioned_by(expression: exp.Expr) -> exp.Expr: @@ -25,9 +30,9 @@ def _move_schema_columns_to_partitioned_by(expression: exp.Expr) -> exp.Expr: return move_schema_columns_to_partitioned_by(expression) -class MaxComputeGenerator(Hive.Generator): +class MaxComputeGenerator(HiveGenerator): TYPE_MAPPING = { - **Hive.Generator.TYPE_MAPPING, + **HiveGenerator.TYPE_MAPPING, exp.DType.DATETIME: "DATETIME", exp.DType.VARCHAR: "STRING", exp.DType.CHAR: "STRING", @@ -35,7 +40,7 @@ class MaxComputeGenerator(Hive.Generator): } TRANSFORMS = { - **Hive.Generator.TRANSFORMS, + **HiveGenerator.TRANSFORMS, exp.Create: preprocess( [ remove_unique_constraints, @@ -67,19 +72,33 @@ class MaxComputeGenerator(Hive.Generator): exp.ApproxDistinct: rename_func("APPROX_DISTINCT"), exp.ArgMax: lambda self, e: self.func("ARG_MAX", e.this, e.expression), exp.ArgMin: lambda self, e: self.func("ARG_MIN", e.this, e.expression), + exp.LogicalAnd: rename_func("BOOL_AND"), + exp.LogicalOr: rename_func("BOOL_OR"), # Statistical aggregate fixes (Hive emits wrong names) exp.Space: rename_func("SPACE"), exp.VariancePop: rename_func("VAR_POP"), exp.Variance: rename_func("VAR_SAMP"), + # Numeric truncation: TRUNC(n, d) + exp.Trunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("decimals")), # String position: MaxCompute uses INSTR(str, substr), not LOCATE(substr, str) - exp.StrPosition: lambda self, e: self.func("INSTR", e.this, e.args.get("substr")), + exp.StrPosition: lambda self, e: self.func( + "INSTR", e.this, e.args.get("substr"), e.args.get("position") + ), # TO_DATE(str, fmt) returns DATETIME — modeled as StrToTime; emit TO_DATE in MaxCompute - exp.StrToTime: lambda self, e: self.func("TO_DATE", e.this, e.args.get("format")), + exp.StrToTime: lambda self, e: self.func( + "TO_DATE", e.this, e.args.get("format") + ), } def _dateadd_sql( self, - expression: exp.TsOrDsAdd | exp.DateAdd | exp.DateSub | exp.TimestampAdd | exp.DatetimeAdd, + expression: ( + exp.TsOrDsAdd + | exp.DateAdd + | exp.DateSub + | exp.TimestampAdd + | exp.DatetimeAdd + ), ) -> str: unit = unit_to_str(expression) if expression.args.get("unit") else "'DAY'" delta = expression.expression @@ -118,18 +137,26 @@ def tochar_sql(self, expression: exp.ToChar) -> str: return self.func("TO_CHAR", expression.this, expression.args.get("format")) def substring_sql(self, expression: exp.Substring) -> str: - return self.func("SUBSTR", expression.this, expression.args.get("start"), expression.args.get("length")) + return self.func( + "SUBSTR", + expression.this, + expression.args.get("start"), + expression.args.get("length"), + ) def extract_sql(self, expression: exp.Extract) -> str: unit = expression.this - return self.func("DATEPART", expression.expression, exp.Literal.string(unit.name)) + return self.func( + "DATEPART", expression.expression, exp.Literal.string(unit.name) + ) def mod_sql(self, expression: exp.Mod) -> str: # Reverse the WEEKDAY parser transform: (DAYOFWEEK(x) + 5) % 7 → WEEKDAY(x) rhs = expression.expression lhs = expression.this if ( - isinstance(rhs, exp.Literal) and rhs.this == "7" + isinstance(rhs, exp.Literal) + and rhs.this == "7" and isinstance(lhs, exp.Paren) and isinstance(lhs.this, exp.Add) and isinstance(lhs.this.this, exp.DayOfWeek) @@ -148,7 +175,9 @@ def _partitioned_by_sql(self, expression: exp.PartitionedByProperty) -> str: inner = inner.this unit = inner.args.get("unit") unit_str = unit.name.lower() if unit else "" - trunc_sql = self.func("TRUNC_TIME", inner.this, exp.Literal.string(unit_str)) + trunc_sql = self.func( + "TRUNC_TIME", inner.this, exp.Literal.string(unit_str) + ) return f"AUTO PARTITIONED BY ({trunc_sql}{alias_sql})" return f"PARTITIONED BY {self.sql(expression, 'this')}" @@ -159,7 +188,9 @@ def clusteredbyproperty_sql(self, expression: exp.ClusteredByProperty) -> str: def datatype_sql(self, expression: exp.DataType) -> str: # VARCHAR and CHAR map to STRING in MaxCompute, with no length parameters if expression.this in (exp.DType.VARCHAR, exp.DType.CHAR): - return self.TYPE_MAPPING.get(expression.this, super().datatype_sql(expression)) + return self.TYPE_MAPPING.get( + expression.this, super().datatype_sql(expression) + ) return super().datatype_sql(expression) def properties_sql(self, expression: exp.Properties) -> str: diff --git a/src/sqlglot_maxcompute/parser.py b/src/sqlglot_maxcompute/parser.py index 60c4b00..e11922c 100644 --- a/src/sqlglot_maxcompute/parser.py +++ b/src/sqlglot_maxcompute/parser.py @@ -4,7 +4,7 @@ import typing as t from sqlglot import exp -from sqlglot.dialects.hive import Hive +from sqlglot.parsers.hive import HiveParser from sqlglot.dialects.dialect import build_timetostr_or_tochar from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -61,9 +61,9 @@ def _build_datetrunc( return exp.DateTrunc(unit=unit, this=this) -class MaxComputeParser(Hive.Parser): +class MaxComputeParser(HiveParser): FUNCTIONS = { - **Hive.Parser.FUNCTIONS, + **HiveParser.FUNCTIONS, # Hive overrides: MaxCompute accepts date/datetime/timestamp/string directly # without needing TsOrDsToDate wrapping "DAY": exp.Day.from_arg_list, @@ -86,7 +86,9 @@ class MaxComputeParser(Hive.Parser): # Hive override: produce exp.DateSub so _dateadd_sql emits DATEADD(date, -n, unit) # cleanly. Hive maps DATE_SUB to TsOrDsAdd(expression=Mul(n, -1)) which generates # "3 * -1" in the output. - "DATE_SUB": lambda args: exp.DateSub(this=seq_get(args, 0), expression=seq_get(args, 1)), + "DATE_SUB": lambda args: exp.DateSub( + this=seq_get(args, 0), expression=seq_get(args, 1) + ), # Date arithmetic "DATEADD": _build_dateadd, "DATEDIFF": lambda args: exp.DateDiff( @@ -104,6 +106,13 @@ class MaxComputeParser(Hive.Parser): ), "DATETRUNC": _build_datetrunc, "TRUNC_TIME": _build_datetrunc, + # TRUNC(n, d) → exp.Trunc (numeric truncation) + # TRUNC(dt, 'u') → _build_datetrunc (date truncation, same as TRUNC_TIME) + "TRUNC": lambda args: ( + _build_datetrunc(args) + if seq_get(args, 1) is not None and seq_get(args, 1).is_string + else exp.Trunc(this=seq_get(args, 0), decimals=seq_get(args, 1)) + ), "DAYOFMONTH": exp.DayOfMonth.from_arg_list, "DAYOFWEEK": exp.DayOfWeek.from_arg_list, "DAYOFYEAR": exp.DayOfYear.from_arg_list, @@ -111,7 +120,10 @@ class MaxComputeParser(Hive.Parser): "MINUTE": exp.Minute.from_arg_list, "SECOND": exp.Second.from_arg_list, "QUARTER": exp.Quarter.from_arg_list, - "WEEKDAY": lambda args: exp.paren(exp.DayOfWeek(this=seq_get(args, 0)) + 5, copy=False) % 7, + "WEEKDAY": lambda args: exp.paren( + exp.DayOfWeek(this=seq_get(args, 0)) + 5, copy=False + ) + % 7, "WEEKOFYEAR": exp.WeekOfYear.from_arg_list, # Last/next day "LAST_DAY": exp.LastDay.from_arg_list, @@ -134,7 +146,9 @@ class MaxComputeParser(Hive.Parser): ), "ISDATE": lambda args: exp.not_( exp.Is( - this=exp.TsOrDsToDate(this=seq_get(args, 0), format=seq_get(args, 1), safe=True), + this=exp.TsOrDsToDate( + this=seq_get(args, 0), format=seq_get(args, 1), safe=True + ), expression=exp.Null(), ) ), @@ -191,7 +205,7 @@ class MaxComputeParser(Hive.Parser): } PROPERTY_PARSERS = { - **Hive.Parser.PROPERTY_PARSERS, + **HiveParser.PROPERTY_PARSERS, # LIFECYCLE n — MaxCompute table retention in days. Stored as a generic # exp.Property with a Var key so no custom expression class is needed and # sqlglot's PROPERTIES_LOCATION contract is not broken. @@ -202,7 +216,9 @@ class MaxComputeParser(Hive.Parser): "AUTO": lambda self: self._parse_auto_partition(), } - def _parse_auto_partition(self) -> exp.PartitionedByProperty | exp.AutoRefreshProperty | None: + def _parse_auto_partition( + self, + ) -> exp.PartitionedByProperty | exp.AutoRefreshProperty | None: if self._match(TokenType.PARTITION_BY): self._match(TokenType.L_PAREN) expr = self._parse_conjunction() diff --git a/tests/test_maxcompute.py b/tests/test_maxcompute.py index cf45683..d309d68 100644 --- a/tests/test_maxcompute.py +++ b/tests/test_maxcompute.py @@ -229,6 +229,51 @@ def test_031_fixes(self): }, ) + def test_032_fixes(self): + # Bug 1: TRUNC(n, d) — numeric truncation, not date truncation + expr = self.parse_one("TRUNC(3.14, 2)") + self.assertIsInstance(expr, exp.Trunc) + self.validate_all( + "TRUNC(3.14, 2)", + write={ + "maxcompute": "TRUNC(3.14, 2)", + }, + ) + # TRUNC with a string unit still routes to DATETRUNC + expr2 = self.parse_one("TRUNC(dt, 'MONTH')") + self.assertIsInstance(expr2, exp.DateTrunc) + + # Bug 2: BOOL_AND / BOOL_OR — aggregate, not infix AND/OR + self.validate_all( + "SELECT BOOL_AND(flag) FROM t", + write={ + "maxcompute": "SELECT BOOL_AND(flag) FROM t", + }, + ) + self.validate_all( + "SELECT BOOL_OR(flag) FROM t", + write={ + "maxcompute": "SELECT BOOL_OR(flag) FROM t", + }, + ) + + # Bug 3: LOCATE(sub, str, start) — start position must pass through to INSTR + self.validate_all( + "LOCATE('bc', 'abcd', 2)", + read={"spark": "LOCATE('bc', 'abcd', 2)"}, + write={ + "maxcompute": "INSTR('abcd', 'bc', 2)", + }, + ) + # Without start position, INSTR(str, sub) is unchanged + self.validate_all( + "LOCATE('bc', 'abcd')", + read={"spark": "LOCATE('bc', 'abcd')"}, + write={ + "maxcompute": "INSTR('abcd', 'bc')", + }, + ) + # ------------------------------------------------------------------------- # Date/time conversion # ------------------------------------------------------------------------- diff --git a/uv.lock b/uv.lock index b1d7206..8470128 100644 --- a/uv.lock +++ b/uv.lock @@ -122,16 +122,16 @@ wheels = [ [[package]] name = "sqlglot" -version = "30.0.1" +version = "30.1.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/2c/32/ffa8390ac039de6e18e6874b1464c4012db78d9a15790d0c56c2bf5d65bb/sqlglot-30.0.1.tar.gz", hash = "sha256:1191cc37654c944b9a1d020347b9e435e3b39bdbade9129f82aa5827e3641332", size = 5793328, upload-time = "2026-03-16T22:07:33.565Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/ae/afee950eff42a9c8ceab4a2e25abfeaa8278c578f967201824287cf530ce/sqlglot-30.1.0.tar.gz", hash = "sha256:7593aea85349c577b269d540ba245024f91464afdcf61c6ef7765f4691c46ef8", size = 5812093, upload-time = "2026-03-26T19:25:45.065Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4f/6a/e7cf2f648d7217359cd21129101208218d6245f9863d36b2c9049211676f/sqlglot-30.0.1-py3-none-any.whl", hash = "sha256:379bb16020573aa7fa4730b9c04d5ee79d1c4cf50f8d203d6fb03e49ac1e4ff1", size = 648788, upload-time = "2026-03-16T22:07:31.239Z" }, + { url = "https://files.pythonhosted.org/packages/29/31/f1cad1972a8eb4b1a9bc904e4a8d440af1eef064160fe10ba0ae81f4693f/sqlglot-30.1.0-py3-none-any.whl", hash = "sha256:6c2d58d0cc68b5f96900058e8866ef4959f89f9e66e4096e0ba746830dda4f40", size = 665823, upload-time = "2026-03-26T19:25:42.794Z" }, ] [[package]] name = "sqlglot-maxcompute" -version = "0.3.1" +version = "0.3.2" source = { editable = "." } dependencies = [ { name = "sqlglot" }, @@ -144,7 +144,7 @@ dev = [ ] [package.metadata] -requires-dist = [{ name = "sqlglot", specifier = ">=29.0.0,<31" }] +requires-dist = [{ name = "sqlglot", specifier = ">=30.1.0,<31" }] [package.metadata.requires-dev] dev = [{ name = "pytest", specifier = ">=7.0" }]