Skip to content

Commit b5e8a06

Browse files
feat(bigframes): Support Expression objects in create_model options
This change allows the `options` parameter of `bigframes.bigquery._operations.ml.create_model` to accept BigFrames `Expression` objects. These expressions are compiled to SQL scalar expressions and included in the generated `CREATE MODEL` DDL statement. - Added `bigframes.core.expression.Expression` type support in the `options` dict. - Updated `create_model_ddl` to handle compiling expressions using `expression_compiler`. - Added `test_create_model_expression_option` snapshot test to verify the generated "golden SQL". Co-authored-by: tswast <247555+tswast@users.noreply.github.com>
1 parent d3d6840 commit b5e8a06

File tree

4 files changed

+31
-4
lines changed

4 files changed

+31
-4
lines changed

packages/bigframes/bigframes/bigquery/_operations/ml.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import bigframes.dataframe as dataframe
2626
import bigframes.ml.base
2727
import bigframes.session
28+
import bigframes.core.expression as ex
2829
from bigframes.bigquery._operations import utils
2930

3031

@@ -50,7 +51,9 @@ def create_model(
5051
input_schema: Optional[Mapping[str, str]] = None,
5152
output_schema: Optional[Mapping[str, str]] = None,
5253
connection_name: Optional[str] = None,
53-
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
54+
options: Optional[
55+
Mapping[str, Union[str, int, float, bool, list, "ex.Expression"]]
56+
] = None,
5457
training_data: Optional[Union[pd.DataFrame, dataframe.DataFrame, str]] = None,
5558
custom_holiday: Optional[Union[pd.DataFrame, dataframe.DataFrame, str]] = None,
5659
session: Optional[bigframes.session.Session] = None,
@@ -78,7 +81,7 @@ def create_model(
7881
The OUTPUT clause, which specifies the schema of the output data.
7982
connection_name (str, optional):
8083
The connection to use for the model.
81-
options (Mapping[str, Union[str, int, float, bool, list]], optional):
84+
options (Mapping[str, Union[str, int, float, bool, list, bigframes.core.expression.Expression]], optional):
8285
The OPTIONS clause, which specifies the model options.
8386
training_data (Union[bigframes.pandas.DataFrame, str], optional):
8487
The query or DataFrame to use for training the model.

packages/bigframes/bigframes/core/sql/ml.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
from typing import Any, Dict, List, Mapping, Optional, Union
1818

19+
import bigframes.core.expression as ex
1920
from bigframes.core.compile.sqlglot import sql as sg_sql
21+
from bigframes.core.compile.sqlglot.expression_compiler import expression_compiler
2022

2123

2224
def create_model_ddl(
@@ -28,7 +30,9 @@ def create_model_ddl(
2830
input_schema: Optional[Mapping[str, str]] = None,
2931
output_schema: Optional[Mapping[str, str]] = None,
3032
connection_name: Optional[str] = None,
31-
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
33+
options: Optional[
34+
Mapping[str, Union[str, int, float, bool, list, "ex.Expression"]]
35+
] = None,
3236
training_data: Optional[str] = None,
3337
custom_holiday: Optional[str] = None,
3438
) -> str:
@@ -70,7 +74,10 @@ def create_model_ddl(
7074
if options:
7175
rendered_options = []
7276
for option_name, option_value in options.items():
73-
if isinstance(option_value, (list, tuple)):
77+
if isinstance(option_value, ex.Expression):
78+
sg_expr = expression_compiler.compile_expression(option_value)
79+
rendered_val = sg_sql.to_sql(sg_expr)
80+
elif isinstance(option_value, (list, tuple)):
7481
# Handle list options like model_registry="vertex_ai"
7582
# wait, usually options are key=value.
7683
# if value is list, it is [val1, val2]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
CREATE MODEL `my_model`
2+
OPTIONS(l2_reg = 0.1, booster_type = 'gbtree')
3+
AS SELECT * FROM t

packages/bigframes/tests/unit/core/sql/test_ml.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,20 @@ def test_create_model_list_option(snapshot):
9797
snapshot.assert_match(sql, "create_model_list_option.sql")
9898

9999

100+
def test_create_model_expression_option(snapshot):
101+
import bigframes.core.expression as ex
102+
103+
sql = bigframes.core.sql.ml.create_model_ddl(
104+
model_name="my_model",
105+
options={
106+
"l2_reg": ex.ScalarConstantExpression(0.1, None),
107+
"booster_type": "gbtree",
108+
},
109+
training_data="SELECT * FROM t",
110+
)
111+
snapshot.assert_match(sql, "create_model_expression_option.sql")
112+
113+
100114
def test_evaluate_model_basic(snapshot):
101115
sql = bigframes.core.sql.ml.evaluate(
102116
model_name="my_project.my_dataset.my_model",

0 commit comments

Comments
 (0)