@@ -606,7 +606,7 @@ def generate_table(
606606 model : Union [bigframes .ml .base .BaseEstimator , str , pd .Series ],
607607 data : Union [dataframe .DataFrame , series .Series , pd .DataFrame , pd .Series ],
608608 * ,
609- output_schema : str ,
609+ output_schema : Union [ str , Mapping [ str , str ]] ,
610610 temperature : Optional [float ] = None ,
611611 top_p : Optional [float ] = None ,
612612 max_output_tokens : Optional [int ] = None ,
@@ -642,8 +642,10 @@ def generate_table(
642642 treated as the 'prompt' column. If a DataFrame is provided, it
643643 must contain a 'prompt' column, or you must rename the column you
644644 wish to generate table to 'prompt'.
645- output_schema (str):
646- A string defining the output schema (e.g., "col1 STRING, col2 INT64").
645+ output_schema (str | Mapping[str, str]):
646+ A string defining the output schema (e.g., "col1 STRING, col2 INT64"),
647+ or a mapping value that specifies the schema of the output, in the form {field_name: data_type}.
648+ Supported data types include `STRING`, `INT64`, `FLOAT64`, `BOOL`, `ARRAY`, and `STRUCT`.
647649 temperature (float, optional):
648650 A FLOAT64 value that is used for sampling promiscuity. The value
649651 must be in the range ``[0.0, 1.0]``.
@@ -666,8 +668,17 @@ def generate_table(
666668 model_name , session = bq_utils .get_model_name_and_session (model , data )
667669 table_sql = bq_utils .to_sql (data )
668670
671+ if isinstance (output_schema , Mapping ):
672+ output_schema_str = ", " .join (
673+ [f"{ name } { sql_type } " for name , sql_type in output_schema .items ()]
674+ )
675+ # Validate user input
676+ output_schemas .parse_sql_fields (output_schema_str )
677+ else :
678+ output_schema_str = output_schema
679+
669680 struct_fields_bq : Dict [str , bigframes .core .sql .literals .STRUCT_VALUES ] = {
670- "output_schema" : output_schema
681+ "output_schema" : output_schema_str
671682 }
672683 if temperature is not None :
673684 struct_fields_bq ["temperature" ] = temperature
0 commit comments