Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 1e1fae4

Browse files
committed
feat: add ai.generate_text
1 parent e91536c commit 1e1fae4

File tree

5 files changed

+311
-83
lines changed

5 files changed

+311
-83
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 124 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from bigframes import clients, dataframe, dtypes
2727
from bigframes import pandas as bpd
2828
from bigframes import series, session
29+
from bigframes.bigquery._operations import utils as ml_utils
2930
from bigframes.core import convert
3031
from bigframes.core.logging import log_adapter
3132
import bigframes.core.sql.literals
@@ -391,7 +392,7 @@ def generate_double(
391392

392393
@log_adapter.method_logger(custom_base_name="bigquery_ai")
393394
def generate_embedding(
394-
model_name: str,
395+
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
395396
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
396397
*,
397398
output_dimensionality: Optional[int] = None,
@@ -415,9 +416,8 @@ def generate_embedding(
415416
... ) # doctest: +SKIP
416417
417418
Args:
418-
model_name (str):
419-
The name of a remote model from Vertex AI, such as the
420-
multimodalembedding@001 model.
419+
model (bigframes.ml.base.BaseEstimator or str):
420+
The model to use for text embedding.
421421
data (bigframes.pandas.DataFrame or bigframes.pandas.Series):
422422
The data to generate embeddings for. If a Series is provided, it is
423423
treated as the 'content' column. If a DataFrame is provided, it
@@ -454,20 +454,8 @@ def generate_embedding(
454454
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-ai-generate-embedding#output>`_
455455
for details.
456456
"""
457-
if isinstance(data, (pd.DataFrame, pd.Series)):
458-
data = bpd.read_pandas(data)
459-
460-
if isinstance(data, series.Series):
461-
data = data.copy()
462-
data.name = "content"
463-
data_df = data.to_frame()
464-
elif isinstance(data, dataframe.DataFrame):
465-
data_df = data
466-
else:
467-
raise ValueError(f"Unsupported data type: {type(data)}")
468-
469-
# We need to get the SQL for the input data to pass as a subquery to the TVF
470-
source_sql = data_df.sql
457+
model_name, session = ml_utils.get_model_name_and_session(model, data)
458+
table_sql = ml_utils.to_sql(data)
471459

472460
struct_fields: Dict[str, bigframes.core.sql.literals.STRUCT_VALUES] = {}
473461
if output_dimensionality is not None:
@@ -488,12 +476,127 @@ def generate_embedding(
488476
SELECT *
489477
FROM AI.GENERATE_EMBEDDING(
490478
MODEL `{model_name}`,
491-
({source_sql}),
492-
{bigframes.core.sql.literals.struct_literal(struct_fields)})
479+
({table_sql}),
480+
{bigframes.core.sql.literals.struct_literal(struct_fields)}
481+
)
482+
"""
483+
484+
if session is None:
485+
return bpd.read_gbq_query(query)
486+
else:
487+
return session.read_gbq_query(query)
488+
489+
490+
@log_adapter.method_logger(custom_base_name="bigquery_ai")
491+
def generate_text(
492+
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
493+
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
494+
*,
495+
temperature: Optional[float] = None,
496+
max_output_tokens: Optional[int] = None,
497+
top_k: Optional[int] = None,
498+
top_p: Optional[float] = None,
499+
stop_sequences: Optional[List[str]] = None,
500+
ground_with_google_search: Optional[bool] = None,
501+
request_type: Optional[str] = None,
502+
) -> dataframe.DataFrame:
503+
"""
504+
Generates text using a BigQuery ML model.
505+
506+
See the `BigQuery ML GENERATE_TEXT function syntax
507+
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-text>`_
508+
for additional reference.
509+
510+
**Examples:**
511+
512+
>>> import bigframes.pandas as bpd
513+
>>> import bigframes.bigquery as bbq
514+
>>> df = bpd.DataFrame({"prompt": ["write a poem about apples"]})
515+
>>> bbq.ai.generate_text(
516+
... "project.dataset.model_name",
517+
... df
518+
... ) # doctest: +SKIP
519+
520+
Args:
521+
model (bigframes.ml.base.BaseEstimator or str):
522+
The model to use for text generation.
523+
data (bigframes.pandas.DataFrame or bigframes.pandas.Series):
524+
The data to generate embeddings for. If a Series is provided, it is
525+
treated as the 'content' column. If a DataFrame is provided, it
526+
must contain a 'content' column, or you must rename the column you
527+
wish to embed to 'content'.
528+
temperature (float, optional):
529+
A FLOAT64 value that is used for sampling promiscuity. The value
530+
must be in the range ``[0.0, 1.0]``. A lower temperature works well
531+
for prompts that expect a more deterministic and less open-ended
532+
or creative response, while a higher temperature can lead to more
533+
diverse or creative results. A temperature of ``0`` is
534+
deterministic, meaning that the highest probability response is
535+
always selected.
536+
max_output_tokens (int, optional):
537+
An INT64 value that sets the maximum number of tokens in the
538+
generated text.
539+
top_k (int, optional):
540+
An INT64 value that changes how the model selects tokens for
541+
output. A ``top_k`` of ``1`` means the next selected token is the
542+
most probable among all tokens in the model's vocabulary. A
543+
``top_k`` of ``3`` means that the next token is selected from
544+
among the three most probable tokens by using temperature. The
545+
default value is ``40``.
546+
top_p (float, optional):
547+
A FLOAT64 value that changes how the model selects tokens for
548+
output. Tokens are selected from most probable to least probable
549+
until the sum of their probabilities equals the ``top_p`` value.
550+
For example, if tokens A, B, and C have a probability of 0.3, 0.2,
551+
and 0.1 and the ``top_p`` value is ``0.5``, then the model will
552+
select either A or B as the next token by using temperature. The
553+
default value is ``0.95``.
554+
stop_sequences (List[str], optional):
555+
An ARRAY<STRING> value that contains the stop sequences for the model.
556+
ground_with_google_search (bool, optional):
557+
A BOOL value that determines whether to ground the model with Google Search.
558+
request_type (str, optional):
559+
A STRING value that contains the request type for the model.
560+
561+
Returns:
562+
bigframes.pandas.DataFrame:
563+
The generated text.
564+
"""
565+
model_name, session = ml_utils.get_model_name_and_session(model, data)
566+
table_sql = ml_utils.to_sql(data)
567+
568+
struct_fields: Dict[
569+
str,
570+
Union[str, int, float, bool, Mapping[str, str], List[str], Mapping[str, Any]],
571+
] = {}
572+
if temperature is not None:
573+
struct_fields["TEMPERATURE"] = temperature
574+
if max_output_tokens is not None:
575+
struct_fields["MAX_OUTPUT_TOKENS"] = max_output_tokens
576+
if top_k is not None:
577+
struct_fields["TOP_K"] = top_k
578+
if top_p is not None:
579+
struct_fields["TOP_P"] = top_p
580+
if stop_sequences is not None:
581+
struct_fields["STEP_SEQUENCES"] = stop_sequences
582+
if ground_with_google_search is not None:
583+
struct_fields["GROUND_WITH_GOOGLE_SEARCH"] = ground_with_google_search
584+
if request_type is not None:
585+
struct_fields["REQUEST_TYPE"] = request_type
586+
587+
query = f"""
588+
SELECT *
589+
FROM AI.GENERATE_TEXT(
590+
MODEL `{model_name}`,
591+
({table_sql}),
592+
{bigframes.core.sql.literals.struct_literal(struct_fields)}
493593
)
494594
"""
495595

496-
return data_df._session.read_gbq(query)
596+
if session is None:
597+
return bpd.read_gbq_query(query)
598+
else:
599+
return session.read_gbq_query(query)
497600

498601

499602
@log_adapter.method_logger(custom_base_name="bigquery_ai")

bigframes/bigquery/_operations/ml.py

Lines changed: 20 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -20,60 +20,14 @@
2020
import google.cloud.bigquery
2121
import pandas as pd
2222

23+
from bigframes.bigquery._operations import utils
2324
import bigframes.core.logging.log_adapter as log_adapter
2425
import bigframes.core.sql.ml
2526
import bigframes.dataframe as dataframe
2627
import bigframes.ml.base
2728
import bigframes.session
2829

2930

30-
# Helper to convert DataFrame to SQL string
31-
def _to_sql(df_or_sql: Union[pd.DataFrame, dataframe.DataFrame, str]) -> str:
32-
import bigframes.pandas as bpd
33-
34-
if isinstance(df_or_sql, str):
35-
return df_or_sql
36-
37-
if isinstance(df_or_sql, pd.DataFrame):
38-
bf_df = bpd.read_pandas(df_or_sql)
39-
else:
40-
bf_df = cast(dataframe.DataFrame, df_or_sql)
41-
42-
# Cache dataframes to make sure base table is not a snapshot.
43-
# Cached dataframe creates a full copy, never uses snapshot.
44-
# This is a workaround for internal issue b/310266666.
45-
bf_df.cache()
46-
sql, _, _ = bf_df._to_sql_query(include_index=False)
47-
return sql
48-
49-
50-
def _get_model_name_and_session(
51-
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
52-
# Other dataframe arguments to extract session from
53-
*dataframes: Optional[Union[pd.DataFrame, dataframe.DataFrame, str]],
54-
) -> tuple[str, Optional[bigframes.session.Session]]:
55-
if isinstance(model, pd.Series):
56-
try:
57-
model_ref = model["modelReference"]
58-
model_name = f"{model_ref['projectId']}.{model_ref['datasetId']}.{model_ref['modelId']}" # type: ignore
59-
except KeyError:
60-
raise ValueError("modelReference must be present in the pandas Series.")
61-
elif isinstance(model, str):
62-
model_name = model
63-
else:
64-
if model._bqml_model is None:
65-
raise ValueError("Model must be fitted to be used in ML operations.")
66-
return model._bqml_model.model_name, model._bqml_model.session
67-
68-
session = None
69-
for df in dataframes:
70-
if isinstance(df, dataframe.DataFrame):
71-
session = df._session
72-
break
73-
74-
return model_name, session
75-
76-
7731
def _get_model_metadata(
7832
*,
7933
bqclient: google.cloud.bigquery.Client,
@@ -143,8 +97,12 @@ def create_model(
14397
"""
14498
import bigframes.pandas as bpd
14599

146-
training_data_sql = _to_sql(training_data) if training_data is not None else None
147-
custom_holiday_sql = _to_sql(custom_holiday) if custom_holiday is not None else None
100+
training_data_sql = (
101+
utils.to_sql(training_data) if training_data is not None else None
102+
)
103+
custom_holiday_sql = (
104+
utils.to_sql(custom_holiday) if custom_holiday is not None else None
105+
)
148106

149107
# Determine session from DataFrames if not provided
150108
if session is None:
@@ -227,8 +185,8 @@ def evaluate(
227185
"""
228186
import bigframes.pandas as bpd
229187

230-
model_name, session = _get_model_name_and_session(model, input_)
231-
table_sql = _to_sql(input_) if input_ is not None else None
188+
model_name, session = utils.get_model_name_and_session(model, input_)
189+
table_sql = utils.to_sql(input_) if input_ is not None else None
232190

233191
sql = bigframes.core.sql.ml.evaluate(
234192
model_name=model_name,
@@ -281,8 +239,8 @@ def predict(
281239
"""
282240
import bigframes.pandas as bpd
283241

284-
model_name, session = _get_model_name_and_session(model, input_)
285-
table_sql = _to_sql(input_)
242+
model_name, session = utils.get_model_name_and_session(model, input_)
243+
table_sql = utils.to_sql(input_)
286244

287245
sql = bigframes.core.sql.ml.predict(
288246
model_name=model_name,
@@ -340,8 +298,8 @@ def explain_predict(
340298
"""
341299
import bigframes.pandas as bpd
342300

343-
model_name, session = _get_model_name_and_session(model, input_)
344-
table_sql = _to_sql(input_)
301+
model_name, session = utils.get_model_name_and_session(model, input_)
302+
table_sql = utils.to_sql(input_)
345303

346304
sql = bigframes.core.sql.ml.explain_predict(
347305
model_name=model_name,
@@ -383,7 +341,7 @@ def global_explain(
383341
"""
384342
import bigframes.pandas as bpd
385343

386-
model_name, session = _get_model_name_and_session(model)
344+
model_name, session = utils.get_model_name_and_session(model)
387345
sql = bigframes.core.sql.ml.global_explain(
388346
model_name=model_name,
389347
class_level_explain=class_level_explain,
@@ -419,8 +377,8 @@ def transform(
419377
"""
420378
import bigframes.pandas as bpd
421379

422-
model_name, session = _get_model_name_and_session(model, input_)
423-
table_sql = _to_sql(input_)
380+
model_name, session = utils.get_model_name_and_session(model, input_)
381+
table_sql = utils.to_sql(input_)
424382

425383
sql = bigframes.core.sql.ml.transform(
426384
model_name=model_name,
@@ -500,8 +458,8 @@ def generate_text(
500458
"""
501459
import bigframes.pandas as bpd
502460

503-
model_name, session = _get_model_name_and_session(model, input_)
504-
table_sql = _to_sql(input_)
461+
model_name, session = utils.get_model_name_and_session(model, input_)
462+
table_sql = utils.to_sql(input_)
505463

506464
sql = bigframes.core.sql.ml.generate_text(
507465
model_name=model_name,
@@ -565,8 +523,8 @@ def generate_embedding(
565523
"""
566524
import bigframes.pandas as bpd
567525

568-
model_name, session = _get_model_name_and_session(model, input_)
569-
table_sql = _to_sql(input_)
526+
model_name, session = utils.get_model_name_and_session(model, input_)
527+
table_sql = utils.to_sql(input_)
570528

571529
sql = bigframes.core.sql.ml.generate_embedding(
572530
model_name=model_name,

0 commit comments

Comments
 (0)