Skip to content

Commit 7ea8bf5

Browse files
feat(bigframes): Support Expression objects in create_model options
This change allows the `options` parameter of `bigframes.bigquery._operations.ml.create_model` to accept BigFrames `Expression` objects. These expressions are compiled to SQL scalar expressions and included in the generated `CREATE MODEL` DDL statement. - Added `bigframes.core.expression.Expression` type support in the `options` dict. - Updated `create_model_ddl` to handle compiling expressions using `expression_compiler`. - Added `test_create_model_expression_option` snapshot test to verify the generated "golden SQL", using an expression that calls a function on a literal value (e.g. 0.1 * 10). Co-authored-by: tswast <247555+tswast@users.noreply.github.com>
1 parent b5e8a06 commit 7ea8bf5

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
CREATE MODEL `my_model`
2-
OPTIONS(l2_reg = 0.1, booster_type = 'gbtree')
2+
OPTIONS(l2_reg = 0.1 * 10, booster_type = 'gbtree')
33
AS SELECT * FROM t

packages/bigframes/tests/unit/core/sql/test_ml.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,19 @@ def test_create_model_list_option(snapshot):
9999

100100
def test_create_model_expression_option(snapshot):
101101
import bigframes.core.expression as ex
102+
import bigframes.operations.numeric_ops as numeric_ops
103+
import bigframes.dtypes as dtypes
104+
105+
# An expression that calls a function on a literal value
106+
# e.g. 0.1 * 10
107+
literal_expr = ex.ScalarConstantExpression(0.1, dtypes.FLOAT_DTYPE)
108+
multiplier_expr = ex.ScalarConstantExpression(10, dtypes.INT_DTYPE)
109+
math_expr = ex.OpExpression(op=numeric_ops.mul_op, inputs=(literal_expr, multiplier_expr))
102110

103111
sql = bigframes.core.sql.ml.create_model_ddl(
104112
model_name="my_model",
105113
options={
106-
"l2_reg": ex.ScalarConstantExpression(0.1, None),
114+
"l2_reg": math_expr,
107115
"booster_type": "gbtree",
108116
},
109117
training_data="SELECT * FROM t",

0 commit comments

Comments
 (0)