2626from bigframes import clients , dataframe , dtypes
2727from bigframes import pandas as bpd
2828from bigframes import series , session
29+ from bigframes .bigquery ._operations import utils as ml_utils
2930from bigframes .core import convert
3031from bigframes .core .logging import log_adapter
3132import bigframes .core .sql .literals
@@ -391,7 +392,7 @@ def generate_double(
391392
392393@log_adapter .method_logger (custom_base_name = "bigquery_ai" )
393394def generate_embedding (
394- model_name : str ,
395+ model : Union [ bigframes . ml . base . BaseEstimator , str , pd . Series ] ,
395396 data : Union [dataframe .DataFrame , series .Series , pd .DataFrame , pd .Series ],
396397 * ,
397398 output_dimensionality : Optional [int ] = None ,
@@ -415,9 +416,8 @@ def generate_embedding(
415416 ... ) # doctest: +SKIP
416417
417418 Args:
418- model_name (str):
419- The name of a remote model from Vertex AI, such as the
420- multimodalembedding@001 model.
419+ model (bigframes.ml.base.BaseEstimator or str):
420+ The model to use for text embedding.
421421 data (bigframes.pandas.DataFrame or bigframes.pandas.Series):
422422 The data to generate embeddings for. If a Series is provided, it is
423423 treated as the 'content' column. If a DataFrame is provided, it
@@ -454,20 +454,8 @@ def generate_embedding(
454454 <https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-ai-generate-embedding#output>`_
455455 for details.
456456 """
457- if isinstance (data , (pd .DataFrame , pd .Series )):
458- data = bpd .read_pandas (data )
459-
460- if isinstance (data , series .Series ):
461- data = data .copy ()
462- data .name = "content"
463- data_df = data .to_frame ()
464- elif isinstance (data , dataframe .DataFrame ):
465- data_df = data
466- else :
467- raise ValueError (f"Unsupported data type: { type (data )} " )
468-
469- # We need to get the SQL for the input data to pass as a subquery to the TVF
470- source_sql = data_df .sql
457+ model_name , session = ml_utils .get_model_name_and_session (model , data )
458+ table_sql = ml_utils .to_sql (data )
471459
472460 struct_fields : Dict [str , bigframes .core .sql .literals .STRUCT_VALUES ] = {}
473461 if output_dimensionality is not None :
@@ -488,12 +476,127 @@ def generate_embedding(
488476 SELECT *
489477 FROM AI.GENERATE_EMBEDDING(
490478 MODEL `{ model_name } `,
491- ({ source_sql } ),
492- { bigframes .core .sql .literals .struct_literal (struct_fields )} )
479+ ({ table_sql } ),
480+ { bigframes .core .sql .literals .struct_literal (struct_fields )}
481+ )
482+ """
483+
484+ if session is None :
485+ return bpd .read_gbq_query (query )
486+ else :
487+ return session .read_gbq_query (query )
488+
489+
490+ @log_adapter .method_logger (custom_base_name = "bigquery_ai" )
491+ def generate_text (
492+ model : Union [bigframes .ml .base .BaseEstimator , str , pd .Series ],
493+ data : Union [dataframe .DataFrame , series .Series , pd .DataFrame , pd .Series ],
494+ * ,
495+ temperature : Optional [float ] = None ,
496+ max_output_tokens : Optional [int ] = None ,
497+ top_k : Optional [int ] = None ,
498+ top_p : Optional [float ] = None ,
499+ stop_sequences : Optional [List [str ]] = None ,
500+ ground_with_google_search : Optional [bool ] = None ,
501+ request_type : Optional [str ] = None ,
502+ ) -> dataframe .DataFrame :
503+ """
504+ Generates text using a BigQuery ML model.
505+
506+ See the `BigQuery ML GENERATE_TEXT function syntax
507+ <https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-text>`_
508+ for additional reference.
509+
510+ **Examples:**
511+
512+ >>> import bigframes.pandas as bpd
513+ >>> import bigframes.bigquery as bbq
514+ >>> df = bpd.DataFrame({"prompt": ["write a poem about apples"]})
515+ >>> bbq.ai.generate_text(
516+ ... "project.dataset.model_name",
517+ ... df
518+ ... ) # doctest: +SKIP
519+
520+ Args:
521+ model (bigframes.ml.base.BaseEstimator or str):
522+ The model to use for text generation.
523+ data (bigframes.pandas.DataFrame or bigframes.pandas.Series):
524+ The data to generate embeddings for. If a Series is provided, it is
525+ treated as the 'content' column. If a DataFrame is provided, it
526+ must contain a 'content' column, or you must rename the column you
527+ wish to embed to 'content'.
528+ temperature (float, optional):
529+ A FLOAT64 value that is used for sampling promiscuity. The value
530+ must be in the range ``[0.0, 1.0]``. A lower temperature works well
531+ for prompts that expect a more deterministic and less open-ended
532+ or creative response, while a higher temperature can lead to more
533+ diverse or creative results. A temperature of ``0`` is
534+ deterministic, meaning that the highest probability response is
535+ always selected.
536+ max_output_tokens (int, optional):
537+ An INT64 value that sets the maximum number of tokens in the
538+ generated text.
539+ top_k (int, optional):
540+ An INT64 value that changes how the model selects tokens for
541+ output. A ``top_k`` of ``1`` means the next selected token is the
542+ most probable among all tokens in the model's vocabulary. A
543+ ``top_k`` of ``3`` means that the next token is selected from
544+ among the three most probable tokens by using temperature. The
545+ default value is ``40``.
546+ top_p (float, optional):
547+ A FLOAT64 value that changes how the model selects tokens for
548+ output. Tokens are selected from most probable to least probable
549+ until the sum of their probabilities equals the ``top_p`` value.
550+ For example, if tokens A, B, and C have a probability of 0.3, 0.2,
551+ and 0.1 and the ``top_p`` value is ``0.5``, then the model will
552+ select either A or B as the next token by using temperature. The
553+ default value is ``0.95``.
554+ stop_sequences (List[str], optional):
555+ An ARRAY<STRING> value that contains the stop sequences for the model.
556+ ground_with_google_search (bool, optional):
557+ A BOOL value that determines whether to ground the model with Google Search.
558+ request_type (str, optional):
559+ A STRING value that contains the request type for the model.
560+
561+ Returns:
562+ bigframes.pandas.DataFrame:
563+ The generated text.
564+ """
565+ model_name , session = ml_utils .get_model_name_and_session (model , data )
566+ table_sql = ml_utils .to_sql (data )
567+
568+ struct_fields : Dict [
569+ str ,
570+ Union [str , int , float , bool , Mapping [str , str ], List [str ], Mapping [str , Any ]],
571+ ] = {}
572+ if temperature is not None :
573+ struct_fields ["TEMPERATURE" ] = temperature
574+ if max_output_tokens is not None :
575+ struct_fields ["MAX_OUTPUT_TOKENS" ] = max_output_tokens
576+ if top_k is not None :
577+ struct_fields ["TOP_K" ] = top_k
578+ if top_p is not None :
579+ struct_fields ["TOP_P" ] = top_p
580+ if stop_sequences is not None :
581+ struct_fields ["STEP_SEQUENCES" ] = stop_sequences
582+ if ground_with_google_search is not None :
583+ struct_fields ["GROUND_WITH_GOOGLE_SEARCH" ] = ground_with_google_search
584+ if request_type is not None :
585+ struct_fields ["REQUEST_TYPE" ] = request_type
586+
587+ query = f"""
588+ SELECT *
589+ FROM AI.GENERATE_TEXT(
590+ MODEL `{ model_name } `,
591+ ({ table_sql } ),
592+ { bigframes .core .sql .literals .struct_literal (struct_fields )}
493593 )
494594 """
495595
496- return data_df ._session .read_gbq (query )
596+ if session is None :
597+ return bpd .read_gbq_query (query )
598+ else :
599+ return session .read_gbq_query (query )
497600
498601
499602@log_adapter .method_logger (custom_base_name = "bigquery_ai" )
0 commit comments