diff --git a/sqlglot/expressions/functions.py b/sqlglot/expressions/functions.py index 2cfeb2bbcf..4f52069098 100644 --- a/sqlglot/expressions/functions.py +++ b/sqlglot/expressions/functions.py @@ -333,6 +333,21 @@ class MLForecast(Expression, Func): arg_types = {"this": True, "expression": False, "params_struct": False} +class AIForecast(Expression, Func): + arg_types = { + "this": True, + "data_col": False, + "timestamp_col": False, + "model": False, + "id_cols": False, + "horizon": False, + "forecast_end_timestamp": False, + "confidence_level": False, + "output_historical_time_series": False, + "context_window": False, + } + + class MLTranslate(Expression, Func): arg_types = {"this": True, "expression": True, "params_struct": True} diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 709ca80e68..cc3aa4bff8 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -4668,6 +4668,25 @@ def mltranslate_sql(self, expression: exp.MLTranslate) -> str: def mlforecast_sql(self, expression: exp.MLForecast) -> str: return self._ml_sql(expression, "FORECAST") + def aiforecast_sql(self, expression: exp.AIForecast) -> str: + this_sql = self.sql(expression, "this") + if isinstance(expression.this, exp.Table): + this_sql = f"TABLE {this_sql}" + + return self.func( + "FORECAST", + this_sql, + expression.args.get("data_col"), + expression.args.get("timestamp_col"), + expression.args.get("model"), + expression.args.get("id_cols"), + expression.args.get("horizon"), + expression.args.get("forecast_end_timestamp"), + expression.args.get("confidence_level"), + expression.args.get("output_historical_time_series"), + expression.args.get("context_window"), + ) + def featuresattime_sql(self, expression: exp.FeaturesAtTime) -> str: this_sql = self.sql(expression, "this") if isinstance(expression.this, exp.Table): diff --git a/sqlglot/parsers/bigquery.py b/sqlglot/parsers/bigquery.py index f0b99ffda5..e25451342c 100644 --- a/sqlglot/parsers/bigquery.py +++ b/sqlglot/parsers/bigquery.py @@ -285,7 +285,7 @@ class BigQueryParser(parser.Parser): "GENERATE_EMBEDDING": lambda self: self._parse_ml(exp.GenerateEmbedding), "GENERATE_TEXT_EMBEDDING": lambda self: self._parse_ml(exp.GenerateEmbedding, is_text=True), "VECTOR_SEARCH": lambda self: self._parse_vector_search(), - "FORECAST": lambda self: self._parse_ml(exp.MLForecast), + "FORECAST": lambda self: self._parse_forecast(), } NO_PAREN_FUNCTIONS: t.ClassVar = { @@ -603,6 +603,31 @@ def _parse_translate(self) -> exp.Translate | exp.MLTranslate: return exp.Translate.from_arg_list(self._parse_function_args()) + def _parse_forecast(self) -> exp.AIForecast | exp.MLForecast: + # Check if this is ML.FORECAST by looking at previous tokens. + token = seq_get(self._tokens, self._index - 4) + if token and token.text.upper() == "ML": + return self._parse_ml(exp.MLForecast) + + # AI.FORECAST is a TVF, where the first argument is either TABLE + # or a parenthesized query statement, followed by named arguments. + self._match(TokenType.TABLE) + this = self._parse_table() + if not this: + self.raise_error("Expected table or query statement") + + expr = self.expression(exp.AIForecast(this=this)) + if self._match(TokenType.COMMA): + while True: + arg = self._parse_lambda() + if arg: + expr.set(arg.this.name, arg) + + if not self._match(TokenType.COMMA): + break + + return expr + def _parse_features_at_time(self) -> exp.FeaturesAtTime: self._match(TokenType.TABLE) this = self._parse_table() diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index af43a0bce8..ae93bbbb9e 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -2414,6 +2414,13 @@ def test_ml_functions(self): "SELECT * FROM ML.FORECAST(MODEL `mydataset.mymodel`, (SELECT * FROM mydataset.query_table), STRUCT())" ) + self.validate_identity( + "SELECT * FROM AI.FORECAST(TABLE citibike_trips, data_col => 'num_trips', timestamp_col => 'date', horizon => 30)" + ) + self.validate_identity( + "SELECT * FROM AI.FORECAST((SELECT * FROM citibike_trips), data_col => 'num_trips', timestamp_col => 'date', horizon => 30)" + ) + for name in ("GENERATE_EMBEDDING", "GENERATE_TEXT_EMBEDDING"): with self.subTest(f"Testing BigQuery's ML function {name}"): ast = self.validate_identity(