Skip to content

Commit 58aa74e

Browse files
feat(bigframes): Support Expression objects in create_model options
This change allows the `options` parameter of `bigframes.bigquery._operations.ml.create_model` to accept BigFrames `col.Expression` objects. These expressions are compiled to SQL scalar expressions and included in the generated `CREATE MODEL` DDL statement. - Added `bigframes.core.col.Expression` type support in the `options` dict. - Updated `create_model_ddl` to handle compiling expressions using `expression_compiler`. - Added `test_create_model_expression_option` snapshot test to verify the generated "golden SQL", using an expression that calls a function on a literal value (e.g. 0.1 * 10). - Moved test imports to the top level to adhere to PEP 8 and ran `ruff format`. Co-authored-by: tswast <247555+tswast@users.noreply.github.com>
1 parent 5d874d0 commit 58aa74e

File tree

3 files changed

+10
-9
lines changed

3 files changed

+10
-9
lines changed

packages/bigframes/bigframes/bigquery/_operations/ml.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import bigframes.dataframe as dataframe
2626
import bigframes.ml.base
2727
import bigframes.session
28-
import bigframes.core.expression as ex
28+
import bigframes.core.col as col
2929
from bigframes.bigquery._operations import utils
3030

3131

@@ -52,7 +52,7 @@ def create_model(
5252
output_schema: Optional[Mapping[str, str]] = None,
5353
connection_name: Optional[str] = None,
5454
options: Optional[
55-
Mapping[str, Union[str, int, float, bool, list, "ex.Expression"]]
55+
Mapping[str, Union[str, int, float, bool, list, "col.Expression"]]
5656
] = None,
5757
training_data: Optional[Union[pd.DataFrame, dataframe.DataFrame, str]] = None,
5858
custom_holiday: Optional[Union[pd.DataFrame, dataframe.DataFrame, str]] = None,
@@ -81,7 +81,7 @@ def create_model(
8181
The OUTPUT clause, which specifies the schema of the output data.
8282
connection_name (str, optional):
8383
The connection to use for the model.
84-
options (Mapping[str, Union[str, int, float, bool, list, bigframes.core.expression.Expression]], optional):
84+
options (Mapping[str, Union[str, int, float, bool, list, bigframes.core.col.Expression]], optional):
8585
The OPTIONS clause, which specifies the model options.
8686
training_data (Union[bigframes.pandas.DataFrame, str], optional):
8787
The query or DataFrame to use for training the model.

packages/bigframes/bigframes/core/sql/ml.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from typing import Any, Dict, List, Mapping, Optional, Union
1818

19-
import bigframes.core.expression as ex
19+
import bigframes.core.col as col
2020
from bigframes.core.compile.sqlglot import sql as sg_sql
2121
from bigframes.core.compile.sqlglot.expression_compiler import expression_compiler
2222

@@ -31,7 +31,7 @@ def create_model_ddl(
3131
output_schema: Optional[Mapping[str, str]] = None,
3232
connection_name: Optional[str] = None,
3333
options: Optional[
34-
Mapping[str, Union[str, int, float, bool, list, "ex.Expression"]]
34+
Mapping[str, Union[str, int, float, bool, list, "col.Expression"]]
3535
] = None,
3636
training_data: Optional[str] = None,
3737
custom_holiday: Optional[str] = None,
@@ -74,8 +74,8 @@ def create_model_ddl(
7474
if options:
7575
rendered_options = []
7676
for option_name, option_value in options.items():
77-
if isinstance(option_value, ex.Expression):
78-
sg_expr = expression_compiler.compile_expression(option_value)
77+
if isinstance(option_value, col.Expression):
78+
sg_expr = expression_compiler.compile_expression(option_value._value)
7979
rendered_val = sg_sql.to_sql(sg_expr)
8080
elif isinstance(option_value, (list, tuple)):
8181
# Handle list options like model_registry="vertex_ai"

packages/bigframes/tests/unit/core/sql/test_ml.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import pytest
1616

17+
import bigframes.core.col as col
1718
import bigframes.core.expression as ex
1819
import bigframes.core.sql.ml
1920
import bigframes.dtypes as dtypes
@@ -105,8 +106,8 @@ def test_create_model_expression_option(snapshot):
105106
# e.g. 0.1 * 10
106107
literal_expr = ex.ScalarConstantExpression(0.1, dtypes.FLOAT_DTYPE)
107108
multiplier_expr = ex.ScalarConstantExpression(10, dtypes.INT_DTYPE)
108-
math_expr = ex.OpExpression(
109-
op=numeric_ops.mul_op, inputs=(literal_expr, multiplier_expr)
109+
math_expr = col.Expression(
110+
ex.OpExpression(op=numeric_ops.mul_op, inputs=(literal_expr, multiplier_expr))
110111
)
111112

112113
sql = bigframes.core.sql.ml.create_model_ddl(

0 commit comments

Comments
 (0)