diff --git a/sqlglot/expressions/functions.py b/sqlglot/expressions/functions.py index 2cfeb2bbcf..4f52069098 100644 --- a/sqlglot/expressions/functions.py +++ b/sqlglot/expressions/functions.py @@ -333,6 +333,21 @@ class MLForecast(Expression, Func): arg_types = {"this": True, "expression": False, "params_struct": False} +class AIForecast(Expression, Func): + arg_types = { + "this": True, + "data_col": False, + "timestamp_col": False, + "model": False, + "id_cols": False, + "horizon": False, + "forecast_end_timestamp": False, + "confidence_level": False, + "output_historical_time_series": False, + "context_window": False, + } + + class MLTranslate(Expression, Func): arg_types = {"this": True, "expression": True, "params_struct": True} diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 709ca80e68..cc3aa4bff8 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -4668,6 +4668,25 @@ def mltranslate_sql(self, expression: exp.MLTranslate) -> str: def mlforecast_sql(self, expression: exp.MLForecast) -> str: return self._ml_sql(expression, "FORECAST") + def aiforecast_sql(self, expression: exp.AIForecast) -> str: + this_sql = self.sql(expression, "this") + if isinstance(expression.this, exp.Table): + this_sql = f"TABLE {this_sql}" + + return self.func( + "FORECAST", + this_sql, + expression.args.get("data_col"), + expression.args.get("timestamp_col"), + expression.args.get("model"), + expression.args.get("id_cols"), + expression.args.get("horizon"), + expression.args.get("forecast_end_timestamp"), + expression.args.get("confidence_level"), + expression.args.get("output_historical_time_series"), + expression.args.get("context_window"), + ) + def featuresattime_sql(self, expression: exp.FeaturesAtTime) -> str: this_sql = self.sql(expression, "this") if isinstance(expression.this, exp.Table): diff --git a/sqlglot/parsers/bigquery.py b/sqlglot/parsers/bigquery.py index f0b99ffda5..e25451342c 100644 --- a/sqlglot/parsers/bigquery.py +++ b/sqlglot/parsers/bigquery.py @@ -285,7 +285,7 @@ class BigQueryParser(parser.Parser): "GENERATE_EMBEDDING": lambda self: self._parse_ml(exp.GenerateEmbedding), "GENERATE_TEXT_EMBEDDING": lambda self: self._parse_ml(exp.GenerateEmbedding, is_text=True), "VECTOR_SEARCH": lambda self: self._parse_vector_search(), - "FORECAST": lambda self: self._parse_ml(exp.MLForecast), + "FORECAST": lambda self: self._parse_forecast(), } NO_PAREN_FUNCTIONS: t.ClassVar = { @@ -603,6 +603,31 @@ def _parse_translate(self) -> exp.Translate | exp.MLTranslate: return exp.Translate.from_arg_list(self._parse_function_args()) + def _parse_forecast(self) -> exp.AIForecast | exp.MLForecast: + # Check if this is ML.FORECAST by looking at previous tokens. + token = seq_get(self._tokens, self._index - 4) + if token and token.text.upper() == "ML": + return self._parse_ml(exp.MLForecast) + + # AI.FORECAST is a TVF, where the first argument is either TABLE