diff --git a/sqlglot/typing/bigquery.py b/sqlglot/typing/bigquery.py index 4026d51729..abf77581e6 100644 --- a/sqlglot/typing/bigquery.py +++ b/sqlglot/typing/bigquery.py @@ -9,6 +9,36 @@ from sqlglot.optimizer.annotate_types import TypeAnnotator +# DATE_ADD / DATE_SUB / *_TRUNC return the type of their first argument. BigQuery +# implicitly casts a string literal first arg to the function's own temporal type, +# so map each to that type (e.g. DATE_ADD('2020-01-01', ...) -> DATE, +# TIMESTAMP_TRUNC('...') -> TIMESTAMP). +_DATE_FUNC_LITERAL_TYPE: dict[type[exp.Expr], exp.DType] = { + exp.DateAdd: exp.DType.DATE, + exp.DateSub: exp.DType.DATE, + exp.DateTrunc: exp.DType.DATE, + exp.DatetimeTrunc: exp.DType.DATETIME, + exp.TimestampTrunc: exp.DType.TIMESTAMPTZ, +} + + +def _annotate_date_func(self: TypeAnnotator, expression: exp.Expr) -> exp.Expr: + """Annotate DATE_ADD / DATE_SUB / *_TRUNC, which return their first arg's type. + + A typed first argument keeps its exact type (e.g. DATE_ADD(DATETIME, ...) -> + DATETIME). For a string literal first argument, BigQuery implicitly casts it to + the function's own temporal type, so the result is that type (e.g. + DATE_ADD('2020-01-01', INTERVAL 1 DAY) -> DATE). + """ + this = expression.this + + # BigQuery rejects expressions like DATE_ADD(c, ...); it requires the first argument to be a literal + if isinstance(this, exp.Literal) and this.is_string: + return self._set_type(expression, _DATE_FUNC_LITERAL_TYPE[type(expression)]) + + return self._annotate_by_args(expression, "this") + + def _annotate_math_functions(self: TypeAnnotator, expression: exp.Expr) -> exp.Expr: """ Many BigQuery math functions such as CEIL, FLOOR etc follow this return type convention: @@ -175,9 +205,6 @@ def _annotate_array(self: TypeAnnotator, expression: exp.Array) -> exp.Array: for expr_type in { exp.ArgMax, exp.ArgMin, - exp.DateAdd, - exp.DateTrunc, - exp.DatetimeTrunc, exp.GroupConcat, exp.IgnoreNulls, exp.JSONExtract, @@ -197,12 +224,15 @@ def _annotate_array(self: TypeAnnotator, expression: exp.Array) -> exp.Array: exp.SafeNegate, exp.Sign, exp.Substring, - exp.TimestampTrunc, exp.Translate, exp.Trim, exp.Upper, } }, + **{ + expr_type: {"annotator": lambda self, e: _annotate_date_func(self, e)} + for expr_type in _DATE_FUNC_LITERAL_TYPE + }, **{ expr_type: {"returns": exp.DType.BIGINT} for expr_type in { diff --git a/tests/fixtures/optimizer/annotate_functions.sql b/tests/fixtures/optimizer/annotate_functions.sql index f8a9b9482d..f54ef2397a 100644 --- a/tests/fixtures/optimizer/annotate_functions.sql +++ b/tests/fixtures/optimizer/annotate_functions.sql @@ -2448,6 +2448,74 @@ DATETIME; DATE_ADD(DATETIME '2008-12-25 15:30:00', INTERVAL 30 MINUTE); DATETIME; +# dialect: bigquery +DATE_ADD('2008-12-25', INTERVAL 5 DAY); +DATE; + +# dialect: bigquery +DATE_TRUNC('2008-12-25', MONTH); +DATE; + +# dialect: bigquery +DATETIME_TRUNC('2008-12-25', DAY); +DATETIME; + +# dialect: bigquery +DATETIME_TRUNC('2008-12-25 15:30:00', DAY); +DATETIME; + +# dialect: bigquery +TIMESTAMP_TRUNC('2008-12-25 15:30:00', DAY); +TIMESTAMP; + +# dialect: bigquery +TIMESTAMP_TRUNC('2008-12-25', DAY); +TIMESTAMP; + +# dialect: bigquery +DATE_SUB('2008-12-25', INTERVAL 1 MONTH); +DATE; + +# dialect: bigquery +DATE_SUB(DATE '2008-12-25', INTERVAL 1 MONTH); +DATE; + +# dialect: bigquery +DATE_SUB(DATETIME '2008-12-25 15:30:00', INTERVAL 1 DAY); +DATETIME; + +# dialect: bigquery +DATE_SUB(TIMESTAMP '2008-12-25 15:30:00', INTERVAL 1 HOUR); +TIMESTAMP; + +# dialect: bigquery +DATETIME_ADD('2008-12-25 15:30:00', INTERVAL 1 DAY); +DATETIME; + +# dialect: bigquery +DATETIME_SUB('2008-12-25 15:30:00', INTERVAL 1 DAY); +DATETIME; + +# dialect: bigquery +TIMESTAMP_ADD('2008-12-25 15:30:00', INTERVAL 1 HOUR); +TIMESTAMP; + +# dialect: bigquery +TIMESTAMP_SUB('2008-12-25 15:30:00', INTERVAL 1 HOUR); +TIMESTAMP; + +# dialect: bigquery +TIME_ADD('08:50:48', INTERVAL 1 HOUR); +TIME; + +# dialect: bigquery +TIME_SUB('08:50:48', INTERVAL 1 HOUR); +TIME; + +# dialect: bigquery +TIME_TRUNC('08:50:48', HOUR); +TIME; + # dialect: bigquery UNIX_DATE(tbl.date_col); BIGINT;