1414
1515from __future__ import annotations
1616
17- from typing import Mapping , Optional , TYPE_CHECKING , Union
17+ from typing import Mapping , Optional , Union
1818
1919import bigframes .core .log_adapter as log_adapter
2020import bigframes .core .sql .ml
2121import bigframes .dataframe as dataframe
22-
23- if TYPE_CHECKING :
24- import bigframes .ml .base
25- import bigframes .session
22+ import bigframes .ml .base
23+ import bigframes .session
2624
2725
2826# Helper to convert DataFrame to SQL string
@@ -34,6 +32,29 @@ def _to_sql(df_or_sql: Union[dataframe.DataFrame, str]) -> str:
3432 return sql
3533
3634
35+ def _get_model_name_and_session (
36+ model : Union [bigframes .ml .base .BaseEstimator , str ],
37+ # Other dataframe arguments to extract session from
38+ * dataframes : Optional [Union [dataframe .DataFrame , str ]],
39+ ) -> tuple [str , bigframes .session .Session ]:
40+ import bigframes .pandas as bpd
41+
42+ if isinstance (model , str ):
43+ model_name = model
44+ session = None
45+ for df in dataframes :
46+ if isinstance (df , dataframe .DataFrame ):
47+ session = df ._session
48+ break
49+ if session is None :
50+ session = bpd .get_global_session ()
51+ return model_name , session
52+ else :
53+ if model ._bqml_model is None :
54+ raise ValueError ("Model must be fitted to be used in ML operations." )
55+ return model ._bqml_model .model_name , model ._bqml_model .session
56+
57+
3758@log_adapter .method_logger (custom_base_name = "bigquery_ml" )
3859def create_model (
3960 model_name : str ,
@@ -123,3 +144,150 @@ def create_model(
123144 session ._start_query_ml_ddl (sql )
124145
125146 return session .read_gbq_model (model_name )
147+
148+
149+ @log_adapter .method_logger (custom_base_name = "bigquery_ml" )
150+ def evaluate (
151+ model : Union [bigframes .ml .base .BaseEstimator , str ],
152+ input_ : Optional [Union [dataframe .DataFrame , str ]] = None ,
153+ * ,
154+ options : Optional [Mapping [str , Union [str , int , float , bool , list ]]] = None ,
155+ ) -> dataframe .DataFrame :
156+ """
157+ Evaluates a BigQuery ML model.
158+
159+ See the `BigQuery ML EVALUATE function syntax
160+ <https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-evaluate>`_
161+ for additional reference.
162+
163+ Args:
164+ model (bigframes.ml.base.BaseEstimator or str):
165+ The model to evaluate.
166+ input_ (Union[bigframes.pandas.DataFrame, str], optional):
167+ The DataFrame or query to use for evaluation. If not provided, the
168+ evaluation data from training is used.
169+ options (Mapping[str, Union[str, int, float, bool, list]], optional):
170+ The OPTIONS clause, which specifies the model options.
171+
172+ Returns:
173+ bigframes.pandas.DataFrame:
174+ The evaluation results.
175+ """
176+ model_name , session = _get_model_name_and_session (model , input_ )
177+ table_sql = _to_sql (input_ ) if input_ is not None else None
178+
179+ sql = bigframes .core .sql .ml .evaluate (
180+ model_name = model_name ,
181+ table = table_sql ,
182+ options = options ,
183+ )
184+
185+ return session .read_gbq (sql )
186+
187+
188+ @log_adapter .method_logger (custom_base_name = "bigquery_ml" )
189+ def predict (
190+ model : Union [bigframes .ml .base .BaseEstimator , str ],
191+ input_ : Union [dataframe .DataFrame , str ],
192+ * ,
193+ options : Optional [Mapping [str , Union [str , int , float , bool , list ]]] = None ,
194+ ) -> dataframe .DataFrame :
195+ """
196+ Runs prediction on a BigQuery ML model.
197+
198+ See the `BigQuery ML PREDICT function syntax
199+ <https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict>`_
200+ for additional reference.
201+
202+ Args:
203+ model (bigframes.ml.base.BaseEstimator or str):
204+ The model to use for prediction.
205+ input_ (Union[bigframes.pandas.DataFrame, str]):
206+ The DataFrame or query to use for prediction.
207+ options (Mapping[str, Union[str, int, float, bool, list]], optional):
208+ The OPTIONS clause, which specifies the model options.
209+
210+ Returns:
211+ bigframes.pandas.DataFrame:
212+ The prediction results.
213+ """
214+ model_name , session = _get_model_name_and_session (model , input_ )
215+ table_sql = _to_sql (input_ )
216+
217+ sql = bigframes .core .sql .ml .predict (
218+ model_name = model_name ,
219+ table = table_sql ,
220+ options = options ,
221+ )
222+
223+ return session .read_gbq (sql )
224+
225+
226+ @log_adapter .method_logger (custom_base_name = "bigquery_ml" )
227+ def explain_predict (
228+ model : Union [bigframes .ml .base .BaseEstimator , str ],
229+ input_ : Union [dataframe .DataFrame , str ],
230+ * ,
231+ options : Optional [Mapping [str , Union [str , int , float , bool , list ]]] = None ,
232+ ) -> dataframe .DataFrame :
233+ """
234+ Runs explainable prediction on a BigQuery ML model.
235+
236+ See the `BigQuery ML EXPLAIN_PREDICT function syntax
237+ <https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-explain-predict>`_
238+ for additional reference.
239+
240+ Args:
241+ model (bigframes.ml.base.BaseEstimator or str):
242+ The model to use for prediction.
243+ input_ (Union[bigframes.pandas.DataFrame, str]):
244+ The DataFrame or query to use for prediction.
245+ options (Mapping[str, Union[str, int, float, bool, list]], optional):
246+ The OPTIONS clause, which specifies the model options.
247+
248+ Returns:
249+ bigframes.pandas.DataFrame:
250+ The prediction results with explanations.
251+ """
252+ model_name , session = _get_model_name_and_session (model , input_ )
253+ table_sql = _to_sql (input_ )
254+
255+ sql = bigframes .core .sql .ml .explain_predict (
256+ model_name = model_name ,
257+ table = table_sql ,
258+ options = options ,
259+ )
260+
261+ return session .read_gbq (sql )
262+
263+
264+ @log_adapter .method_logger (custom_base_name = "bigquery_ml" )
265+ def global_explain (
266+ model : Union [bigframes .ml .base .BaseEstimator , str ],
267+ * ,
268+ options : Optional [Mapping [str , Union [str , int , float , bool , list ]]] = None ,
269+ ) -> dataframe .DataFrame :
270+ """
271+ Gets global explanations for a BigQuery ML model.
272+
273+ See the `BigQuery ML GLOBAL_EXPLAIN function syntax
274+ <https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-global-explain>`_
275+ for additional reference.
276+
277+ Args:
278+ model (bigframes.ml.base.BaseEstimator or str):
279+ The model to get explanations from.
280+ options (Mapping[str, Union[str, int, float, bool, list]], optional):
281+ The OPTIONS clause, which specifies the model options.
282+
283+ Returns:
284+ bigframes.pandas.DataFrame:
285+ The global explanation results.
286+ """
287+ model_name , session = _get_model_name_and_session (model )
288+ sql = bigframes .core .sql .ml .global_explain (
289+ model_name = model_name ,
290+ options = options ,
291+ )
292+
293+ return session .read_gbq (sql )
0 commit comments