Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 19cb613

Browse files
committed
Merge branch 'main' into shuowei-feat-persist-obj-ref
2 parents 179bde3 + 11fd2f1 commit 19cb613

File tree

67 files changed

+10267
-9590
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+10267
-9590
lines changed

.github/workflows/docs.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,3 @@ jobs:
3636
run: |
3737
python -m pip install --upgrade setuptools pip wheel
3838
python -m pip install nox
39-
- name: Run docfx
40-
run: |
41-
nox -s docfx

bigframes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
)
3333
import bigframes.enums as enums # noqa: E402
3434
import bigframes.exceptions as exceptions # noqa: E402
35+
36+
# Register pandas extensions
37+
import bigframes.extensions.pandas.dataframe_accessor # noqa: F401, E402
3538
from bigframes.session import connect, Session # noqa: E402
3639
from bigframes.version import __version__ # noqa: E402
3740

bigframes/bigquery/_operations/ai.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ def if_(
745745
or pandas Series.
746746
connection_id (str, optional):
747747
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
748-
If not provided, the connection from the current session will be used.
748+
If not provided, the query uses your end-user credential.
749749
750750
Returns:
751751
bigframes.series.Series: A new series of bools.
@@ -756,7 +756,7 @@ def if_(
756756

757757
operator = ai_ops.AIIf(
758758
prompt_context=tuple(prompt_context),
759-
connection_id=_resolve_connection_id(series_list[0], connection_id),
759+
connection_id=connection_id,
760760
)
761761

762762
return series_list[0]._apply_nary_op(operator, series_list[1:])
@@ -800,7 +800,7 @@ def classify(
800800
Categories to classify the input into.
801801
connection_id (str, optional):
802802
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
803-
If not provided, the connection from the current session will be used.
803+
If not provided, the query uses your end-user credential.
804804
805805
Returns:
806806
bigframes.series.Series: A new series of strings.
@@ -812,7 +812,7 @@ def classify(
812812
operator = ai_ops.AIClassify(
813813
prompt_context=tuple(prompt_context),
814814
categories=tuple(categories),
815-
connection_id=_resolve_connection_id(series_list[0], connection_id),
815+
connection_id=connection_id,
816816
)
817817

818818
return series_list[0]._apply_nary_op(operator, series_list[1:])
@@ -853,7 +853,7 @@ def score(
853853
or pandas Series.
854854
connection_id (str, optional):
855855
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
856-
If not provided, the connection from the current session will be used.
856+
If not provided, the query uses your end-user credential.
857857
858858
Returns:
859859
bigframes.series.Series: A new series of double (float) values.
@@ -864,7 +864,7 @@ def score(
864864

865865
operator = ai_ops.AIScore(
866866
prompt_context=tuple(prompt_context),
867-
connection_id=_resolve_connection_id(series_list[0], connection_id),
867+
connection_id=connection_id,
868868
)
869869

870870
return series_list[0]._apply_nary_op(operator, series_list[1:])

bigframes/bigquery/_operations/io.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
import pandas as pd
2020

2121
from bigframes.bigquery._operations.table import _get_table_metadata
22+
import bigframes.core.compile.sqlglot.sql as sql
2223
import bigframes.core.logging.log_adapter as log_adapter
23-
import bigframes.core.sql.io
2424
import bigframes.session
2525

2626

@@ -73,7 +73,7 @@ def load_data(
7373
"""
7474
import bigframes.pandas as bpd
7575

76-
sql = bigframes.core.sql.io.load_data_ddl(
76+
load_data_expr = sql.load_data(
7777
table_name=table_name,
7878
write_disposition=write_disposition,
7979
columns=columns,
@@ -84,11 +84,12 @@ def load_data(
8484
with_partition_columns=with_partition_columns,
8585
connection_name=connection_name,
8686
)
87+
sql_text = sql.to_sql(load_data_expr)
8788

8889
if session is None:
89-
bpd.read_gbq_query(sql)
90+
bpd.read_gbq_query(sql_text)
9091
session = bpd.get_global_session()
9192
else:
92-
session.read_gbq_query(sql)
93+
session.read_gbq_query(sql_text)
9394

9495
return _get_table_metadata(bqclient=session.bqclient, table_name=table_name)

bigframes/bigquery/_operations/sql.py

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,31 @@
1616

1717
from __future__ import annotations
1818

19-
from typing import Sequence
19+
from typing import cast, Optional, Sequence, Union
2020

2121
import google.cloud.bigquery
2222

2323
from bigframes.core.compile.sqlglot import sql
24+
import bigframes.dataframe
2425
import bigframes.dtypes
2526
import bigframes.operations
2627
import bigframes.series
2728

2829

30+
def _format_names(sql_template: str, dataframe: bigframes.dataframe.DataFrame):
31+
"""Turn sql_template from a template that uses names to one that uses
32+
numbers.
33+
"""
34+
names_to_numbers = {name: f"{{{i}}}" for i, name in enumerate(dataframe.columns)}
35+
numbers = [f"{{{i}}}" for i in range(len(dataframe.columns))]
36+
return sql_template.format(*numbers, **names_to_numbers)
37+
38+
2939
def sql_scalar(
3040
sql_template: str,
31-
columns: Sequence[bigframes.series.Series],
41+
columns: Union[bigframes.dataframe.DataFrame, Sequence[bigframes.series.Series]],
42+
*,
43+
output_dtype: Optional[bigframes.dtypes.Dtype] = None,
3244
) -> bigframes.series.Series:
3345
"""Create a Series from a SQL template.
3446
@@ -37,6 +49,9 @@ def sql_scalar(
3749
>>> import bigframes.pandas as bpd
3850
>>> import bigframes.bigquery as bbq
3951
52+
Either pass in a sequence of series, in which case use integers in the
53+
format strings.
54+
4055
>>> s = bpd.Series(["1.5", "2.5", "3.5"])
4156
>>> s = s.astype(pd.ArrowDtype(pa.decimal128(38, 9)))
4257
>>> bbq.sql_scalar("ROUND({0}, 0, 'ROUND_HALF_EVEN')", [s])
@@ -45,13 +60,29 @@ def sql_scalar(
4560
2 4.000000000
4661
dtype: decimal128(38, 9)[pyarrow]
4762
63+
Or pass in a DataFrame, in which case use the column names in the format
64+
strings.
65+
66+
>>> df = bpd.DataFrame({"a": ["1.5", "2.5", "3.5"]})
67+
>>> df = df.astype({"a": pd.ArrowDtype(pa.decimal128(38, 9))})
68+
>>> bbq.sql_scalar("ROUND({a}, 0, 'ROUND_HALF_EVEN')", df)
69+
0 2.000000000
70+
1 2.000000000
71+
2 4.000000000
72+
dtype: decimal128(38, 9)[pyarrow]
73+
4874
Args:
4975
sql_template (str):
5076
A SQL format string with Python-style {0} placeholders for each of
5177
the Series objects in ``columns``.
52-
columns (Sequence[bigframes.pandas.Series]):
78+
columns (
79+
Sequence[bigframes.pandas.Series] | bigframes.pandas.DataFrame
80+
):
5381
Series objects representing the column inputs to the
5482
``sql_template``. Must contain at least one Series.
83+
output_dtype (a BigQuery DataFrames compatible dtype, optional):
84+
If provided, BigQuery DataFrames uses this to determine the output
85+
of the returned Series. This avoids a dry run query.
5586
5687
Returns:
5788
bigframes.pandas.Series:
@@ -60,28 +91,38 @@ def sql_scalar(
6091
Raises:
6192
ValueError: If ``columns`` is empty.
6293
"""
94+
if isinstance(columns, bigframes.dataframe.DataFrame):
95+
sql_template = _format_names(sql_template, columns)
96+
columns = [
97+
cast(bigframes.series.Series, columns[column]) for column in columns.columns
98+
]
99+
63100
if len(columns) == 0:
64101
raise ValueError("Must provide at least one column in columns")
65102

103+
base_series = columns[0]
104+
66105
# To integrate this into our expression trees, we need to get the output
67106
# type, so we do some manual compilation and a dry run query to get that.
68107
# Another benefit of this is that if there is a syntax error in the SQL
69108
# template, then this will fail with an error earlier in the process,
70109
# aiding users in debugging.
71-
literals_sql = [sql.to_sql(sql.literal(None, column.dtype)) for column in columns]
72-
select_sql = sql_template.format(*literals_sql)
73-
dry_run_sql = f"SELECT {select_sql}"
74-
75-
# Use the executor directly, because we want the original column IDs, not
76-
# the user-friendly column names that block.to_sql_query() would produce.
77-
base_series = columns[0]
78-
bqclient = base_series._session.bqclient
79-
job = bqclient.query(
80-
dry_run_sql, job_config=google.cloud.bigquery.QueryJobConfig(dry_run=True)
81-
)
82-
_, output_type = bigframes.dtypes.convert_schema_field(job.schema[0])
110+
if output_dtype is None:
111+
literals_sql = [
112+
sql.to_sql(sql.literal(None, column.dtype)) for column in columns
113+
]
114+
select_sql = sql_template.format(*literals_sql)
115+
dry_run_sql = f"SELECT {select_sql}"
116+
117+
# Use the executor directly, because we want the original column IDs, not
118+
# the user-friendly column names that block.to_sql_query() would produce.
119+
bqclient = base_series._session.bqclient
120+
job = bqclient.query(
121+
dry_run_sql, job_config=google.cloud.bigquery.QueryJobConfig(dry_run=True)
122+
)
123+
_, output_dtype = bigframes.dtypes.convert_schema_field(job.schema[0])
83124

84125
op = bigframes.operations.SqlScalarOp(
85-
_output_type=output_type, sql_template=sql_template
126+
_output_type=output_dtype, sql_template=sql_template
86127
)
87128
return base_series._apply_nary_op(op, columns[1:])

bigframes/core/bigframe_node.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,12 +330,32 @@ def top_down(
330330
"""
331331
Perform a top-down transformation of the BigFrameNode tree.
332332
"""
333+
results: Dict[BigFrameNode, BigFrameNode] = {}
334+
# Each stack entry is (node, t_node). t_node is None until transform(node) is called.
335+
stack: list[tuple[BigFrameNode, typing.Optional[BigFrameNode]]] = [(self, None)]
333336

334-
@functools.cache
335-
def recursive_transform(node: BigFrameNode) -> BigFrameNode:
336-
return transform(node).transform_children(recursive_transform)
337+
while stack:
338+
node, t_node = stack[-1]
339+
340+
if t_node is None:
341+
if node in results:
342+
stack.pop()
343+
continue
344+
t_node = transform(node)
345+
stack[-1] = (node, t_node)
346+
347+
all_done = True
348+
for child in reversed(t_node.child_nodes):
349+
if child not in results:
350+
stack.append((child, None))
351+
all_done = False
352+
break
353+
354+
if all_done:
355+
results[node] = t_node.transform_children(lambda x: results[x])
356+
stack.pop()
337357

338-
return recursive_transform(self)
358+
return results[self]
339359

340360
def bottom_up(
341361
self: BigFrameNode,

bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,9 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
113113
)
114114
)
115115

116-
endpoit = op_args.get("endpoint", None)
117-
if endpoit is not None:
118-
args.append(sge.Kwarg(this="endpoint", expression=sge.Literal.string(endpoit)))
116+
endpoint = op_args.get("endpoint", None)
117+
if endpoint is not None:
118+
args.append(sge.Kwarg(this="endpoint", expression=sge.Literal.string(endpoint)))
119119

120120
request_type = op_args.get("request_type", None)
121121
if request_type is not None:

bigframes/core/compile/sqlglot/expressions/datetime_ops.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,12 @@ def _(expr: TypedExpr, op: ops.ToDatetimeOp) -> sge.Expression:
371371
)
372372
return sge.Cast(this=result, to="DATETIME")
373373

374-
if expr.dtype in (dtypes.STRING_DTYPE, dtypes.TIMESTAMP_DTYPE):
374+
if expr.dtype in (
375+
dtypes.STRING_DTYPE,
376+
dtypes.TIMESTAMP_DTYPE,
377+
dtypes.DATETIME_DTYPE,
378+
dtypes.DATE_DTYPE,
379+
):
375380
return sge.TryCast(this=expr.expr, to="DATETIME")
376381

377382
value = expr.expr
@@ -396,7 +401,12 @@ def _(expr: TypedExpr, op: ops.ToTimestampOp) -> sge.Expression:
396401
"PARSE_TIMESTAMP", sge.convert(op.format), expr.expr, sge.convert("UTC")
397402
)
398403

399-
if expr.dtype in (dtypes.STRING_DTYPE, dtypes.DATETIME_DTYPE):
404+
if expr.dtype in (
405+
dtypes.STRING_DTYPE,
406+
dtypes.DATETIME_DTYPE,
407+
dtypes.TIMESTAMP_DTYPE,
408+
dtypes.DATE_DTYPE,
409+
):
400410
return sge.func("TIMESTAMP", expr.expr)
401411

402412
value = expr.expr

bigframes/core/compile/sqlglot/sql/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
table,
2323
to_sql,
2424
)
25+
from bigframes.core.compile.sqlglot.sql.ddl import load_data
2526
from bigframes.core.compile.sqlglot.sql.dml import insert, replace
2627

2728
__all__ = [
@@ -33,6 +34,8 @@
3334
"literal",
3435
"table",
3536
"to_sql",
37+
# From ddl.py
38+
"load_data",
3639
# From dml.py
3740
"insert",
3841
"replace",

0 commit comments

Comments
 (0)