Skip to content

Commit e9c52b1

Browse files
authored
feat(bigframes): add more params to ai.classify (#16990)
1 parent cef659d commit e9c52b1

8 files changed

Lines changed: 109 additions & 17 deletions

File tree

packages/bigframes/bigframes/bigquery/_operations/ai.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import pandas as pd
2525

26-
from bigframes import clients, dataframe, dtypes, series, session
26+
from bigframes import dataframe, dtypes, series, session
2727
from bigframes import pandas as bpd
2828
from bigframes.bigquery._operations import utils as bq_utils
2929
from bigframes.core import convert
@@ -885,7 +885,11 @@ def classify(
885885
input: PROMPT_TYPE,
886886
categories: tuple[str, ...] | list[str],
887887
*,
888+
examples: list[tuple[str, str]] | None = None,
888889
connection_id: str | None = None,
890+
endpoint: str | None = None,
891+
optimization_mode: Literal["minimize_cost", "maximize_quality"] | None = None,
892+
max_error_ratio: float | None = None,
889893
) -> series.Series:
890894
"""
891895
Classifies a given input into one of the specified categories. It will always return one of the provided categories best fit the prompt input.
@@ -903,22 +907,30 @@ def classify(
903907
<BLANKLINE>
904908
[2 rows x 2 columns]
905909
906-
.. note::
907-
908-
This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
909-
Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is"
910-
and might have limited support. For more information, see the launch stage descriptions
911-
(https://cloud.google.com/products#product-launch-stages).
912-
913910
Args:
914911
input (str | Series | List[str|Series] | Tuple[str|Series, ...]):
915912
A mixture of Series and string literals that specifies the input to send to the model. The Series can be BigFrames Series
916913
or pandas Series.
917914
categories (tuple[str, ...] | list[str]):
918915
Categories to classify the input into.
916+
examples (list[tuple[str, str]], optional):
917+
An array that contains representative examples of input strings and the output category
918+
that you expect. You can provide examples to help the model understand your
919+
intended threshold for a condition with nuanced or subjective logic. We recommend providing at most 5 examples.
919920
connection_id (str, optional):
920921
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
921922
If not provided, the query uses your end-user credential.
923+
endpoint (str, optional):
924+
A STRING value that specifies the Vertex AI endpoint to use for the model. You can specify any
925+
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically
926+
identifies and uses the full endpoint of the model.
927+
optimization_mode (Literal["minimize_cost", "maximize_quality"], optional):
928+
A STRING value that specifies the optimization strategy to use. Supported values are ``minimize_cost``
929+
and ``maximize_quality``.
930+
max_error_ratio (float, optional):
931+
A value between ``0.0`` and ``1.0`` that contains the maximum acceptable ratio of row-level
932+
inference failures to rows processed on this function. The default value is 1.0.
933+
This argument isn't supported when ``optimization_mode`` is set to ``minimize_cost``.
922934
923935
Returns:
924936
bigframes.series.Series: A new series of strings.
@@ -927,10 +939,16 @@ def classify(
927939
prompt_context, series_list = _separate_context_and_series(input)
928940
assert len(series_list) > 0
929941

942+
example_tuples = tuple(examples) if examples is not None else None
943+
930944
operator = ai_ops.AIClassify(
931945
prompt_context=tuple(prompt_context),
932946
categories=tuple(categories),
947+
examples=example_tuples,
933948
connection_id=connection_id,
949+
endpoint=endpoint,
950+
optimization_mode=optimization_mode,
951+
max_error_ratio=max_error_ratio,
934952
)
935953

936954
return series_list[0]._apply_nary_op(operator, series_list[1:])
@@ -1249,14 +1267,6 @@ def _convert_series(
12491267
return result
12501268

12511269

1252-
def _resolve_connection_id(series: series.Series, connection_id: str | None):
1253-
return clients.get_canonical_bq_connection_id(
1254-
connection_id or series._session.bq_connection,
1255-
series._session._project,
1256-
series._session._location,
1257-
)
1258-
1259-
12601270
def _to_dataframe(
12611271
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
12621272
series_rename: str,

packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1996,7 +1996,11 @@ def ai_classify(
19961996
return ai_ops.AIClassify(
19971997
_construct_prompt(values, op.prompt_context), # type: ignore
19981998
op.categories, # type: ignore
1999+
_construct_examples(op.examples), # type: ignore
19992000
op.connection_id, # type: ignore
2001+
op.endpoint, # type: ignore
2002+
op.optimization_mode.upper() if op.optimization_mode is not None else None, # type: ignore
2003+
op.max_error_ratio, # type: ignore
20002004
).to_expr()
20012005

20022006

@@ -2040,6 +2044,26 @@ def _construct_prompt(
20402044
return ibis.struct(prompt)
20412045

20422046

2047+
def _construct_examples(
2048+
examples: tuple[tuple[str, str]] | None,
2049+
) -> ibis_types.ArrayValue | None:
2050+
if examples is None:
2051+
return None
2052+
2053+
results: list[ibis_types.StructValue] = []
2054+
2055+
for example in examples:
2056+
ibis_example = ibis.struct(
2057+
{
2058+
"_field_1": example[0],
2059+
"_field_2": example[1],
2060+
}
2061+
)
2062+
results.append(ibis_example)
2063+
2064+
return ibis.array(results)
2065+
2066+
20432067
@scalar_op_compiler.register_nary_op(ops.RowKey, pass_op=True)
20442068
def rowkey_op_impl(*values: ibis_types.Value, op: ops.RowKey) -> ibis_types.Value:
20452069
return bigframes.core.compile.ibis_compiler.default_ordering.gen_row_key(values)

packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,16 @@ def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]:
149149
args.append(
150150
sge.Kwarg(this=field, expression=sge.Literal.string(value.upper()))
151151
)
152+
elif field == "examples":
153+
example_expressions = [
154+
sge.Tuple(
155+
expressions=[sge.Literal.string(key), sge.Literal.string(val)]
156+
)
157+
for key, val in value
158+
]
159+
args.append(
160+
sge.Kwarg(this=field, expression=sge.array(*example_expressions))
161+
)
152162
else:
153163
args.append(
154164
sge.Kwarg(this=field, expression=sge.Literal.string(str(value)))

packages/bigframes/bigframes/operations/ai_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,11 @@ class AIClassify(base_ops.NaryOp):
160160

161161
prompt_context: Tuple[str | None, ...]
162162
categories: tuple[str, ...]
163+
examples: tuple[tuple[str, str], ...] | None
163164
connection_id: str | None
165+
endpoint: str | None
166+
optimization_mode: str | None
167+
max_error_ratio: float | None
164168

165169
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
166170
return dtypes.STRING_DTYPE

packages/bigframes/tests/system/small/bigquery/test_ai.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,15 @@ def test_ai_classify(session):
355355
assert result.dtype == dtypes.STRING_DTYPE
356356

357357

358+
def test_ai_classify_with_examples(session):
359+
s = bpd.Series(["cat", "orchid"], session=session)
360+
361+
result = bbq.ai.classify(s, ["animal", "plant"], examples=[("dog", "animal")])
362+
363+
assert len(result) == len(s)
364+
assert result.dtype == dtypes.STRING_DTYPE
365+
366+
358367
def test_ai_classify_multi_model(session, bq_connection):
359368
df = session.from_glob_path(
360369
"gs://bigframes-dev-testing/a_multimodel/images/*",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
SELECT
2+
AI.CLASSIFY(
3+
input => (`string_col`),
4+
categories => ['greeting', 'rejection'],
5+
examples => [('hi', 'greeting'), ('bye', 'rejection')],
6+
endpoint => 'gemini-2.5-flash',
7+
max_error_ratio => 0.1
8+
) AS `result`
9+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`

packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,29 @@ def test_ai_classify(scalar_types_df: dataframe.DataFrame, snapshot, connection_
392392
op = ops.AIClassify(
393393
prompt_context=(None,),
394394
categories=("greeting", "rejection"),
395+
examples=None,
395396
connection_id=connection_id,
397+
endpoint=None,
398+
optimization_mode=None,
399+
max_error_ratio=None,
400+
)
401+
402+
sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"])
403+
404+
snapshot.assert_match(sql, "out.sql")
405+
406+
407+
def test_ai_classify_with_params(scalar_types_df: dataframe.DataFrame, snapshot):
408+
col_name = "string_col"
409+
410+
op = ops.AIClassify(
411+
prompt_context=(None,),
412+
categories=("greeting", "rejection"),
413+
examples=(("hi", "greeting"), ("bye", "rejection")),
414+
connection_id=None,
415+
endpoint="gemini-2.5-flash",
416+
optimization_mode=None,
417+
max_error_ratio=0.1,
396418
)
397419

398420
sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"])

packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,16 @@ class AIClassify(Value):
155155

156156
input: Value
157157
categories: Value[dt.Array[dt.String]]
158+
examples: Optional[Value]
158159
connection_id: Optional[Value[dt.String]]
160+
endpoint: Optional[Value[dt.String]]
161+
optimization_mode: Optional[Value[dt.String]]
162+
max_error_ratio: Optional[Value[dt.Float64]]
159163

160164
shape = rlz.shape_like("input")
161165

162166
@attribute
163-
def dtype(self) -> dt.Struct:
167+
def dtype(self) -> dt.DataType:
164168
return dt.string
165169

166170

0 commit comments

Comments
 (0)