Skip to content

Commit bc8b76b

Browse files
refactor(bigframes): reduce code size of AI functions and tests (#17055)
The changes include: * Make all optional parameters default to None, which reduces the length of compiled SQLs * "Upper" the parameters at the API layer if necessary, so that we don't need to repeat this logic in both compilers later * Make all optional parameters default to None in the operator definitions too, saving some lines in the unit tests --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 7f0b54d commit bc8b76b

18 files changed

Lines changed: 74 additions & 161 deletions

File tree

  • packages/bigframes

packages/bigframes/bigframes/bigquery/_operations/ai.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def generate(
4848
*,
4949
connection_id: str | None = None,
5050
endpoint: str | None = None,
51-
request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified",
51+
request_type: Literal["dedicated", "shared", "unspecified"] | None = None,
5252
model_params: Mapping[Any, Any] | None = None,
5353
output_schema: Mapping[str, str] | None = None,
5454
) -> series.Series:
@@ -129,7 +129,7 @@ def generate(
129129
prompt_context=tuple(prompt_context),
130130
connection_id=connection_id,
131131
endpoint=endpoint,
132-
request_type=request_type,
132+
request_type=_upper_optional(request_type),
133133
model_params=json.dumps(model_params) if model_params else None,
134134
output_schema=output_schema_str,
135135
)
@@ -143,7 +143,7 @@ def generate_bool(
143143
*,
144144
connection_id: str | None = None,
145145
endpoint: str | None = None,
146-
request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified",
146+
request_type: Literal["dedicated", "shared", "unspecified"] | None = None,
147147
model_params: Mapping[Any, Any] | None = None,
148148
) -> series.Series:
149149
"""
@@ -207,7 +207,7 @@ def generate_bool(
207207
prompt_context=tuple(prompt_context),
208208
connection_id=connection_id,
209209
endpoint=endpoint,
210-
request_type=request_type,
210+
request_type=_upper_optional(request_type),
211211
model_params=json.dumps(model_params) if model_params else None,
212212
)
213213

@@ -220,7 +220,7 @@ def generate_int(
220220
*,
221221
connection_id: str | None = None,
222222
endpoint: str | None = None,
223-
request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified",
223+
request_type: Literal["dedicated", "shared", "unspecified"] | None = None,
224224
model_params: Mapping[Any, Any] | None = None,
225225
) -> series.Series:
226226
"""
@@ -281,7 +281,7 @@ def generate_int(
281281
prompt_context=tuple(prompt_context),
282282
connection_id=connection_id,
283283
endpoint=endpoint,
284-
request_type=request_type,
284+
request_type=_upper_optional(request_type),
285285
model_params=json.dumps(model_params) if model_params else None,
286286
)
287287

@@ -294,7 +294,7 @@ def generate_double(
294294
*,
295295
connection_id: str | None = None,
296296
endpoint: str | None = None,
297-
request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified",
297+
request_type: Literal["dedicated", "shared", "unspecified"] | None = None,
298298
model_params: Mapping[Any, Any] | None = None,
299299
) -> series.Series:
300300
"""
@@ -355,7 +355,7 @@ def generate_double(
355355
prompt_context=tuple(prompt_context),
356356
connection_id=connection_id,
357357
endpoint=endpoint,
358-
request_type=request_type,
358+
request_type=_upper_optional(request_type),
359359
model_params=json.dumps(model_params) if model_params else None,
360360
)
361361

@@ -753,7 +753,7 @@ def embed(
753753
operator = ai_ops.AIEmbed(
754754
endpoint=endpoint,
755755
model=model,
756-
task_type=task_type,
756+
task_type=_upper_optional(task_type),
757757
title=title,
758758
model_params=json.dumps(model_params) if model_params else None,
759759
connection_id=connection_id,
@@ -775,7 +775,7 @@ def if_(
775775
*,
776776
connection_id: str | None = None,
777777
endpoint: str | None = None,
778-
optimization_mode: Literal["minimize_cost", "maximize_quality"] = "minimize_cost",
778+
optimization_mode: Literal["minimize_cost", "maximize_quality"] | None = None,
779779
max_error_ratio: float | None = None,
780780
) -> series.Series:
781781
"""
@@ -830,7 +830,7 @@ def if_(
830830
prompt_context=tuple(prompt_context),
831831
connection_id=connection_id,
832832
endpoint=endpoint,
833-
optimization_mode=optimization_mode,
833+
optimization_mode=_upper_optional(optimization_mode),
834834
max_error_ratio=max_error_ratio,
835835
)
836836

@@ -904,7 +904,7 @@ def classify(
904904
examples=example_tuples,
905905
connection_id=connection_id,
906906
endpoint=endpoint,
907-
optimization_mode=optimization_mode,
907+
optimization_mode=_upper_optional(optimization_mode),
908908
max_error_ratio=max_error_ratio,
909909
)
910910

@@ -1225,3 +1225,9 @@ def _to_dataframe(
12251225
return data
12261226

12271227
raise ValueError(f"Unsupported data type: {type(data)}")
1228+
1229+
1230+
def _upper_optional(value: str | None) -> str | None:
1231+
if value is None:
1232+
return None
1233+
return value.upper()

packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1920,7 +1920,7 @@ def ai_generate(
19201920
_construct_prompt(values, op.prompt_context), # type: ignore
19211921
op.connection_id, # type: ignore
19221922
op.endpoint, # type: ignore
1923-
op.request_type.upper(), # type: ignore
1923+
op.request_type, # type: ignore
19241924
op.model_params, # type: ignore
19251925
op.output_schema, # type: ignore
19261926
).to_expr()
@@ -1934,7 +1934,7 @@ def ai_generate_bool(
19341934
_construct_prompt(values, op.prompt_context), # type: ignore
19351935
op.connection_id, # type: ignore
19361936
op.endpoint, # type: ignore
1937-
op.request_type.upper(), # type: ignore
1937+
op.request_type, # type: ignore
19381938
op.model_params, # type: ignore
19391939
).to_expr()
19401940

@@ -1947,7 +1947,7 @@ def ai_generate_int(
19471947
_construct_prompt(values, op.prompt_context), # type: ignore
19481948
op.connection_id, # type: ignore
19491949
op.endpoint, # type: ignore
1950-
op.request_type.upper(), # type: ignore
1950+
op.request_type, # type: ignore
19511951
op.model_params, # type: ignore
19521952
).to_expr()
19531953

@@ -1960,7 +1960,7 @@ def ai_generate_double(
19601960
_construct_prompt(values, op.prompt_context), # type: ignore
19611961
op.connection_id, # type: ignore
19621962
op.endpoint, # type: ignore
1963-
op.request_type.upper(), # type: ignore
1963+
op.request_type, # type: ignore
19641964
op.model_params, # type: ignore
19651965
).to_expr()
19661966

@@ -1972,7 +1972,7 @@ def ai_embed(value: ibis_types.Value, op: ops.AIEmbed) -> ibis_types.StructValue
19721972
connection_id=op.connection_id, # type: ignore
19731973
endpoint=op.endpoint, # type: ignore
19741974
model=op.model, # type: ignore
1975-
task_type=op.task_type.upper() if op.task_type is not None else None, # type: ignore
1975+
task_type=op.task_type, # type: ignore
19761976
title=op.title, # type: ignore
19771977
model_params=op.model_params, # type: ignore
19781978
).to_expr()
@@ -1984,7 +1984,7 @@ def ai_if(*values: ibis_types.Value, op: ops.AIIf) -> ibis_types.StructValue:
19841984
_construct_prompt(values, op.prompt_context), # type: ignore
19851985
op.connection_id, # type: ignore
19861986
op.endpoint, # type: ignore
1987-
op.optimization_mode.upper() if op.optimization_mode is not None else None, # type: ignore
1987+
op.optimization_mode, # type: ignore
19881988
op.max_error_ratio, # type: ignore
19891989
).to_expr()
19901990

@@ -1999,7 +1999,7 @@ def ai_classify(
19991999
_construct_examples(op.examples), # type: ignore
20002000
op.connection_id, # type: ignore
20012001
op.endpoint, # type: ignore
2002-
op.optimization_mode.upper() if op.optimization_mode is not None else None, # type: ignore
2002+
op.optimization_mode, # type: ignore
20032003
op.max_error_ratio, # type: ignore
20042004
).to_expr()
20052005

packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -139,16 +139,6 @@ def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]:
139139
expression=sge.JSON(this=sge.Literal.string(value)),
140140
)
141141
)
142-
elif field == "optimization_mode":
143-
args.append(
144-
sge.Kwarg(this=field, expression=sge.Literal.string(value.upper()))
145-
)
146-
elif field == "max_error_ratio":
147-
args.append(sge.Kwarg(this=field, expression=sge.Literal.number(value)))
148-
elif field == "request_type":
149-
args.append(
150-
sge.Kwarg(this=field, expression=sge.Literal.string(value.upper()))
151-
)
152142
elif field == "examples":
153143
example_expressions = [
154144
sge.Tuple(
@@ -160,8 +150,6 @@ def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]:
160150
sge.Kwarg(this=field, expression=sge.array(*example_expressions))
161151
)
162152
else:
163-
args.append(
164-
sge.Kwarg(this=field, expression=sge.Literal.string(str(value)))
165-
)
153+
args.append(sge.Kwarg(this=field, expression=sge.convert(value)))
166154

167155
return args

packages/bigframes/bigframes/operations/ai_ops.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
import dataclasses
18-
from typing import ClassVar, Literal, Tuple
18+
from typing import ClassVar, Tuple
1919

2020
import pandas as pd
2121
import pyarrow as pa
@@ -29,11 +29,11 @@ class AIGenerate(base_ops.NaryOp):
2929
name: ClassVar[str] = "ai_generate"
3030

3131
prompt_context: Tuple[str | None, ...]
32-
connection_id: str | None
33-
endpoint: str | None
34-
request_type: Literal["dedicated", "shared", "unspecified"]
35-
model_params: str | None
36-
output_schema: str | None
32+
connection_id: str | None = None
33+
endpoint: str | None = None
34+
request_type: str | None = None
35+
model_params: str | None = None
36+
output_schema: str | None = None
3737

3838
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
3939
if self.output_schema is None:
@@ -57,10 +57,10 @@ class AIGenerateBool(base_ops.NaryOp):
5757
name: ClassVar[str] = "ai_generate_bool"
5858

5959
prompt_context: Tuple[str | None, ...]
60-
connection_id: str | None
61-
endpoint: str | None
62-
request_type: Literal["dedicated", "shared", "unspecified"]
63-
model_params: str | None
60+
connection_id: str | None = None
61+
endpoint: str | None = None
62+
request_type: str | None = None
63+
model_params: str | None = None
6464

6565
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
6666
return pd.ArrowDtype(
@@ -79,10 +79,10 @@ class AIGenerateInt(base_ops.NaryOp):
7979
name: ClassVar[str] = "ai_generate_int"
8080

8181
prompt_context: Tuple[str | None, ...]
82-
connection_id: str | None
83-
endpoint: str | None
84-
request_type: Literal["dedicated", "shared", "unspecified"]
85-
model_params: str | None
82+
connection_id: str | None = None
83+
endpoint: str | None = None
84+
request_type: str | None = None
85+
model_params: str | None = None
8686

8787
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
8888
return pd.ArrowDtype(
@@ -101,10 +101,10 @@ class AIGenerateDouble(base_ops.NaryOp):
101101
name: ClassVar[str] = "ai_generate_double"
102102

103103
prompt_context: Tuple[str | None, ...]
104-
connection_id: str | None
105-
endpoint: str | None
106-
request_type: Literal["dedicated", "shared", "unspecified"]
107-
model_params: str | None
104+
connection_id: str | None = None
105+
endpoint: str | None = None
106+
request_type: str | None = None
107+
model_params: str | None = None
108108

109109
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
110110
return pd.ArrowDtype(
@@ -122,12 +122,12 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
122122
class AIEmbed(base_ops.UnaryOp):
123123
name: ClassVar[str] = "ai_embed"
124124

125-
endpoint: str | None
126-
model: str | None
127-
task_type: str | None
128-
title: str | None
129-
model_params: str | None
130-
connection_id: str | None
125+
endpoint: str | None = None
126+
model: str | None = None
127+
task_type: str | None = None
128+
title: str | None = None
129+
model_params: str | None = None
130+
connection_id: str | None = None
131131

132132
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
133133
return pd.ArrowDtype(
@@ -145,7 +145,7 @@ class AIIf(base_ops.NaryOp):
145145
name: ClassVar[str] = "ai_if"
146146

147147
prompt_context: Tuple[str | None, ...]
148-
connection_id: str | None
148+
connection_id: str | None = None
149149
endpoint: str | None = None
150150
optimization_mode: str | None = None
151151
max_error_ratio: float | None = None
@@ -160,11 +160,11 @@ class AIClassify(base_ops.NaryOp):
160160

161161
prompt_context: Tuple[str | None, ...]
162162
categories: tuple[str, ...]
163-
examples: tuple[tuple[str, str], ...] | None
164-
connection_id: str | None
165-
endpoint: str | None
166-
optimization_mode: str | None
167-
max_error_ratio: float | None
163+
examples: tuple[tuple[str, str], ...] | None = None
164+
connection_id: str | None = None
165+
endpoint: str | None = None
166+
optimization_mode: str | None = None
167+
max_error_ratio: float | None = None
168168

169169
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
170170
return dtypes.STRING_DTYPE
@@ -175,9 +175,9 @@ class AIScore(base_ops.NaryOp):
175175
name: ClassVar[str] = "ai_score"
176176

177177
prompt_context: Tuple[str | None, ...]
178-
connection_id: str | None
179-
endpoint: str | None
180-
max_error_ratio: float | None
178+
connection_id: str | None = None
179+
endpoint: str | None = None
180+
max_error_ratio: float | None = None
181181

182182
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
183183
return dtypes.FLOAT_DTYPE
@@ -187,10 +187,10 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
187187
class AISimilarity(base_ops.BinaryOp):
188188
name: ClassVar[str] = "ai_similarity"
189189

190-
endpoint: str | None
191-
model: str | None
192-
model_params: str | None
193-
connection_id: str | None
190+
endpoint: str | None = None
191+
model: str | None = None
192+
model_params: str | None = None
193+
connection_id: str | None = None
194194

195195
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
196196
return dtypes.FLOAT_DTYPE

packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_task_type_and_title/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ SELECT
22
AI.EMBED(
33
`string_col`,
44
endpoint => 'text-embedding-005',
5-
task_type => 'retrieval_document',
5+
task_type => 'RETRIEVAL_DOCUMENT',
66
title => 'My Document',
77
model_params => JSON '{"outputDimensionality": 256}'
88
) AS `result`
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
SELECT
22
AI.GENERATE_BOOL(
33
prompt => (`string_col`, ' is the same as ', `string_col`),
4-
endpoint => 'gemini-2.5-flash',
5-
request_type => 'SHARED'
4+
endpoint => 'gemini-2.5-flash'
65
) AS `result`
76
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`

packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool_with_connection_id/out.sql

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ SELECT
22
AI.GENERATE_BOOL(
33
prompt => (`string_col`, ' is the same as ', `string_col`),
44
connection_id => 'bigframes-dev.us.bigframes-default-connection',
5-
endpoint => 'gemini-2.5-flash',
6-
request_type => 'SHARED'
5+
endpoint => 'gemini-2.5-flash'
76
) AS `result`
87
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
SELECT
22
AI.GENERATE_BOOL(
33
prompt => (`string_col`, ' is the same as ', `string_col`),
4-
request_type => 'SHARED',
54
model_params => JSON '{}'
65
) AS `result`
76
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
SELECT
22
AI.GENERATE_DOUBLE(
33
prompt => (`string_col`, ' is the same as ', `string_col`),
4-
endpoint => 'gemini-2.5-flash',
5-
request_type => 'SHARED'
4+
endpoint => 'gemini-2.5-flash'
65
) AS `result`
76
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`

0 commit comments

Comments
 (0)