diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 3ab40d673c..74fb8802c8 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -26,7 +26,7 @@ from sqlglot.jsonpath import ALL_JSON_PATH_PARTS, JSONPathTokenizer, parse as parse_json_path from sqlglot.parser import Parser from sqlglot.parsers.base import BaseParser -from sqlglot.time import TIMEZONES, format_time, subsecond_precision +from sqlglot.time import STRICT_TIME_FORMATS, TIMEZONES, format_time, subsecond_precision from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import new_trie from sqlglot.typing import EXPRESSION_METADATA @@ -134,6 +134,16 @@ class NormalizationStrategy(str, AutoName): """Always case-insensitive (uppercase), regardless of quotes.""" +def _with_strict_time_fallback(inverse_mapping: dict[str, str]) -> dict[str, str]: + # Dialects that define a "strict" format (e.g. Spark) keep their own mapping; + # everyone else degrades it to the lax counterpart's mapping, so the internal + # token never leaks into generated SQL. + for strict_format, lax_format in STRICT_TIME_FORMATS.items(): + inverse_mapping.setdefault(strict_format, inverse_mapping.get(lax_format, lax_format)) + + return inverse_mapping + + class _Dialect(type): _classes: dict[str, Type[Dialect]] = {} @@ -232,16 +242,21 @@ def __new__(cls, clsname, bases, attrs): cls._classes[enum.value if enum is not None else clsname.lower()] = klass klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) + klass.STRICT_TIME_TRIE = new_trie(klass.STRICT_TIME_MAPPING) + klass.LENIENT_INVERSE_TIME_TRIE = new_trie(klass.LENIENT_INVERSE_TIME_MAPPING) klass.FORMAT_TRIE = ( new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE ) # Merge class-defined INVERSE_TIME_MAPPING with auto-generated mappings # This allows dialects to define custom inverse mappings for roundtrip correctness - klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} | ( - klass.__dict__.get("INVERSE_TIME_MAPPING") or {} + klass.INVERSE_TIME_MAPPING = _with_strict_time_fallback( + {v: k for k, v in klass.TIME_MAPPING.items()} + | (klass.__dict__.get("INVERSE_TIME_MAPPING") or {}) ) klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) - klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()} + klass.INVERSE_FORMAT_MAPPING = _with_strict_time_fallback( + {v: k for k, v in klass.FORMAT_MAPPING.items()} + ) klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING) klass.INVERSE_CREATABLE_KIND_MAPPING = { @@ -412,6 +427,20 @@ class Dialect(metaclass=_Dialect): TIME_MAPPING: dict[str, str] = {} """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" + STRICT_TIME_MAPPING: dict[str, str] = {} + """ + Variant of `TIME_MAPPING` used when *parsing* a string with a format (e.g. `StrToTime`). + Lets dialects with strict parsing (e.g. Spark 3+'s zero-padded `MM`/`dd`) map those to a + distinct canonical format, preserving the roundtrip. Empty means `TIME_MAPPING` is used. + """ + + LENIENT_INVERSE_TIME_MAPPING: dict[str, str] = {} + """ + Inverse mapping used when *generating* a parse format (e.g. `StrToTime`) for dialects that + parse leniently (e.g. Spark). Maps the canonical specifiers to their lenient single-letter + forms, and the strict tokens back to the padded forms. + """ + # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Exprs-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE FORMAT_MAPPING: dict[str, str] = {} @@ -770,10 +799,12 @@ class Dialect(metaclass=_Dialect): # A trie of the time_mapping keys TIME_TRIE: dict = {} + STRICT_TIME_TRIE: dict = {} FORMAT_TRIE: dict = {} INVERSE_TIME_MAPPING: dict[str, str] = {} INVERSE_TIME_TRIE: dict = {} + LENIENT_INVERSE_TIME_TRIE: dict = {} INVERSE_FORMAT_MAPPING: dict[str, str] = {} INVERSE_FORMAT_TRIE: dict = {} @@ -966,16 +997,23 @@ def get_or_raise(cls, dialect: DialectType) -> Dialect: raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") @classmethod - def format_time(cls, expression: str | exp.Expr | None) -> exp.Expr | None: + def format_time( + cls, expression: str | exp.Expr | None, strict: bool = False + ) -> exp.Expr | None: """Converts a time format in this dialect to its equivalent Python `strftime` format.""" + if strict and cls.STRICT_TIME_MAPPING: + mapping, trie = cls.STRICT_TIME_MAPPING, cls.STRICT_TIME_TRIE + else: + mapping, trie = cls.TIME_MAPPING, cls.TIME_TRIE + if isinstance(expression, str): return exp.Literal.string( # the time formats are quoted - format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) + format_time(expression[1:-1], mapping, trie) ) if expression and expression.is_string: - return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) + return exp.Literal.string(format_time(expression.this, mapping, trie)) return expression @@ -1544,6 +1582,13 @@ def months_between_sql(self: Generator, expression: exp.MonthsBetween) -> str: return self.sql(result) +# Expressions that parse a string with a format (vs. formatting one, like TimeToStr). +# Dialects with strict parsing semantics (STRICT_TIME_MAPPING) use it for these on the +# parser side, and the corresponding generator (e.g. SparkGenerator.format_time) reuses +# this same set to emit the lenient inverse, which is what preserves the roundtrip. +STRICT_PARSE_TIME_EXPRESSIONS = (exp.StrToTime, exp.StrToDate, exp.TsOrDsToDate) + + def build_formatted_time( exp_class: Type[E], dialect_override: str | None = None, default: bool | str | None = None ) -> t.Callable[[BuilderArgs, Dialect], E]: @@ -1569,7 +1614,10 @@ def _builder(args: BuilderArgs, dialect: Dialect) -> E: if not fmt: fmt = target_dialect.TIME_FORMAT if default is True else default or None - return exp_class(this=seq_get(args, 0), format=target_dialect.format_time(fmt)) + strict = exp_class in STRICT_PARSE_TIME_EXPRESSIONS + return exp_class( + this=seq_get(args, 0), format=target_dialect.format_time(fmt, strict=strict) + ) return _builder diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 0cc9ef7d07..0c591c650b 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -16,6 +16,23 @@ class Spark(Spark2): ARRAY_FUNCS_PROPAGATES_NULLS = True EXPRESSION_METADATA = EXPRESSION_METADATA.copy() + # Spark 3+ parses MM/dd strictly (single-digit months/days don't parse), unlike the + # lax %m/%d other dialects produce. When *parsing* (StrToTime/StrToDate/...), MM/dd + # map to a distinct canonical token so the strict roundtrip is preserved; formatting + # keeps the regular padded %m/%d -> MM/dd (TIME_MAPPING is unchanged). + STRICT_TIME_MAPPING = { + **Spark2.TIME_MAPPING, + "MM": "%mstrict", + "dd": "%dstrict", + } + # Generating a parse format is lenient: %m/%d -> M/d (matching strptime), while the + # strict tokens map back to MM/dd. + LENIENT_INVERSE_TIME_MAPPING = { + **{v: k for k, v in STRICT_TIME_MAPPING.items()}, + "%m": "M", + "%d": "d", + } + class Tokenizer(Spark2.Tokenizer): STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = False diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 80deeedffc..128c268591 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -11,7 +11,7 @@ from sqlglot.expressions import apply_index_offset from sqlglot.helper import csv, name_sequence, seq_get from sqlglot.jsonpath import ALL_JSON_PATH_PARTS, JSON_PATH_PART_TRANSFORMS -from sqlglot.time import format_time +from sqlglot.time import STRICT_TIME_FORMATS, STRICT_TIME_TRIE, format_time from sqlglot.tokens import TokenType if t.TYPE_CHECKING: @@ -4052,7 +4052,15 @@ def cast_sql(self, expression: exp.Cast, safe_prefix: str | None = None) -> str: # Base implementation that excludes safe, zone, and target_type metadata args def strtotime_sql(self, expression: exp.StrToTime) -> str: - return self.func("STR_TO_TIME", expression.this, expression.args.get("format")) + # STR_TO_TIME is sqlglot's canonical form, so the format must stay canonical + # strftime - we only strip the internal "strict" tokens (e.g. Spark's %mstrict) + # rather than routing through self.format_time(), which would also rewrite every + # other specifier into the dialect's INVERSE_TIME_MAPPING. + return self.func( + "STR_TO_TIME", + expression.this, + self.format_time(expression, STRICT_TIME_FORMATS, STRICT_TIME_TRIE), + ) def currentdate_sql(self, expression: exp.CurrentDate) -> str: zone = self.sql(expression, "this") diff --git a/sqlglot/generators/spark.py b/sqlglot/generators/spark.py index 206eced7b9..52b4a5e5ca 100644 --- a/sqlglot/generators/spark.py +++ b/sqlglot/generators/spark.py @@ -1,9 +1,9 @@ from __future__ import annotations - from sqlglot import exp from sqlglot import generator from sqlglot.dialects.dialect import ( + STRICT_PARSE_TIME_EXPRESSIONS, array_append_sql, rename_func, unit_to_var, @@ -89,6 +89,21 @@ class SparkGenerator(Spark2Generator): exp.DType.SMALLMONEY: ((6, 4), ()), } + def format_time( + self, + expression: exp.Expr, + inverse_time_mapping: dict[str, str] | None = None, + inverse_time_trie: dict | None = None, + ) -> str | None: + # Spark 3+ parses these leniently, so emit M/d (not the padded MM/dd used for + # formatting) for the canonical %m/%d. The expression set is shared with the parser + # (STRICT_PARSE_TIME_EXPRESSIONS), which is what guarantees the strict roundtrip. + if isinstance(expression, STRICT_PARSE_TIME_EXPRESSIONS): + inverse_time_mapping = inverse_time_mapping or self.dialect.LENIENT_INVERSE_TIME_MAPPING + inverse_time_trie = inverse_time_trie or self.dialect.LENIENT_INVERSE_TIME_TRIE + + return super().format_time(expression, inverse_time_mapping, inverse_time_trie) + TRANSFORMS = { k: v for k, v in { diff --git a/sqlglot/time.py b/sqlglot/time.py index 520734ff16..3c1da05759 100644 --- a/sqlglot/time.py +++ b/sqlglot/time.py @@ -6,6 +6,13 @@ # https://docs.python.org/3/library/time.html#time.strftime from sqlglot.trie import TrieResult, in_trie, new_trie +# "Strict" canonical time formats round-trip in dialects that define them (e.g. +# Spark 3+'s zero-padded MM/dd, which don't parse single-digit values) and degrade +# to their lax counterpart elsewhere. These are sqlglot-internal tokens, not valid +# strftime directives, so they must be normalized away when emitting generic SQL. +STRICT_TIME_FORMATS = {"%mstrict": "%m", "%dstrict": "%d"} +STRICT_TIME_TRIE = new_trie(STRICT_TIME_FORMATS) + def format_time( string: str, mapping: dict[str, str], trie: dict[t.Any, t.Any] | None = None diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 6a0abcae06..90b5e80fa7 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -763,7 +763,7 @@ def test_time(self): "presto": "DATE_PARSE(x, '%Y-%m-%dT%T')", "drill": "TO_TIMESTAMP(x, 'yyyy-MM-dd''T''HH:mm:ss')", "redshift": "TO_TIMESTAMP(x, 'YYYY-MM-DDTHH24:MI:SS')", - "spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')", + "spark": "TO_TIMESTAMP(x, 'yyyy-M-dTHH:mm:ss')", }, ) self.validate_all( @@ -776,7 +776,7 @@ def test_time(self): "postgres": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')", "presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d')", "redshift": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')", - "spark": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')", + "spark": "TO_TIMESTAMP('2020-01-01', 'yyyy-M-d')", }, ) self.validate_all( @@ -1219,7 +1219,7 @@ def test_time(self): "starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%T')", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)", "presto": "CAST(DATE_PARSE(x, '%Y-%m-%dT%T') AS DATE)", - "spark": "TO_DATE(x, 'yyyy-MM-ddTHH:mm:ss')", + "spark": "TO_DATE(x, 'yyyy-M-dTHH:mm:ss')", "doris": "STR_TO_DATE(x, '%Y-%m-%dT%T')", }, ) @@ -1231,7 +1231,7 @@ def test_time(self): "starrocks": "STR_TO_DATE(x, '%Y-%m-%d')", "hive": "CAST(x AS DATE)", "presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)", - "spark": "TO_DATE(x)", + "spark": "TO_DATE(x, 'yyyy-M-d')", "doris": "STR_TO_DATE(x, '%Y-%m-%d')", }, ) diff --git a/tests/dialects/test_exasol.py b/tests/dialects/test_exasol.py index 1404e1a3ca..aaf6db0b85 100644 --- a/tests/dialects/test_exasol.py +++ b/tests/dialects/test_exasol.py @@ -480,9 +480,9 @@ def test_datetime_functions(self): "duckdb": "CAST(x AS DATE)", "hive": "TO_DATE(x)", "presto": "CAST(CAST(x AS TIMESTAMP) AS DATE)", - "spark": "TO_DATE(x)", + "spark": "TO_DATE(x, 'yyyy-M-d')", "snowflake": "TO_DATE(x, 'yyyy-mm-DD')", - "databricks": "TO_DATE(x)", + "databricks": "TO_DATE(x, 'yyyy-M-d')", }, ) self.validate_all( diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 3aa75f2e82..152c569103 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -306,7 +306,7 @@ def test_time(self): "duckdb": "STRPTIME(x, '%Y-%m-%d %H:%M:%S')", "presto": "DATE_PARSE(x, '%Y-%m-%d %T')", "hive": "CAST(x AS TIMESTAMP)", - "spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss')", + "spark": "TO_TIMESTAMP(x, 'yyyy-M-d HH:mm:ss')", }, ) self.validate_all( @@ -315,7 +315,7 @@ def test_time(self): "duckdb": "STRPTIME(x, '%Y-%m-%d')", "presto": "DATE_PARSE(x, '%Y-%m-%d')", "hive": "CAST(x AS TIMESTAMP)", - "spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd')", + "spark": "TO_TIMESTAMP(x, 'yyyy-M-d')", }, ) self.validate_all( @@ -330,7 +330,7 @@ def test_time(self): "duckdb": "STRPTIME(SUBSTRING(x, 1, 10), '%Y-%m-%d')", "presto": "DATE_PARSE(SUBSTR(x, 1, 10), '%Y-%m-%d')", "hive": "CAST(SUBSTRING(x, 1, 10) AS TIMESTAMP)", - "spark": "TO_TIMESTAMP(SUBSTRING(x, 1, 10), 'yyyy-MM-dd')", + "spark": "TO_TIMESTAMP(SUBSTRING(x, 1, 10), 'yyyy-M-d')", }, ) self.validate_all( @@ -339,7 +339,7 @@ def test_time(self): "duckdb": "STRPTIME(SUBSTRING(x, 1, 10), '%Y-%m-%d')", "presto": "DATE_PARSE(SUBSTR(x, 1, 10), '%Y-%m-%d')", "hive": "CAST(SUBSTRING(x, 1, 10) AS TIMESTAMP)", - "spark": "TO_TIMESTAMP(SUBSTRING(x, 1, 10), 'yyyy-MM-dd')", + "spark": "TO_TIMESTAMP(SUBSTRING(x, 1, 10), 'yyyy-M-d')", }, ) self.validate_all( diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 8eccf066cf..c33b832049 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -1839,7 +1839,7 @@ def test_snowflake(self): "bigquery": "SELECT PARSE_TIMESTAMP('%d-%m-%Y %I:%M:%S', col) FROM t", "duckdb": "SELECT STRPTIME(col, '%d-%m-%Y %I:%M:%S') FROM t", "snowflake": "SELECT TO_TIMESTAMP(col, 'DD-mm-yyyy hh12:mi:ss') FROM t", - "spark": "SELECT TO_TIMESTAMP(col, 'dd-MM-yyyy hh:mm:ss') FROM t", + "spark": "SELECT TO_TIMESTAMP(col, 'd-M-yyyy hh:mm:ss') FROM t", }, ) self.validate_all( @@ -1904,7 +1904,7 @@ def test_snowflake(self): write={ "bigquery": "SELECT PARSE_TIMESTAMP('%m/%d/%Y %T', '04/05/2013 01:02:03')", "snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/DD/yyyy hh24:mi:ss')", - "spark": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'MM/dd/yyyy HH:mm:ss')", + "spark": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'M/d/yyyy HH:mm:ss')", }, ) self.validate_all( diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 8d0b890211..e029768f82 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -659,14 +659,39 @@ def test_spark(self): }, ) self.validate_all( - "SELECT TO_TIMESTAMP('2016-12-31', 'yyyy-MM-dd')", + "SELECT TO_TIMESTAMP('2016-1-1', 'yyyy-M-d')", read={ - "duckdb": "SELECT STRPTIME('2016-12-31', '%Y-%m-%d')", + "duckdb": "SELECT STRPTIME('2016-1-1', '%Y-%m-%d')", }, + write={ + "": "SELECT STR_TO_TIME('2016-1-1', '%Y-%-m-%-d')", + "duckdb": "SELECT STRPTIME('2016-1-1', '%Y-%-m-%-d')", + "spark": "SELECT TO_TIMESTAMP('2016-1-1', 'yyyy-M-d')", + }, + ) + # Spark 3+ parses MM/dd strictly, so the strict parse format roundtrips, but + # widens to the lax %m/%d for dialects that parse leniently (e.g. duckdb). + self.validate_all( + "SELECT TO_TIMESTAMP('2016-12-31', 'yyyy-MM-dd')", write={ "": "SELECT STR_TO_TIME('2016-12-31', '%Y-%m-%d')", "duckdb": "SELECT STRPTIME('2016-12-31', '%Y-%m-%d')", "spark": "SELECT TO_TIMESTAMP('2016-12-31', 'yyyy-MM-dd')", + "databricks": "SELECT TO_TIMESTAMP('2016-12-31', 'yyyy-MM-dd')", + }, + ) + # Formatting keeps zero-padded MM/dd, unlike the lenient parsing above. + self.validate_identity("SELECT DATE_FORMAT(x, 'yyyy-MM-dd')") + # The strict canonical token must degrade in BigQuery's FORMAT clause too, + # not just INVERSE_TIME_MAPPING (it previously leaked as 'MMstrict/DDstrict'). + self.validate_all( + "SELECT TO_DATE(x, 'MM/dd/yyyy')", + write={ + "": "SELECT CAST(STR_TO_TIME(x, '%m/%d/%Y') AS DATE)", + "duckdb": "SELECT CAST(CAST(TRY_STRPTIME(x, '%m/%d/%Y') AS TIMESTAMP) AS DATE)", + "bigquery": "SELECT CAST(SAFE_CAST(x AS TIMESTAMP FORMAT 'MM/DD/YYYY') AS DATE)", + "spark": "SELECT TO_DATE(x, 'MM/dd/yyyy')", + "databricks": "SELECT TO_DATE(x, 'MM/dd/yyyy')", }, ) self.validate_all( diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py index cf0e32478c..2cf8b1d0d8 100644 --- a/tests/dialects/test_teradata.py +++ b/tests/dialects/test_teradata.py @@ -233,9 +233,9 @@ def test_cast(self): write={ "teradata": "CAST('1992-01' AS DATE FORMAT 'YYYY-DD')", "bigquery": "PARSE_DATE('%Y-%d', '1992-01')", - "databricks": "TO_DATE('1992-01', 'yyyy-dd')", + "databricks": "TO_DATE('1992-01', 'yyyy-d')", "mysql": "STR_TO_DATE('1992-01', '%Y-%d')", - "spark": "TO_DATE('1992-01', 'yyyy-dd')", + "spark": "TO_DATE('1992-01', 'yyyy-d')", "": "STR_TO_DATE('1992-01', '%Y-%d')", }, ) diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index d6040dac30..ee0ad38f87 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -1749,21 +1749,21 @@ def test_convert(self): self.validate_all( "CONVERT(DATE, x, 121)", write={ - "spark": "TO_DATE(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')", + "spark": "TO_DATE(x, 'yyyy-M-d HH:mm:ss.SSSSSS')", "tsql": "CONVERT(DATE, x, 121)", }, ) self.validate_all( "CONVERT(DATETIME, x, 121)", write={ - "spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')", + "spark": "TO_TIMESTAMP(x, 'yyyy-M-d HH:mm:ss.SSSSSS')", "tsql": "CONVERT(DATETIME, x, 121)", }, ) self.validate_all( "CONVERT(DATETIME2, x, 121)", write={ - "spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')", + "spark": "TO_TIMESTAMP(x, 'yyyy-M-d HH:mm:ss.SSSSSS')", "tsql": "CONVERT(DATETIME2, x, 121)", }, )