Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 442654d

Browse files
committed
fix stuffs
1 parent d250e14 commit 442654d

File tree

4 files changed

+19
-19
lines changed

4 files changed

+19
-19
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424

2525
def ai_generate_bool(
26-
prompt: series.Series | List[str | series.Series] | Tuple[str | series.Series],
26+
prompt: series.Series | List[str | series.Series] | Tuple[str | series.Series, ...],
2727
*,
2828
connection_id: str | None = None,
2929
endpoint: str | None = None,
@@ -71,7 +71,7 @@ def ai_generate_bool(
7171
Name: result, dtype: boolean
7272
7373
Args:
74-
prompt (series.Series | List[str|series.Series] | Tuple[str|series.Series]):
74+
prompt (series.Series | List[str|series.Series] | Tuple[str|series.Series, ...]):
7575
A mixture of Series and string literals that specifies the prompt to send to the model.
7676
connection_id (str, optional):
7777
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
@@ -87,15 +87,15 @@ def ai_generate_bool(
8787
* "dedicated": function only uses Provisioned Throughput quota. The AI.GENERATE function returns the error Provisioned throughput is not purchased or is not active if Provisioned Throughput quota isn't available.
8888
* "shared": the function only uses dynamic shared quota (DSQ), even if you have purchased Provisioned Throughput quota.
8989
* "unspecified":
90-
* If you haven't purchased Provisioned Throughput quota, the function uses DSQ quota.
91-
* If you have purchased Provisioned Throughput quota, the function uses the Provisioned Throughput quota first. If requests exceed the Provisioned Throughput quota, the overflow traffic uses DSQ quota.
90+
If you haven't purchased Provisioned Throughput quota, the function uses DSQ quota.
91+
If you have purchased Provisioned Throughput quota, the function uses the Provisioned Throughput quota first. If requests exceed the Provisioned Throughput quota, the overflow traffic uses DSQ quota.
9292
model_params (Mapping[Any, Any]):
9393
Provides additional parameters to the model. The MODEL_PARAMS value must conform to the generateContent request body format.
9494
9595
Returns:
9696
bigframes.series.Series: A new struct Series with the result data. The struct contains these fields:
9797
* "result": a BOOL value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI.
98-
* "full_resposne": a STRING value containing the JSON response from the projects.locations.endpoints.generateContent call to the model. The generated text is in the text element.
98+
* "full_response": a STRING value containing the JSON response from the projects.locations.endpoints.generateContent call to the model. The generated text is in the text element.
9999
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
100100
"""
101101

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1968,26 +1968,24 @@ def struct_op_impl(
19681968
@scalar_op_compiler.register_nary_op(ops.AIGenerateBool, pass_op=True)
19691969
def ai_generate_bool(
19701970
*values: ibis_types.Value, op: ops.AIGenerateBool
1971-
) -> ibis_dtypes.StructValue:
1971+
) -> ibis_types.StructValue:
19721972

1973-
prompt = {}
1973+
prompt: dict[str, ibis_types.Value | str] = {}
19741974
column_ref_idx = 0
19751975

19761976
for idx, elem in enumerate(op.prompt_context):
19771977
if elem is None:
1978-
value = values[column_ref_idx]
1978+
prompt[f"_field_{idx + 1}"] = values[column_ref_idx]
19791979
column_ref_idx += 1
19801980
else:
1981-
value = elem
1982-
1983-
prompt[f"_field_{idx + 1}"] = value
1981+
prompt[f"_field_{idx + 1}"] = elem
19841982

19851983
return ai_ops.AIGenerateBool(
1986-
ibis.struct(prompt),
1987-
op.connection_id,
1988-
op.endpoint,
1989-
op.request_type.upper(),
1990-
op.model_params,
1984+
ibis.struct(prompt), # type: ignore
1985+
op.connection_id,# type: ignore
1986+
op.endpoint,# type: ignore
1987+
op.request_type.upper(),# type: ignore
1988+
op.model_params,# type: ignore
19911989
).to_expr()
19921990

19931991

bigframes/operations/ai_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class AIGenerateBool(base_ops.NaryOp):
2929
name: ClassVar[str] = "ai_generate_bool"
3030

3131
# None are the placeholders for column references.
32-
prompt_context: Tuple[str | None]
32+
prompt_context: Tuple[str | None, ...]
3333
connection_id: str
3434
endpoint: str | None
3535
request_type: Literal["dedicated", "shared", "unspecified"]

third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from __future__ import annotations
66

7+
from typing import Optional
8+
79
from bigframes_vendored.ibis.common.annotations import attribute
810
import bigframes_vendored.ibis.expr.datatypes as dt
911
from bigframes_vendored.ibis.expr.operations.core import Value
@@ -17,9 +19,9 @@ class AIGenerateBool(Value):
1719

1820
prompt: Value
1921
connection_id: Value[dt.String]
20-
endpoint: Value[dt.String] | None
22+
endpoint: Optional[Value[dt.String]]
2123
request_type: Value[dt.String]
22-
model_params: Value[dt.String] | None
24+
model_params: Optional[Value[dt.String]]
2325

2426
shape = rlz.shape_like("prompt")
2527

0 commit comments

Comments
 (0)