Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 3d74176

Browse files
Merge remote-tracking branch 'github/main' into fix_solo_if_else
2 parents 63a83d7 + 6b8154c commit 3d74176

File tree

23 files changed

+509
-13
lines changed

23 files changed

+509
-13
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,81 @@ def generate_int(
188188
return series_list[0]._apply_nary_op(operator, series_list[1:])
189189

190190

191+
@log_adapter.method_logger(custom_base_name="bigquery_ai")
192+
def generate_double(
193+
prompt: PROMPT_TYPE,
194+
*,
195+
connection_id: str | None = None,
196+
endpoint: str | None = None,
197+
request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified",
198+
model_params: Mapping[Any, Any] | None = None,
199+
) -> series.Series:
200+
"""
201+
Returns the AI analysis based on the prompt, which can be any combination of text and unstructured data.
202+
203+
**Examples:**
204+
205+
>>> import bigframes.pandas as bpd
206+
>>> import bigframes.bigquery as bbq
207+
>>> bpd.options.display.progress_bar = None
208+
>>> animal = bpd.Series(["Kangaroo", "Rabbit", "Spider"])
209+
>>> bbq.ai.generate_double(("How many legs does a ", animal, " have?"))
210+
0 {'result': 2.0, 'full_response': '{"candidates...
211+
1 {'result': 4.0, 'full_response': '{"candidates...
212+
2 {'result': 8.0, 'full_response': '{"candidates...
213+
dtype: struct<result: double, full_response: extension<dbjson<JSONArrowType>>, status: string>[pyarrow]
214+
215+
>>> bbq.ai.generate_double(("How many legs does a ", animal, " have?")).struct.field("result")
216+
0 2.0
217+
1 4.0
218+
2 8.0
219+
Name: result, dtype: Float64
220+
221+
Args:
222+
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
223+
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
224+
or pandas Series.
225+
connection_id (str, optional):
226+
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
227+
If not provided, the connection from the current session will be used.
228+
endpoint (str, optional):
229+
Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any
230+
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and
231+
uses the full endpoint of the model. If you don't specify an ENDPOINT value, BigQuery ML selects a recent stable
232+
version of Gemini to use.
233+
request_type (Literal["dedicated", "shared", "unspecified"]):
234+
Specifies the type of inference request to send to the Gemini model. The request type determines what quota the request uses.
235+
* "dedicated": function only uses Provisioned Throughput quota. The function returns the error Provisioned throughput is not
236+
purchased or is not active if Provisioned Throughput quota isn't available.
237+
* "shared": the function only uses dynamic shared quota (DSQ), even if you have purchased Provisioned Throughput quota.
238+
* "unspecified": If you haven't purchased Provisioned Throughput quota, the function uses DSQ quota.
239+
If you have purchased Provisioned Throughput quota, the function uses the Provisioned Throughput quota first.
240+
If requests exceed the Provisioned Throughput quota, the overflow traffic uses DSQ quota.
241+
model_params (Mapping[Any, Any]):
242+
Provides additional parameters to the model. The MODEL_PARAMS value must conform to the generateContent request body format.
243+
244+
Returns:
245+
bigframes.series.Series: A new struct Series with the result data. The struct contains these fields:
246+
* "result": an DOUBLE value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI.
247+
* "full_response": a JSON value containing the response from the projects.locations.endpoints.generateContent call to the model.
248+
The generated text is in the text element.
249+
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
250+
"""
251+
252+
prompt_context, series_list = _separate_context_and_series(prompt)
253+
assert len(series_list) > 0
254+
255+
operator = ai_ops.AIGenerateDouble(
256+
prompt_context=tuple(prompt_context),
257+
connection_id=_resolve_connection_id(series_list[0], connection_id),
258+
endpoint=endpoint,
259+
request_type=request_type,
260+
model_params=json.dumps(model_params) if model_params else None,
261+
)
262+
263+
return series_list[0]._apply_nary_op(operator, series_list[1:])
264+
265+
191266
def _separate_context_and_series(
192267
prompt: PROMPT_TYPE,
193268
) -> Tuple[List[str | None], List[series.Series]]:

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1990,7 +1990,7 @@ def ai_generate_bool(
19901990

19911991
@scalar_op_compiler.register_nary_op(ops.AIGenerateInt, pass_op=True)
19921992
def ai_generate_int(
1993-
*values: ibis_types.Value, op: ops.AIGenerateBool
1993+
*values: ibis_types.Value, op: ops.AIGenerateInt
19941994
) -> ibis_types.StructValue:
19951995

19961996
return ai_ops.AIGenerateInt(
@@ -2002,6 +2002,20 @@ def ai_generate_int(
20022002
).to_expr()
20032003

20042004

2005+
@scalar_op_compiler.register_nary_op(ops.AIGenerateDouble, pass_op=True)
2006+
def ai_generate_double(
2007+
*values: ibis_types.Value, op: ops.AIGenerateDouble
2008+
) -> ibis_types.StructValue:
2009+
2010+
return ai_ops.AIGenerateDouble(
2011+
_construct_prompt(values, op.prompt_context), # type: ignore
2012+
op.connection_id, # type: ignore
2013+
op.endpoint, # type: ignore
2014+
op.request_type.upper(), # type: ignore
2015+
op.model_params, # type: ignore
2016+
).to_expr()
2017+
2018+
20052019
def _construct_prompt(
20062020
col_refs: tuple[ibis_types.Value], prompt_context: tuple[str | None]
20072021
) -> ibis_types.StructValue:

bigframes/core/compile/sqlglot/aggregate_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def compile_analytic(
6363
window: window_spec.WindowSpec,
6464
) -> sge.Expression:
6565
if isinstance(aggregate, agg_expressions.NullaryAggregation):
66-
return nullary_compiler.compile(aggregate.op)
66+
return nullary_compiler.compile(aggregate.op, window)
6767
if isinstance(aggregate, agg_expressions.UnaryAggregation):
6868
column = typed_expr.TypedExpr(
6969
scalar_compiler.scalar_op_compiler.compile_expression(aggregate.arg),

bigframes/core/compile/sqlglot/aggregations/binary_compiler.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from bigframes.core import window_spec
2222
import bigframes.core.compile.sqlglot.aggregations.op_registration as reg
23+
from bigframes.core.compile.sqlglot.aggregations.windows import apply_window_if_present
2324
import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr
2425
from bigframes.operations import aggregations as agg_ops
2526

@@ -33,3 +34,25 @@ def compile(
3334
window: typing.Optional[window_spec.WindowSpec] = None,
3435
) -> sge.Expression:
3536
return BINARY_OP_REGISTRATION[op](op, left, right, window=window)
37+
38+
39+
@BINARY_OP_REGISTRATION.register(agg_ops.CorrOp)
40+
def _(
41+
op: agg_ops.CorrOp,
42+
left: typed_expr.TypedExpr,
43+
right: typed_expr.TypedExpr,
44+
window: typing.Optional[window_spec.WindowSpec] = None,
45+
) -> sge.Expression:
46+
result = sge.func("CORR", left.expr, right.expr)
47+
return apply_window_if_present(result, window)
48+
49+
50+
@BINARY_OP_REGISTRATION.register(agg_ops.CovOp)
51+
def _(
52+
op: agg_ops.CovOp,
53+
left: typed_expr.TypedExpr,
54+
right: typed_expr.TypedExpr,
55+
window: typing.Optional[window_spec.WindowSpec] = None,
56+
) -> sge.Expression:
57+
result = sge.func("COVAR_SAMP", left.expr, right.expr)
58+
return apply_window_if_present(result, window)

bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,15 @@ def _(
3939
window: typing.Optional[window_spec.WindowSpec] = None,
4040
) -> sge.Expression:
4141
return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window)
42+
43+
44+
@NULLARY_OP_REGISTRATION.register(agg_ops.RowNumberOp)
45+
def _(
46+
op: agg_ops.RowNumberOp,
47+
window: typing.Optional[window_spec.WindowSpec] = None,
48+
) -> sge.Expression:
49+
result: sge.Expression = sge.func("ROW_NUMBER")
50+
if window is None:
51+
# ROW_NUMBER always needs an OVER clause.
52+
return sge.Window(this=result)
53+
return apply_window_if_present(result, window)

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,7 @@ def _(
8484
column: typed_expr.TypedExpr,
8585
window: typing.Optional[window_spec.WindowSpec] = None,
8686
) -> sge.Expression:
87-
# Ranking functions do not support window framing clauses.
88-
return apply_window_if_present(
89-
sge.func("DENSE_RANK"), window, include_framing_clauses=False
90-
)
87+
return apply_window_if_present(sge.func("DENSE_RANK"), window)
9188

9289

9390
@UNARY_OP_REGISTRATION.register(agg_ops.MaxOp)
@@ -165,10 +162,7 @@ def _(
165162
column: typed_expr.TypedExpr,
166163
window: typing.Optional[window_spec.WindowSpec] = None,
167164
) -> sge.Expression:
168-
# Ranking functions do not support window framing clauses.
169-
return apply_window_if_present(
170-
sge.func("RANK"), window, include_framing_clauses=False
171-
)
165+
return apply_window_if_present(sge.func("RANK"), window)
172166

173167

174168
@UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp)

bigframes/core/compile/sqlglot/aggregations/windows.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
def apply_window_if_present(
2626
value: sge.Expression,
2727
window: typing.Optional[window_spec.WindowSpec] = None,
28-
include_framing_clauses: bool = True,
2928
) -> sge.Expression:
3029
if window is None:
3130
return value
@@ -65,7 +64,7 @@ def apply_window_if_present(
6564
if not window.bounds and not order:
6665
return sge.Window(this=value, partition_by=group_by)
6766

68-
if not window.bounds and not include_framing_clauses:
67+
if not window.bounds:
6968
return sge.Window(this=value, partition_by=group_by, order=order)
7069

7170
kind = (

bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ def _(*exprs: TypedExpr, op: ops.AIGenerateInt) -> sge.Expression:
4040
return sge.func("AI.GENERATE_INT", *args)
4141

4242

43+
@register_nary_op(ops.AIGenerateDouble, pass_op=True)
44+
def _(*exprs: TypedExpr, op: ops.AIGenerateDouble) -> sge.Expression:
45+
args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op)
46+
47+
return sge.func("AI.GENERATE_DOUBLE", *args)
48+
49+
4350
def _construct_prompt(
4451
exprs: tuple[TypedExpr, ...], prompt_context: tuple[str | None, ...]
4552
) -> sge.Kwarg:

bigframes/operations/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import annotations
1616

17-
from bigframes.operations.ai_ops import AIGenerateBool, AIGenerateInt
17+
from bigframes.operations.ai_ops import AIGenerateBool, AIGenerateDouble, AIGenerateInt
1818
from bigframes.operations.array_ops import (
1919
ArrayIndexOp,
2020
ArrayReduceOp,
@@ -413,6 +413,7 @@
413413
"GeoStDistanceOp",
414414
# AI ops
415415
"AIGenerateBool",
416+
"AIGenerateDouble",
416417
"AIGenerateInt",
417418
# Numpy ops mapping
418419
"NUMPY_TO_BINOP",

bigframes/operations/ai_ops.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,25 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
6666
)
6767
)
6868
)
69+
70+
71+
@dataclasses.dataclass(frozen=True)
72+
class AIGenerateDouble(base_ops.NaryOp):
73+
name: ClassVar[str] = "ai_generate_double"
74+
75+
prompt_context: Tuple[str | None, ...]
76+
connection_id: str
77+
endpoint: str | None
78+
request_type: Literal["dedicated", "shared", "unspecified"]
79+
model_params: str | None
80+
81+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
82+
return pd.ArrowDtype(
83+
pa.struct(
84+
(
85+
pa.field("result", pa.float64()),
86+
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
87+
pa.field("status", pa.string()),
88+
)
89+
)
90+
)

0 commit comments

Comments
 (0)