Skip to content

Commit eb9dbac

Browse files
feat(bigquery): support AI.FORECAST function (#7457)
Signed-off-by: Mridankan Mandal <xerontitan90@gmail.com> Co-authored-by: Mridankan Mandal <xerontitan90@gmail.com>
1 parent e8bbc80 commit eb9dbac

File tree

4 files changed

+67
-1
lines changed

4 files changed

+67
-1
lines changed

sqlglot/expressions/functions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,21 @@ class MLForecast(Expression, Func):
333333
arg_types = {"this": True, "expression": False, "params_struct": False}
334334

335335

336+
class AIForecast(Expression, Func):
337+
arg_types = {
338+
"this": True,
339+
"data_col": False,
340+
"timestamp_col": False,
341+
"model": False,
342+
"id_cols": False,
343+
"horizon": False,
344+
"forecast_end_timestamp": False,
345+
"confidence_level": False,
346+
"output_historical_time_series": False,
347+
"context_window": False,
348+
}
349+
350+
336351
class MLTranslate(Expression, Func):
337352
arg_types = {"this": True, "expression": True, "params_struct": True}
338353

sqlglot/generator.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4668,6 +4668,25 @@ def mltranslate_sql(self, expression: exp.MLTranslate) -> str:
46684668
def mlforecast_sql(self, expression: exp.MLForecast) -> str:
46694669
return self._ml_sql(expression, "FORECAST")
46704670

4671+
def aiforecast_sql(self, expression: exp.AIForecast) -> str:
4672+
this_sql = self.sql(expression, "this")
4673+
if isinstance(expression.this, exp.Table):
4674+
this_sql = f"TABLE {this_sql}"
4675+
4676+
return self.func(
4677+
"FORECAST",
4678+
this_sql,
4679+
expression.args.get("data_col"),
4680+
expression.args.get("timestamp_col"),
4681+
expression.args.get("model"),
4682+
expression.args.get("id_cols"),
4683+
expression.args.get("horizon"),
4684+
expression.args.get("forecast_end_timestamp"),
4685+
expression.args.get("confidence_level"),
4686+
expression.args.get("output_historical_time_series"),
4687+
expression.args.get("context_window"),
4688+
)
4689+
46714690
def featuresattime_sql(self, expression: exp.FeaturesAtTime) -> str:
46724691
this_sql = self.sql(expression, "this")
46734692
if isinstance(expression.this, exp.Table):

sqlglot/parsers/bigquery.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ class BigQueryParser(parser.Parser):
285285
"GENERATE_EMBEDDING": lambda self: self._parse_ml(exp.GenerateEmbedding),
286286
"GENERATE_TEXT_EMBEDDING": lambda self: self._parse_ml(exp.GenerateEmbedding, is_text=True),
287287
"VECTOR_SEARCH": lambda self: self._parse_vector_search(),
288-
"FORECAST": lambda self: self._parse_ml(exp.MLForecast),
288+
"FORECAST": lambda self: self._parse_forecast(),
289289
}
290290

291291
NO_PAREN_FUNCTIONS: t.ClassVar = {
@@ -603,6 +603,31 @@ def _parse_translate(self) -> exp.Translate | exp.MLTranslate:
603603

604604
return exp.Translate.from_arg_list(self._parse_function_args())
605605

606+
def _parse_forecast(self) -> exp.AIForecast | exp.MLForecast:
607+
# Check if this is ML.FORECAST by looking at previous tokens.
608+
token = seq_get(self._tokens, self._index - 4)
609+
if token and token.text.upper() == "ML":
610+
return self._parse_ml(exp.MLForecast)
611+
612+
# AI.FORECAST is a TVF, where the first argument is either TABLE <table>
613+
# or a parenthesized query statement, followed by named arguments.
614+
self._match(TokenType.TABLE)
615+
this = self._parse_table()
616+
if not this:
617+
self.raise_error("Expected table or query statement")
618+
619+
expr = self.expression(exp.AIForecast(this=this))
620+
if self._match(TokenType.COMMA):
621+
while True:
622+
arg = self._parse_lambda()
623+
if arg:
624+
expr.set(arg.this.name, arg)
625+
626+
if not self._match(TokenType.COMMA):
627+
break
628+
629+
return expr
630+
606631
def _parse_features_at_time(self) -> exp.FeaturesAtTime:
607632
self._match(TokenType.TABLE)
608633
this = self._parse_table()

tests/dialects/test_bigquery.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2414,6 +2414,13 @@ def test_ml_functions(self):
24142414
"SELECT * FROM ML.FORECAST(MODEL `mydataset.mymodel`, (SELECT * FROM mydataset.query_table), STRUCT())"
24152415
)
24162416

2417+
self.validate_identity(
2418+
"SELECT * FROM AI.FORECAST(TABLE citibike_trips, data_col => 'num_trips', timestamp_col => 'date', horizon => 30)"
2419+
)
2420+
self.validate_identity(
2421+
"SELECT * FROM AI.FORECAST((SELECT * FROM citibike_trips), data_col => 'num_trips', timestamp_col => 'date', horizon => 30)"
2422+
)
2423+
24172424
for name in ("GENERATE_EMBEDDING", "GENERATE_TEXT_EMBEDDING"):
24182425
with self.subTest(f"Testing BigQuery's ML function {name}"):
24192426
ast = self.validate_identity(

0 commit comments

Comments
 (0)