|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import typing as t |
| 4 | + |
| 5 | +from sqlglot import exp |
| 6 | +from sqlglot.dialects.hive import Hive |
| 7 | +from sqlglot.dialects.dialect import rename_func, unit_to_str |
| 8 | +from sqlglot.transforms import ( |
| 9 | + move_schema_columns_to_partitioned_by, |
| 10 | + preprocess, |
| 11 | + remove_unique_constraints, |
| 12 | + ctas_with_tmp_tables_to_create_tmp_view, |
| 13 | +) |
| 14 | + |
| 15 | + |
| 16 | +_AUTO_PARTITION_TYPES = (exp.DateTrunc, exp.TimestampTrunc, exp.DatetimeTrunc, exp.Alias) |
| 17 | + |
| 18 | + |
| 19 | +def _move_schema_columns_to_partitioned_by(expression: exp.Expr) -> exp.Expr: |
| 20 | + """Like the Hive transform, but skip AUTO PARTITIONED BY (where this is a DateTrunc/Alias).""" |
| 21 | + assert isinstance(expression, exp.Create) |
| 22 | + prop = expression.find(exp.PartitionedByProperty) |
| 23 | + if prop and isinstance(prop.this, _AUTO_PARTITION_TYPES): |
| 24 | + return expression |
| 25 | + return move_schema_columns_to_partitioned_by(expression) |
| 26 | + |
| 27 | + |
| 28 | +class MaxComputeGenerator(Hive.Generator): |
| 29 | + TYPE_MAPPING = { |
| 30 | + **Hive.Generator.TYPE_MAPPING, |
| 31 | + exp.DType.DATETIME: "DATETIME", |
| 32 | + exp.DType.VARCHAR: "STRING", |
| 33 | + exp.DType.CHAR: "STRING", |
| 34 | + exp.DType.TEXT: "STRING", |
| 35 | + } |
| 36 | + |
| 37 | + TRANSFORMS = { |
| 38 | + **Hive.Generator.TRANSFORMS, |
| 39 | + exp.Create: preprocess( |
| 40 | + [ |
| 41 | + remove_unique_constraints, |
| 42 | + ctas_with_tmp_tables_to_create_tmp_view, |
| 43 | + _move_schema_columns_to_partitioned_by, |
| 44 | + ] |
| 45 | + ), |
| 46 | + exp.PartitionedByProperty: lambda self, e: self._partitioned_by_sql(e), |
| 47 | + # Date/time transforms |
| 48 | + exp.TsOrDsAdd: lambda self, e: self._dateadd_sql(e), |
| 49 | + exp.DateAdd: lambda self, e: self._dateadd_sql(e), |
| 50 | + exp.TimestampAdd: lambda self, e: self._dateadd_sql(e), |
| 51 | + exp.DatetimeAdd: lambda self, e: self._dateadd_sql(e), |
| 52 | + exp.DateSub: lambda self, e: self._dateadd_sql(e), |
| 53 | + exp.DateDiff: lambda self, e: self._datediff_sql(e), |
| 54 | + exp.DateTrunc: lambda self, e: self._datetrunc_sql(e), |
| 55 | + exp.TimestampTrunc: lambda self, e: self._datetrunc_sql(e), |
| 56 | + exp.DatetimeTrunc: lambda self, e: self._datetrunc_sql(e), |
| 57 | + exp.CurrentTimestamp: lambda self, e: "GETDATE()", |
| 58 | + exp.CurrentDatetime: lambda self, e: "NOW()", |
| 59 | + # String transforms |
| 60 | + exp.Lower: rename_func("TOLOWER"), |
| 61 | + exp.Upper: rename_func("TOUPPER"), |
| 62 | + # JSON / misc |
| 63 | + exp.ParseJSON: rename_func("FROM_JSON"), |
| 64 | + exp.CurrentUser: lambda self, e: "GET_USER_ID()", |
| 65 | + exp.UnixMillis: rename_func("TO_MILLIS"), |
| 66 | + # Aggregate |
| 67 | + exp.ApproxDistinct: rename_func("APPROX_DISTINCT"), |
| 68 | + exp.ArgMax: lambda self, e: self.func("ARG_MAX", e.this, e.expression), |
| 69 | + exp.ArgMin: lambda self, e: self.func("ARG_MIN", e.this, e.expression), |
| 70 | + # Statistical aggregate fixes (Hive emits wrong names) |
| 71 | + exp.Space: rename_func("SPACE"), |
| 72 | + exp.VariancePop: rename_func("VAR_POP"), |
| 73 | + exp.Variance: rename_func("VAR_SAMP"), |
| 74 | + # String position: MaxCompute uses INSTR(str, substr), not LOCATE(substr, str) |
| 75 | + exp.StrPosition: lambda self, e: self.func("INSTR", e.this, e.args.get("substr")), |
| 76 | + # TO_DATE(str, fmt) returns DATETIME — modeled as StrToTime; emit TO_DATE in MaxCompute |
| 77 | + exp.StrToTime: lambda self, e: self.func("TO_DATE", e.this, e.args.get("format")), |
| 78 | + } |
| 79 | + |
| 80 | + def _dateadd_sql( |
| 81 | + self, |
| 82 | + expression: exp.TsOrDsAdd | exp.DateAdd | exp.DateSub | exp.TimestampAdd | exp.DatetimeAdd, |
| 83 | + ) -> str: |
| 84 | + unit = unit_to_str(expression) if expression.args.get("unit") else "'DAY'" |
| 85 | + delta = expression.expression |
| 86 | + if isinstance(expression, exp.DateSub): |
| 87 | + # DateSub magnitude is positive; negate it so DATEADD subtracts. |
| 88 | + # Some dialects (e.g. BigQuery) store the magnitude as a string |
| 89 | + # literal — normalize to a number first so we emit -3 not -'3'. |
| 90 | + if isinstance(delta, exp.Literal) and delta.is_string: |
| 91 | + delta = exp.Literal.number(delta.this) |
| 92 | + delta = exp.Neg(this=delta) |
| 93 | + return self.func("DATEADD", expression.this, delta, unit) |
| 94 | + |
| 95 | + def _datediff_sql(self, expression: exp.DateDiff) -> str: |
| 96 | + unit = unit_to_str(expression) if expression.args.get("unit") else None |
| 97 | + return self.func("DATEDIFF", expression.this, expression.expression, unit) |
| 98 | + |
| 99 | + def _datetrunc_sql( |
| 100 | + self, expression: exp.DateTrunc | exp.TimestampTrunc | exp.DatetimeTrunc |
| 101 | + ) -> str: |
| 102 | + unit = expression.args.get("unit") |
| 103 | + # WeekStart units must be emitted as 'week(day)' string literals. |
| 104 | + # unit_to_str returns the raw node name which would produce DATETRUNC(dt, WEEK(MONDAY)) |
| 105 | + # — invalid MaxCompute SQL. Reconstruct the canonical 'week(day)' form instead. |
| 106 | + if isinstance(unit, exp.WeekStart): |
| 107 | + day = unit.this.name.lower() if unit.args.get("this") else "monday" |
| 108 | + unit_sql = exp.Literal.string(f"week({day})") |
| 109 | + else: |
| 110 | + unit_sql = unit_to_str(expression) |
| 111 | + return self.func("DATETRUNC", expression.this, unit_sql) |
| 112 | + |
| 113 | + def groupconcat_sql(self, expression: exp.GroupConcat) -> str: |
| 114 | + sep = expression.args.get("separator") or exp.Literal.string(",") |
| 115 | + return self.func("WM_CONCAT", sep, expression.this) |
| 116 | + |
| 117 | + def tochar_sql(self, expression: exp.ToChar) -> str: |
| 118 | + return self.func("TO_CHAR", expression.this, expression.args.get("format")) |
| 119 | + |
| 120 | + def substring_sql(self, expression: exp.Substring) -> str: |
| 121 | + return self.func("SUBSTR", expression.this, expression.args.get("start"), expression.args.get("length")) |
| 122 | + |
| 123 | + def extract_sql(self, expression: exp.Extract) -> str: |
| 124 | + unit = expression.this |
| 125 | + return self.func("DATEPART", expression.expression, exp.Literal.string(unit.name)) |
| 126 | + |
| 127 | + def mod_sql(self, expression: exp.Mod) -> str: |
| 128 | + # Reverse the WEEKDAY parser transform: (DAYOFWEEK(x) + 5) % 7 → WEEKDAY(x) |
| 129 | + rhs = expression.expression |
| 130 | + lhs = expression.this |
| 131 | + if ( |
| 132 | + isinstance(rhs, exp.Literal) and rhs.this == "7" |
| 133 | + and isinstance(lhs, exp.Paren) |
| 134 | + and isinstance(lhs.this, exp.Add) |
| 135 | + and isinstance(lhs.this.this, exp.DayOfWeek) |
| 136 | + and isinstance(lhs.this.expression, exp.Literal) |
| 137 | + and lhs.this.expression.this == "5" |
| 138 | + ): |
| 139 | + return self.func("WEEKDAY", lhs.this.this.this) |
| 140 | + return super().mod_sql(expression) |
| 141 | + |
| 142 | + def _partitioned_by_sql(self, expression: exp.PartitionedByProperty) -> str: |
| 143 | + inner = expression.this |
| 144 | + if isinstance(inner, _AUTO_PARTITION_TYPES): |
| 145 | + alias_sql = "" |
| 146 | + if isinstance(inner, exp.Alias): |
| 147 | + alias_sql = f" AS {inner.alias}" |
| 148 | + inner = inner.this |
| 149 | + unit = inner.args.get("unit") |
| 150 | + unit_str = unit.name.lower() if unit else "" |
| 151 | + trunc_sql = self.func("TRUNC_TIME", inner.this, exp.Literal.string(unit_str)) |
| 152 | + return f"AUTO PARTITIONED BY ({trunc_sql}{alias_sql})" |
| 153 | + return f"PARTITIONED BY {self.sql(expression, 'this')}" |
| 154 | + |
| 155 | + def clusteredbyproperty_sql(self, expression: exp.ClusteredByProperty) -> str: |
| 156 | + sql = super().clusteredbyproperty_sql(expression) |
| 157 | + return f"RANGE {sql}" if expression.args.get("range") else sql |
| 158 | + |
| 159 | + def datatype_sql(self, expression: exp.DataType) -> str: |
| 160 | + # VARCHAR and CHAR map to STRING in MaxCompute, with no length parameters |
| 161 | + if expression.this in (exp.DType.VARCHAR, exp.DType.CHAR): |
| 162 | + return self.TYPE_MAPPING.get(expression.this, super().datatype_sql(expression)) |
| 163 | + return super().datatype_sql(expression) |
| 164 | + |
| 165 | + def properties_sql(self, expression: exp.Properties) -> str: |
| 166 | + # Var-keyed exp.Property instances (e.g. LIFECYCLE 30) render as bare |
| 167 | + # KEY value after the schema. String-keyed ones stay in TBLPROPERTIES. |
| 168 | + var_keyed = [ |
| 169 | + p |
| 170 | + for p in expression.expressions |
| 171 | + if isinstance(p, exp.Property) and isinstance(p.this, exp.Var) |
| 172 | + ] |
| 173 | + other = [p for p in expression.expressions if p not in var_keyed] |
| 174 | + |
| 175 | + other_node = exp.Properties(expressions=other) |
| 176 | + other_node.parent = expression.parent |
| 177 | + base_sql = super().properties_sql(other_node) if other else "" |
| 178 | + |
| 179 | + bare_sql = " ".join(f"{p.name} {self.sql(p, 'value')}" for p in var_keyed) |
| 180 | + |
| 181 | + if base_sql and bare_sql: |
| 182 | + return f"{base_sql} {bare_sql}" |
| 183 | + return base_sql or bare_sql |
0 commit comments