Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 38c4d6f

Browse files
committed
fix unit tests
1 parent 2505296 commit 38c4d6f

File tree

10 files changed

+179
-102
lines changed

10 files changed

+179
-102
lines changed

bigframes/bigquery/_operations/sql.py

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,31 @@
1616

1717
from __future__ import annotations
1818

19-
from typing import Sequence
19+
from typing import cast, Optional, Sequence, Union
2020

2121
import google.cloud.bigquery
2222

2323
from bigframes.core.compile.sqlglot import sql
24+
import bigframes.dataframe
2425
import bigframes.dtypes
2526
import bigframes.operations
2627
import bigframes.series
2728

2829

30+
def _format_names(sql_template: str, dataframe: bigframes.dataframe.DataFrame):
31+
"""Turn sql_template from a template that uses names to one that uses
32+
numbers.
33+
"""
34+
names_to_numbers = {name: f"{{{i}}}" for i, name in enumerate(dataframe.columns)}
35+
numbers = [f"{{{i}}}" for i in range(len(dataframe.columns))]
36+
return sql_template.format(*numbers, **names_to_numbers)
37+
38+
2939
def sql_scalar(
3040
sql_template: str,
31-
columns: Sequence[bigframes.series.Series],
41+
columns: Union[bigframes.dataframe.DataFrame, Sequence[bigframes.series.Series]],
42+
*,
43+
output_dtype: Optional[bigframes.dtypes.Dtype] = None,
3244
) -> bigframes.series.Series:
3345
"""Create a Series from a SQL template.
3446
@@ -37,6 +49,9 @@ def sql_scalar(
3749
>>> import bigframes.pandas as bpd
3850
>>> import bigframes.bigquery as bbq
3951
52+
Either pass in a sequence of series, in which case use integers in the
53+
format strings.
54+
4055
>>> s = bpd.Series(["1.5", "2.5", "3.5"])
4156
>>> s = s.astype(pd.ArrowDtype(pa.decimal128(38, 9)))
4257
>>> bbq.sql_scalar("ROUND({0}, 0, 'ROUND_HALF_EVEN')", [s])
@@ -45,13 +60,29 @@ def sql_scalar(
4560
2 4.000000000
4661
dtype: decimal128(38, 9)[pyarrow]
4762
63+
Or pass in a DataFrame, in which case use the column names in the format
64+
strings.
65+
66+
>>> df = bpd.DataFrame({"a": ["1.5", "2.5", "3.5"]})
67+
>>> df = df.astype({"a": pd.ArrowDtype(pa.decimal128(38, 9))})
68+
>>> bbq.sql_scalar("ROUND({a}, 0, 'ROUND_HALF_EVEN')", df)
69+
0 2.000000000
70+
1 2.000000000
71+
2 4.000000000
72+
dtype: decimal128(38, 9)[pyarrow]
73+
4874
Args:
4975
sql_template (str):
5076
A SQL format string with Python-style {0} placeholders for each of
5177
the Series objects in ``columns``.
52-
columns (Sequence[bigframes.pandas.Series]):
78+
columns (
79+
Sequence[bigframes.pandas.Series] | bigframes.pandas.DataFrame
80+
):
5381
Series objects representing the column inputs to the
5482
``sql_template``. Must contain at least one Series.
83+
output_dtype (a BigQuery DataFrames compatible dtype, optional):
84+
If provided, BigQuery DataFrames uses this to determine the output
85+
of the returned Series. This avoids a dry run query.
5586
5687
Returns:
5788
bigframes.pandas.Series:
@@ -60,28 +91,38 @@ def sql_scalar(
6091
Raises:
6192
ValueError: If ``columns`` is empty.
6293
"""
94+
if isinstance(columns, bigframes.dataframe.DataFrame):
95+
sql_template = _format_names(sql_template, columns)
96+
columns = [
97+
cast(bigframes.series.Series, columns[column]) for column in columns.columns
98+
]
99+
63100
if len(columns) == 0:
64101
raise ValueError("Must provide at least one column in columns")
65102

103+
base_series = columns[0]
104+
66105
# To integrate this into our expression trees, we need to get the output
67106
# type, so we do some manual compilation and a dry run query to get that.
68107
# Another benefit of this is that if there is a syntax error in the SQL
69108
# template, then this will fail with an error earlier in the process,
70109
# aiding users in debugging.
71-
literals_sql = [sql.to_sql(sql.literal(None, column.dtype)) for column in columns]
72-
select_sql = sql_template.format(*literals_sql)
73-
dry_run_sql = f"SELECT {select_sql}"
74-
75-
# Use the executor directly, because we want the original column IDs, not
76-
# the user-friendly column names that block.to_sql_query() would produce.
77-
base_series = columns[0]
78-
bqclient = base_series._session.bqclient
79-
job = bqclient.query(
80-
dry_run_sql, job_config=google.cloud.bigquery.QueryJobConfig(dry_run=True)
81-
)
82-
_, output_type = bigframes.dtypes.convert_schema_field(job.schema[0])
110+
if output_dtype is None:
111+
literals_sql = [
112+
sql.to_sql(sql.literal(None, column.dtype)) for column in columns
113+
]
114+
select_sql = sql_template.format(*literals_sql)
115+
dry_run_sql = f"SELECT {select_sql}"
116+
117+
# Use the executor directly, because we want the original column IDs, not
118+
# the user-friendly column names that block.to_sql_query() would produce.
119+
bqclient = base_series._session.bqclient
120+
job = bqclient.query(
121+
dry_run_sql, job_config=google.cloud.bigquery.QueryJobConfig(dry_run=True)
122+
)
123+
_, output_dtype = bigframes.dtypes.convert_schema_field(job.schema[0])
83124

84125
op = bigframes.operations.SqlScalarOp(
85-
_output_type=output_type, sql_template=sql_template
126+
_output_type=output_dtype, sql_template=sql_template
86127
)
87128
return base_series._apply_nary_op(op, columns[1:])

bigframes/core/blocks.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2666,7 +2666,11 @@ def _array_value_for_output(
26662666
)
26672667

26682668
def to_sql_query(
2669-
self, include_index: bool, enable_cache: bool = True
2669+
self,
2670+
include_index: bool,
2671+
enable_cache: bool = True,
2672+
*,
2673+
ordered=False,
26702674
) -> Tuple[str, list[str], list[Label]]:
26712675
"""
26722676
Compiles this DataFrame's expression tree to SQL, optionally
@@ -2688,7 +2692,9 @@ def to_sql_query(
26882692
# Note: this uses the sql from the executor, so is coupled tightly to execution
26892693
# implementaton. It will reference cached tables instead of original data sources.
26902694
# Maybe should just compile raw BFET? Depends on user intent.
2691-
sql = self.session._executor.to_sql(array_value, enable_cache=enable_cache)
2695+
sql = self.session._executor.to_sql(
2696+
array_value, enable_cache=enable_cache, ordered=ordered
2697+
)
26922698
return (
26932699
sql,
26942700
idx_ids,

bigframes/dataframe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def _to_placeholder_table(self, dry_run: bool = False) -> bigquery.TableReferenc
447447
)
448448

449449
def _to_sql_query(
450-
self, include_index: bool, enable_cache: bool = True
450+
self, include_index: bool, enable_cache: bool = True, *, ordered: bool = False
451451
) -> Tuple[str, list[str], list[blocks.Label]]:
452452
"""Compiles this DataFrame's expression tree to SQL, optionally
453453
including index columns.
@@ -461,7 +461,9 @@ def _to_sql_query(
461461
If include_index is set to False, index_column_id_list and index_column_label_list
462462
return empty lists.
463463
"""
464-
return self._block.to_sql_query(include_index, enable_cache=enable_cache)
464+
return self._block.to_sql_query(
465+
include_index, enable_cache=enable_cache, ordered=ordered
466+
)
465467

466468
@property
467469
def sql(self) -> str:

bigframes/extensions/pandas/dataframe_accessor.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class BigQueryDataFrameAccessor:
3232
def __init__(self, pandas_obj: pandas.DataFrame):
3333
self._obj = pandas_obj
3434

35-
def sql_scalar(self, sql_template: str, session=None):
35+
def sql_scalar(self, sql_template: str, *, output_dtype=None, session=None):
3636
"""
3737
Compute a new pandas Series by applying a SQL scalar function to the DataFrame.
3838
@@ -44,22 +44,24 @@ def sql_scalar(self, sql_template: str, session=None):
4444
sql_template (str):
4545
A SQL format string with Python-style {0}, {1}, etc. placeholders for each of
4646
the columns in the DataFrame (in the order they appear in ``df.columns``).
47+
output_dtype (a BigQuery DataFrames compatible dtype, optional):
48+
If provided, BigQuery DataFrames uses this to determine the output
49+
of the returned Series. This avoids a dry run query.
4750
session (bigframes.session.Session, optional):
4851
The BigFrames session to use. If not provided, the default global session is used.
4952
5053
Returns:
5154
pandas.Series:
5255
The result of the SQL scalar function as a pandas Series.
5356
"""
54-
if session is None:
55-
session = bf_session.get_global_session()
56-
57-
bf_df = cast(bpd.DataFrame, session.read_pandas(self._obj))
58-
5957
# Import bigframes.bigquery here to avoid circular imports
6058
import bigframes.bigquery
6159

62-
columns = [cast(bpd.Series, bf_df[col]) for col in bf_df.columns]
63-
result = bigframes.bigquery.sql_scalar(sql_template, columns)
60+
if session is None:
61+
session = bf_session.get_global_session()
6462

65-
return result.to_pandas()
63+
bf_df = cast(bpd.DataFrame, session.read_pandas(self._obj))
64+
result = bigframes.bigquery.sql_scalar(
65+
sql_template, bf_df, output_dtype=output_dtype
66+
)
67+
return result.to_pandas(ordered=True)

notebooks/getting_started/pandas_extensions.ipynb

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,36 @@
1111
},
1212
{
1313
"cell_type": "code",
14-
"execution_count": 4,
14+
"execution_count": 1,
1515
"metadata": {},
1616
"outputs": [],
1717
"source": [
1818
"import pandas as pd\n",
1919
"import bigframes # This import registers the bigquery accessor."
2020
]
2121
},
22+
{
23+
"cell_type": "markdown",
24+
"metadata": {},
25+
"source": [
26+
"By default, BigQuery DataFrames selects a location to process data based on the\n",
27+
"data location, but using a pandas object doesn't provide such informat. If\n",
28+
"processing location is important to you, configure the location before using the\n",
29+
"accessor."
30+
]
31+
},
32+
{
33+
"cell_type": "code",
34+
"execution_count": 2,
35+
"metadata": {},
36+
"outputs": [],
37+
"source": [
38+
"import bigframes.pandas as bpd\n",
39+
"\n",
40+
"bpd.reset_session()\n",
41+
"bpd.options.bigquery.location = \"US\""
42+
]
43+
},
2244
{
2345
"cell_type": "markdown",
2446
"metadata": {},
@@ -30,7 +52,7 @@
3052
},
3153
{
3254
"cell_type": "code",
33-
"execution_count": 5,
55+
"execution_count": 3,
3456
"metadata": {},
3557
"outputs": [
3658
{
@@ -56,7 +78,7 @@
5678
"dtype: Float64"
5779
]
5880
},
59-
"execution_count": 5,
81+
"execution_count": 3,
6082
"metadata": {},
6183
"output_type": "execute_result"
6284
}
@@ -76,7 +98,7 @@
7698
},
7799
{
78100
"cell_type": "code",
79-
"execution_count": 6,
101+
"execution_count": 4,
80102
"metadata": {},
81103
"outputs": [
82104
{
@@ -102,14 +124,14 @@
102124
"dtype: Int64"
103125
]
104126
},
105-
"execution_count": 6,
127+
"execution_count": 4,
106128
"metadata": {},
107129
"output_type": "execute_result"
108130
}
109131
],
110132
"source": [
111133
"df = pd.DataFrame({\"a\": [1, 2, 3], \"b\": [10, 20, 30]})\n",
112-
"result = df.bigquery.sql_scalar(\"{0} + {1}\")\n",
134+
"result = df.bigquery.sql_scalar(\"{a} + {b}\")\n",
113135
"result"
114136
]
115137
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
SELECT
2+
`rowindex`,
3+
ROUND(`int64_col` + `int64_too`) AS `0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest.mock as mock
16+
17+
import pandas as pd
18+
19+
import bigframes.pandas as bpd
20+
import bigframes.session
21+
22+
23+
def test_sql_scalar(scalar_types_df: bpd.DataFrame, snapshot, monkeypatch):
24+
session = mock.create_autospec(bigframes.session.Session)
25+
session.read_pandas.return_value = scalar_types_df
26+
27+
def to_pandas(series, ordered=True):
28+
sql, _, _ = series.to_frame()._to_sql_query(include_index=True, ordered=ordered)
29+
return sql
30+
31+
monkeypatch.setattr(bpd.Series, "to_pandas", to_pandas)
32+
33+
df = pd.DataFrame({"int64_col": [1, 2], "int64_too": [3, 4]})
34+
result = df.bigquery.sql_scalar(
35+
"ROUND({int64_col} + {int64_too})",
36+
output_dtype=pd.Int64Dtype(),
37+
session=session,
38+
)
39+
40+
session.read_pandas.assert_called_once()
41+
snapshot.assert_match(result, "out.sql")

0 commit comments

Comments
 (0)