Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit a0a1d83

Browse files
committed
Merge remote-tracking branch 'origin/main' into feat-inplace-param-for-drop
2 parents d271d0c + 60056ca commit a0a1d83

File tree

11 files changed

+221
-23
lines changed

11 files changed

+221
-23
lines changed

bigframes/core/compile/sqlglot/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import annotations
1515

1616
from bigframes.core.compile.sqlglot.compiler import SQLGlotCompiler
17+
import bigframes.core.compile.sqlglot.expressions.ai_ops # noqa: F401
1718
import bigframes.core.compile.sqlglot.expressions.array_ops # noqa: F401
1819
import bigframes.core.compile.sqlglot.expressions.blob_ops # noqa: F401
1920
import bigframes.core.compile.sqlglot.expressions.comparison_ops # noqa: F401
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import sqlglot.expressions as sge
18+
19+
from bigframes import operations as ops
20+
from bigframes.core.compile.sqlglot import scalar_compiler
21+
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
22+
23+
register_nary_op = scalar_compiler.scalar_op_compiler.register_nary_op
24+
25+
26+
@register_nary_op(ops.AIGenerateBool, pass_op=True)
27+
def _(*exprs: TypedExpr, op: ops.AIGenerateBool) -> sge.Expression:
28+
29+
prompt: list[str | sge.Expression] = []
30+
column_ref_idx = 0
31+
32+
for elem in op.prompt_context:
33+
if elem is None:
34+
prompt.append(exprs[column_ref_idx].expr)
35+
else:
36+
prompt.append(sge.Literal.string(elem))
37+
38+
args = [sge.Kwarg(this="prompt", expression=sge.Tuple(expressions=prompt))]
39+
40+
args.append(
41+
sge.Kwarg(this="connection_id", expression=sge.Literal.string(op.connection_id))
42+
)
43+
44+
if op.endpoint is not None:
45+
args.append(
46+
sge.Kwarg(this="endpoint", expression=sge.Literal.string(op.endpoint))
47+
)
48+
49+
args.append(
50+
sge.Kwarg(
51+
this="request_type", expression=sge.Literal.string(op.request_type.upper())
52+
)
53+
)
54+
55+
if op.model_params is not None:
56+
args.append(
57+
sge.Kwarg(
58+
this="model_params",
59+
# sge.JSON requires a newer SQLGlot version than 23.6.3.
60+
# PARSE_JSON won't work as the function requires a JSON literal.
61+
expression=sge.JSON(this=sge.Literal.string(op.model_params)),
62+
)
63+
)
64+
65+
return sge.func("AI.GENERATE_BOOL", *args)

bigframes/core/indexes/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,9 @@ def to_numpy(self, dtype=None, *, allow_large_results=None, **kwargs) -> np.ndar
740740

741741
__array__ = to_numpy
742742

743+
def to_list(self, *, allow_large_results: Optional[bool] = None) -> list:
744+
return self.to_pandas(allow_large_results=allow_large_results).to_list()
745+
743746
def __len__(self):
744747
return self.shape[0]
745748

tests/system/small/operations/test_strings.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,20 @@ def test_reverse(scalars_dfs):
236236

237237

238238
@pytest.mark.parametrize(
239-
["start", "stop"], [(0, 1), (3, 5), (100, 101), (None, 1), (0, 12), (0, None)]
239+
["start", "stop"],
240+
[
241+
(0, 1),
242+
(3, 5),
243+
(100, 101),
244+
(None, 1),
245+
(0, 12),
246+
(0, None),
247+
(None, -1),
248+
(-1, None),
249+
(-5, -1),
250+
(1, -1),
251+
(-10, 10),
252+
],
240253
)
241254
def test_slice(scalars_dfs, start, stop):
242255
scalars_df, scalars_pandas_df = scalars_dfs

tests/system/small/test_index.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,12 @@ def test_index_item_with_empty(session):
638638
bf_idx_empty.item()
639639

640640

641+
def test_index_to_list(scalars_df_index, scalars_pandas_df_index):
642+
bf_result = scalars_df_index.index.to_list()
643+
pd_result = scalars_pandas_df_index.index.to_list()
644+
assert bf_result == pd_result
645+
646+
641647
@pytest.mark.parametrize(
642648
("key", "value"),
643649
[
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
AI.GENERATE_BOOL(
9+
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
10+
connection_id => 'test_connection_id',
11+
endpoint => 'gemini-2.5-flash',
12+
request_type => 'SHARED'
13+
) AS `bfcol_1`
14+
FROM `bfcte_0`
15+
)
16+
SELECT
17+
`bfcol_1` AS `result`
18+
FROM `bfcte_1`
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
AI.GENERATE_BOOL(
9+
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
10+
connection_id => 'test_connection_id',
11+
request_type => 'SHARED',
12+
model_params => JSON '{}'
13+
) AS `bfcol_1`
14+
FROM `bfcte_0`
15+
)
16+
SELECT
17+
`bfcol_1` AS `result`
18+
FROM `bfcte_1`
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
import sys
17+
18+
import pytest
19+
20+
from bigframes import dataframe
21+
from bigframes import operations as ops
22+
from bigframes.testing import utils
23+
24+
pytest.importorskip("pytest_snapshot")
25+
26+
27+
def test_ai_generate_bool(scalar_types_df: dataframe.DataFrame, snapshot):
28+
col_name = "string_col"
29+
30+
op = ops.AIGenerateBool(
31+
prompt_context=(None, " is the same as ", None),
32+
connection_id="test_connection_id",
33+
endpoint="gemini-2.5-flash",
34+
request_type="shared",
35+
model_params=None,
36+
)
37+
38+
sql = utils._apply_unary_ops(
39+
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
40+
)
41+
42+
snapshot.assert_match(sql, "out.sql")
43+
44+
45+
def test_ai_generate_bool_with_model_param(
46+
scalar_types_df: dataframe.DataFrame, snapshot
47+
):
48+
if sys.version_info < (3, 10):
49+
pytest.skip(
50+
"Skip test because SQLGLot cannot compile model params to JSON at this env."
51+
)
52+
53+
col_name = "string_col"
54+
55+
op = ops.AIGenerateBool(
56+
prompt_context=(None, " is the same as ", None),
57+
connection_id="test_connection_id",
58+
endpoint=None,
59+
request_type="shared",
60+
model_params=json.dumps(dict()),
61+
)
62+
63+
sql = utils._apply_unary_ops(
64+
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
65+
)
66+
67+
snapshot.assert_match(sql, "out.sql")

tests/unit/test_index.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import pandas as pd
1516
import pytest
1617

1718
from bigframes.testing import mocks
@@ -38,3 +39,13 @@ def test_index_rename_inplace_returns_none(monkeypatch: pytest.MonkeyPatch):
3839
# Make sure the linked DataFrame is updated, too.
3940
assert dataframe.index.name == "my_index_name"
4041
assert index.name == "my_index_name"
42+
43+
44+
def test_index_to_list(monkeypatch: pytest.MonkeyPatch):
45+
pd_index = pd.Index([1, 2, 3], name="my_index")
46+
df = mocks.create_dataframe(
47+
monkeypatch,
48+
data={"my_index": [1, 2, 3]},
49+
).set_index("my_index")
50+
bf_index = df.index
51+
assert bf_index.to_list() == pd_index.to_list()

third_party/bigframes_vendored/ibis/expr/rewrites.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -206,21 +206,26 @@ def replace_parameter(_, params, **kwargs):
206206
@replace(p.StringSlice)
207207
def lower_stringslice(_, **kwargs):
208208
"""Rewrite StringSlice in terms of Substring."""
209-
if _.end is None:
210-
return ops.Substring(_.arg, start=_.start)
211209
if _.start is None:
212-
return ops.Substring(_.arg, start=0, length=_.end)
213-
if (
214-
isinstance(_.start, ops.Literal)
215-
and isinstance(_.start.value, int)
216-
and isinstance(_.end, ops.Literal)
217-
and isinstance(_.end.value, int)
218-
):
219-
# optimization for constant values
220-
length = _.end.value - _.start.value
210+
real_start = 0
221211
else:
222-
length = ops.Subtract(_.end, _.start)
223-
return ops.Substring(_.arg, start=_.start, length=length)
212+
real_start = ops.IfElse(
213+
ops.GreaterEqual(_.start, 0),
214+
_.start,
215+
ops.Greatest((0, ops.Add(ops.StringLength(_.arg), _.start))),
216+
)
217+
218+
if _.end is None:
219+
real_end = ops.StringLength(_.arg)
220+
else:
221+
real_end = ops.IfElse(
222+
ops.GreaterEqual(_.end, 0),
223+
_.end,
224+
ops.Greatest((0, ops.Add(ops.StringLength(_.arg), _.end))),
225+
)
226+
227+
length = ops.Greatest((0, ops.Subtract(real_end, real_start)))
228+
return ops.Substring(_.arg, start=real_start, length=length)
224229

225230

226231
@replace(p.Analytic)

0 commit comments

Comments
 (0)