Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 9ad9011

Browse files
committed
add more functions
1 parent c86e15a commit 9ad9011

File tree

14 files changed

+1089
-438
lines changed

14 files changed

+1089
-438
lines changed

bigframes/bigquery/_operations/ml.py

Lines changed: 173 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,13 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import Mapping, Optional, TYPE_CHECKING, Union
17+
from typing import Mapping, Optional, Union
1818

1919
import bigframes.core.log_adapter as log_adapter
2020
import bigframes.core.sql.ml
2121
import bigframes.dataframe as dataframe
22-
23-
if TYPE_CHECKING:
24-
import bigframes.ml.base
25-
import bigframes.session
22+
import bigframes.ml.base
23+
import bigframes.session
2624

2725

2826
# Helper to convert DataFrame to SQL string
@@ -34,6 +32,29 @@ def _to_sql(df_or_sql: Union[dataframe.DataFrame, str]) -> str:
3432
return sql
3533

3634

35+
def _get_model_name_and_session(
36+
model: Union[bigframes.ml.base.BaseEstimator, str],
37+
# Other dataframe arguments to extract session from
38+
*dataframes: Optional[Union[dataframe.DataFrame, str]],
39+
) -> tuple[str, bigframes.session.Session]:
40+
import bigframes.pandas as bpd
41+
42+
if isinstance(model, str):
43+
model_name = model
44+
session = None
45+
for df in dataframes:
46+
if isinstance(df, dataframe.DataFrame):
47+
session = df._session
48+
break
49+
if session is None:
50+
session = bpd.get_global_session()
51+
return model_name, session
52+
else:
53+
if model._bqml_model is None:
54+
raise ValueError("Model must be fitted to be used in ML operations.")
55+
return model._bqml_model.model_name, model._bqml_model.session
56+
57+
3758
@log_adapter.method_logger(custom_base_name="bigquery_ml")
3859
def create_model(
3960
model_name: str,
@@ -123,3 +144,150 @@ def create_model(
123144
session._start_query_ml_ddl(sql)
124145

125146
return session.read_gbq_model(model_name)
147+
148+
149+
@log_adapter.method_logger(custom_base_name="bigquery_ml")
150+
def evaluate(
151+
model: Union[bigframes.ml.base.BaseEstimator, str],
152+
input_: Optional[Union[dataframe.DataFrame, str]] = None,
153+
*,
154+
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
155+
) -> dataframe.DataFrame:
156+
"""
157+
Evaluates a BigQuery ML model.
158+
159+
See the `BigQuery ML EVALUATE function syntax
160+
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-evaluate>`_
161+
for additional reference.
162+
163+
Args:
164+
model (bigframes.ml.base.BaseEstimator or str):
165+
The model to evaluate.
166+
input_ (Union[bigframes.pandas.DataFrame, str], optional):
167+
The DataFrame or query to use for evaluation. If not provided, the
168+
evaluation data from training is used.
169+
options (Mapping[str, Union[str, int, float, bool, list]], optional):
170+
The OPTIONS clause, which specifies the model options.
171+
172+
Returns:
173+
bigframes.pandas.DataFrame:
174+
The evaluation results.
175+
"""
176+
model_name, session = _get_model_name_and_session(model, input_)
177+
table_sql = _to_sql(input_) if input_ is not None else None
178+
179+
sql = bigframes.core.sql.ml.evaluate(
180+
model_name=model_name,
181+
table=table_sql,
182+
options=options,
183+
)
184+
185+
return session.read_gbq(sql)
186+
187+
188+
@log_adapter.method_logger(custom_base_name="bigquery_ml")
189+
def predict(
190+
model: Union[bigframes.ml.base.BaseEstimator, str],
191+
input_: Union[dataframe.DataFrame, str],
192+
*,
193+
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
194+
) -> dataframe.DataFrame:
195+
"""
196+
Runs prediction on a BigQuery ML model.
197+
198+
See the `BigQuery ML PREDICT function syntax
199+
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict>`_
200+
for additional reference.
201+
202+
Args:
203+
model (bigframes.ml.base.BaseEstimator or str):
204+
The model to use for prediction.
205+
input_ (Union[bigframes.pandas.DataFrame, str]):
206+
The DataFrame or query to use for prediction.
207+
options (Mapping[str, Union[str, int, float, bool, list]], optional):
208+
The OPTIONS clause, which specifies the model options.
209+
210+
Returns:
211+
bigframes.pandas.DataFrame:
212+
The prediction results.
213+
"""
214+
model_name, session = _get_model_name_and_session(model, input_)
215+
table_sql = _to_sql(input_)
216+
217+
sql = bigframes.core.sql.ml.predict(
218+
model_name=model_name,
219+
table=table_sql,
220+
options=options,
221+
)
222+
223+
return session.read_gbq(sql)
224+
225+
226+
@log_adapter.method_logger(custom_base_name="bigquery_ml")
227+
def explain_predict(
228+
model: Union[bigframes.ml.base.BaseEstimator, str],
229+
input_: Union[dataframe.DataFrame, str],
230+
*,
231+
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
232+
) -> dataframe.DataFrame:
233+
"""
234+
Runs explainable prediction on a BigQuery ML model.
235+
236+
See the `BigQuery ML EXPLAIN_PREDICT function syntax
237+
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-explain-predict>`_
238+
for additional reference.
239+
240+
Args:
241+
model (bigframes.ml.base.BaseEstimator or str):
242+
The model to use for prediction.
243+
input_ (Union[bigframes.pandas.DataFrame, str]):
244+
The DataFrame or query to use for prediction.
245+
options (Mapping[str, Union[str, int, float, bool, list]], optional):
246+
The OPTIONS clause, which specifies the model options.
247+
248+
Returns:
249+
bigframes.pandas.DataFrame:
250+
The prediction results with explanations.
251+
"""
252+
model_name, session = _get_model_name_and_session(model, input_)
253+
table_sql = _to_sql(input_)
254+
255+
sql = bigframes.core.sql.ml.explain_predict(
256+
model_name=model_name,
257+
table=table_sql,
258+
options=options,
259+
)
260+
261+
return session.read_gbq(sql)
262+
263+
264+
@log_adapter.method_logger(custom_base_name="bigquery_ml")
265+
def global_explain(
266+
model: Union[bigframes.ml.base.BaseEstimator, str],
267+
*,
268+
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
269+
) -> dataframe.DataFrame:
270+
"""
271+
Gets global explanations for a BigQuery ML model.
272+
273+
See the `BigQuery ML GLOBAL_EXPLAIN function syntax
274+
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-global-explain>`_
275+
for additional reference.
276+
277+
Args:
278+
model (bigframes.ml.base.BaseEstimator or str):
279+
The model to get explanations from.
280+
options (Mapping[str, Union[str, int, float, bool, list]], optional):
281+
The OPTIONS clause, which specifies the model options.
282+
283+
Returns:
284+
bigframes.pandas.DataFrame:
285+
The global explanation results.
286+
"""
287+
model_name, session = _get_model_name_and_session(model)
288+
sql = bigframes.core.sql.ml.global_explain(
289+
model_name=model_name,
290+
options=options,
291+
)
292+
293+
return session.read_gbq(sql)

bigframes/bigquery/ml.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,18 @@
1919
For an interface more familiar to Scikit-Learn users, see :mod:`bigframes.ml`.
2020
"""
2121

22-
from bigframes.bigquery._operations.ml import create_model
22+
from bigframes.bigquery._operations.ml import (
23+
create_model,
24+
evaluate,
25+
explain_predict,
26+
global_explain,
27+
predict,
28+
)
2329

2430
__all__ = [
2531
"create_model",
32+
"evaluate",
33+
"predict",
34+
"explain_predict",
35+
"global_explain",
2636
]

bigframes/core/sql/ml.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,102 @@ def create_model_ddl(
9797
ddl += f"AS {training_data}"
9898

9999
return ddl
100+
101+
102+
def evaluate(
103+
model_name: str,
104+
*,
105+
table: Optional[str] = None,
106+
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
107+
) -> str:
108+
"""Encode the ML.EVALUATE statement.
109+
110+
See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-evaluate for reference.
111+
"""
112+
sql = f"SELECT * FROM ML.EVALUATE(MODEL {googlesql.identifier(model_name)}"
113+
if table:
114+
sql += f", ({table})"
115+
if options:
116+
rendered_options = []
117+
for option_name, option_value in options.items():
118+
if isinstance(option_value, (list, tuple)):
119+
rendered_val = bigframes.core.sql.simple_literal(list(option_value))
120+
else:
121+
rendered_val = bigframes.core.sql.simple_literal(option_value)
122+
rendered_options.append(f"{option_name} = {rendered_val}")
123+
sql += f", OPTIONS({', '.join(rendered_options)})"
124+
sql += ")"
125+
return sql
126+
127+
128+
def predict(
129+
model_name: str,
130+
table: str,
131+
*,
132+
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
133+
) -> str:
134+
"""Encode the ML.PREDICT statement.
135+
136+
See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict for reference.
137+
"""
138+
sql = (
139+
f"SELECT * FROM ML.PREDICT(MODEL {googlesql.identifier(model_name)}, ({table})"
140+
)
141+
if options:
142+
rendered_options = []
143+
for option_name, option_value in options.items():
144+
if isinstance(option_value, (list, tuple)):
145+
rendered_val = bigframes.core.sql.simple_literal(list(option_value))
146+
else:
147+
rendered_val = bigframes.core.sql.simple_literal(option_value)
148+
rendered_options.append(f"{option_name} = {rendered_val}")
149+
sql += f", OPTIONS({', '.join(rendered_options)})"
150+
sql += ")"
151+
return sql
152+
153+
154+
def explain_predict(
155+
model_name: str,
156+
table: str,
157+
*,
158+
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
159+
) -> str:
160+
"""Encode the ML.EXPLAIN_PREDICT statement.
161+
162+
See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-explain-predict for reference.
163+
"""
164+
sql = f"SELECT * FROM ML.EXPLAIN_PREDICT(MODEL {googlesql.identifier(model_name)}, ({table})"
165+
if options:
166+
rendered_options = []
167+
for option_name, option_value in options.items():
168+
if isinstance(option_value, (list, tuple)):
169+
rendered_val = bigframes.core.sql.simple_literal(list(option_value))
170+
else:
171+
rendered_val = bigframes.core.sql.simple_literal(option_value)
172+
rendered_options.append(f"{option_name} = {rendered_val}")
173+
sql += f", OPTIONS({', '.join(rendered_options)})"
174+
sql += ")"
175+
return sql
176+
177+
178+
def global_explain(
179+
model_name: str,
180+
*,
181+
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
182+
) -> str:
183+
"""Encode the ML.GLOBAL_EXPLAIN statement.
184+
185+
See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-global-explain for reference.
186+
"""
187+
sql = f"SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL {googlesql.identifier(model_name)}"
188+
if options:
189+
rendered_options = []
190+
for option_name, option_value in options.items():
191+
if isinstance(option_value, (list, tuple)):
192+
rendered_val = bigframes.core.sql.simple_literal(list(option_value))
193+
else:
194+
rendered_val = bigframes.core.sql.simple_literal(option_value)
195+
rendered_options.append(f"{option_name} = {rendered_val}")
196+
sql += f", OPTIONS({', '.join(rendered_options)})"
197+
sql += ")"
198+
return sql

0 commit comments

Comments
 (0)