Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit a7dcb83

Browse files
committed
Merge remote-tracking branch 'origin/main' into tswast-notebooks
2 parents 06e684b + 43353e2 commit a7dcb83

File tree

74 files changed

+2039
-1593
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+2039
-1593
lines changed

bigframes/_magics.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,4 @@ def _cell_magic(line, cell):
4848
if args.destination_var:
4949
ipython.push({args.destination_var: dataframe})
5050

51-
with bigframes.option_context(
52-
"display.repr_mode",
53-
"anywidget",
54-
):
55-
display(dataframe)
51+
display(dataframe)

bigframes/bigquery/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
to_json,
5959
to_json_string,
6060
)
61+
from bigframes.bigquery._operations.mathematical import rand
6162
from bigframes.bigquery._operations.search import create_vector_index, vector_search
6263
from bigframes.bigquery._operations.sql import sql_scalar
6364
from bigframes.bigquery._operations.struct import struct
@@ -99,6 +100,8 @@
99100
parse_json,
100101
to_json,
101102
to_json_string,
103+
# mathematical ops
104+
rand,
102105
# search ops
103106
create_vector_index,
104107
vector_search,
@@ -154,6 +157,8 @@
154157
"parse_json",
155158
"to_json",
156159
"to_json_string",
160+
# mathematical ops
161+
"rand",
157162
# search ops
158163
"create_vector_index",
159164
"vector_search",

bigframes/bigquery/_operations/ai.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@
2828
from bigframes import series, session
2929
from bigframes.bigquery._operations import utils as bq_utils
3030
from bigframes.core import convert
31+
from bigframes.core.compile.sqlglot import sql as sg_sql
3132
from bigframes.core.logging import log_adapter
32-
import bigframes.core.sql.literals
33+
from bigframes.ml import base as ml_base
3334
from bigframes.ml import core as ml_core
3435
from bigframes.operations import ai_ops, output_schemas
3536

@@ -392,7 +393,7 @@ def generate_double(
392393

393394
@log_adapter.method_logger(custom_base_name="bigquery_ai")
394395
def generate_embedding(
395-
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
396+
model: Union[ml_base.BaseEstimator, str, pd.Series],
396397
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
397398
*,
398399
output_dimensionality: Optional[int] = None,
@@ -416,7 +417,7 @@ def generate_embedding(
416417
... ) # doctest: +SKIP
417418
418419
Args:
419-
model (bigframes.ml.base.BaseEstimator or str):
420+
model (ml_base.BaseEstimator or str):
420421
The model to use for text embedding.
421422
data (bigframes.pandas.DataFrame or bigframes.pandas.Series):
422423
The data to generate embeddings for. If a Series is provided, it is
@@ -458,7 +459,7 @@ def generate_embedding(
458459
model_name, session = bq_utils.get_model_name_and_session(model, data)
459460
table_sql = bq_utils.to_sql(data)
460461

461-
struct_fields: Dict[str, bigframes.core.sql.literals.STRUCT_VALUES] = {}
462+
struct_fields: Dict[str, Any] = {}
462463
if output_dimensionality is not None:
463464
struct_fields["OUTPUT_DIMENSIONALITY"] = output_dimensionality
464465
if task_type is not None:
@@ -478,7 +479,7 @@ def generate_embedding(
478479
FROM AI.GENERATE_EMBEDDING(
479480
MODEL `{model_name}`,
480481
({table_sql}),
481-
{bigframes.core.sql.literals.struct_literal(struct_fields)}
482+
{sg_sql.to_sql(sg_sql.literal(struct_fields))}
482483
)
483484
"""
484485

@@ -490,7 +491,7 @@ def generate_embedding(
490491

491492
@log_adapter.method_logger(custom_base_name="bigquery_ai")
492493
def generate_text(
493-
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
494+
model: Union[ml_base.BaseEstimator, str, pd.Series],
494495
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
495496
*,
496497
temperature: Optional[float] = None,
@@ -519,7 +520,7 @@ def generate_text(
519520
... ) # doctest: +SKIP
520521
521522
Args:
522-
model (bigframes.ml.base.BaseEstimator or str):
523+
model (ml_base.BaseEstimator or str):
523524
The model to use for text generation.
524525
data (bigframes.pandas.DataFrame or bigframes.pandas.Series):
525526
The data to generate text for. If a Series is provided, it is
@@ -591,7 +592,7 @@ def generate_text(
591592
FROM AI.GENERATE_TEXT(
592593
MODEL `{model_name}`,
593594
({table_sql}),
594-
{bigframes.core.sql.literals.struct_literal(struct_fields)}
595+
{sg_sql.to_sql(sg_sql.literal(struct_fields))}
595596
)
596597
"""
597598

@@ -603,7 +604,7 @@ def generate_text(
603604

604605
@log_adapter.method_logger(custom_base_name="bigquery_ai")
605606
def generate_table(
606-
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
607+
model: Union[ml_base.BaseEstimator, str, pd.Series],
607608
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
608609
*,
609610
output_schema: Union[str, Mapping[str, str]],
@@ -635,7 +636,7 @@ def generate_table(
635636
... ) # doctest: +SKIP
636637
637638
Args:
638-
model (bigframes.ml.base.BaseEstimator or str):
639+
model (ml_base.BaseEstimator or str):
639640
The model to use for table generation.
640641
data (bigframes.pandas.DataFrame or bigframes.pandas.Series):
641642
The data to generate table for. If a Series is provided, it is
@@ -677,9 +678,7 @@ def generate_table(
677678
else:
678679
output_schema_str = output_schema
679680

680-
struct_fields_bq: Dict[str, bigframes.core.sql.literals.STRUCT_VALUES] = {
681-
"output_schema": output_schema_str
682-
}
681+
struct_fields_bq: Dict[str, Any] = {"output_schema": output_schema_str}
683682
if temperature is not None:
684683
struct_fields_bq["temperature"] = temperature
685684
if top_p is not None:
@@ -691,7 +690,7 @@ def generate_table(
691690
if request_type is not None:
692691
struct_fields_bq["request_type"] = request_type
693692

694-
struct_sql = bigframes.core.sql.literals.struct_literal(struct_fields_bq)
693+
struct_sql = sg_sql.to_sql(sg_sql.literal(struct_fields_bq))
695694
query = f"""
696695
SELECT *
697696
FROM AI.GENERATE_TABLE(
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from bigframes import dtypes
18+
from bigframes import operations as ops
19+
import bigframes.core.col
20+
import bigframes.core.expression
21+
22+
23+
def rand() -> bigframes.core.col.Expression:
24+
"""
25+
Generates a pseudo-random value of type FLOAT64 in the range of [0, 1),
26+
inclusive of 0 and exclusive of 1.
27+
28+
.. warning::
29+
This method introduces non-determinism to the expression. Reading the
30+
same column twice may result in different results. The value might
31+
change. Do not use this value or any value derived from it as a join
32+
key.
33+
34+
**Examples:**
35+
36+
>>> import bigframes.pandas as bpd
37+
>>> import bigframes.bigquery as bbq
38+
>>> df = bpd.DataFrame({"a": [1, 2, 3]})
39+
>>> df['random'] = bbq.rand()
40+
>>> # Resulting column 'random' will contain random floats between 0 and 1.
41+
42+
Returns:
43+
bigframes.pandas.api.typing.Expression:
44+
An expression that can be used in
45+
:func:`~bigframes.pandas.DataFrame.assign` and other methods. See
46+
:func:`bigframes.pandas.col`.
47+
"""
48+
op = ops.SqlScalarOp(
49+
_output_type=dtypes.FLOAT_DTYPE,
50+
sql_template="RAND()",
51+
is_deterministic=False,
52+
)
53+
return bigframes.core.col.Expression(bigframes.core.expression.OpExpression(op, ()))

bigframes/core/bigframe_node.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -330,22 +330,12 @@ def top_down(
330330
"""
331331
Perform a top-down transformation of the BigFrameNode tree.
332332
"""
333-
to_process = [self]
334-
results: Dict[BigFrameNode, BigFrameNode] = {}
335333

336-
while to_process:
337-
item = to_process.pop()
338-
if item not in results.keys():
339-
item_result = transform(item)
340-
results[item] = item_result
341-
to_process.extend(item_result.child_nodes)
334+
@functools.cache
335+
def recursive_transform(node: BigFrameNode) -> BigFrameNode:
336+
return transform(node).transform_children(recursive_transform)
342337

343-
to_process = [self]
344-
# for each processed item, replace its children
345-
for item in reversed(list(results.keys())):
346-
results[item] = results[item].transform_children(lambda x: results[x])
347-
348-
return results[self]
338+
return recursive_transform(self)
349339

350340
def bottom_up(
351341
self: BigFrameNode,

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def compile_sql(request: configs.CompileRequest) -> configs.CompileResult:
6262
if request.sort_rows:
6363
result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node))
6464
encoded_type_refs = data_type_logger.encode_type_refs(result_node)
65+
# TODO: Extract CTEs earlier
66+
result_node = typing.cast(nodes.ResultNode, rewrite.extract_ctes(result_node))
6567
sql = _compile_result_node(result_node)
6668
return configs.CompileResult(
6769
sql,
@@ -74,6 +76,8 @@ def compile_sql(request: configs.CompileRequest) -> configs.CompileResult:
7476
result_node = dataclasses.replace(result_node, order_by=None)
7577
result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node))
7678
encoded_type_refs = data_type_logger.encode_type_refs(result_node)
79+
# TODO: Extract CTEs earlier
80+
result_node = typing.cast(nodes.ResultNode, rewrite.extract_ctes(result_node))
7781
sql = _compile_result_node(result_node)
7882
# Return the ordering iff no extra columns are needed to define the row order
7983
if ordering is not None:
@@ -94,6 +98,7 @@ def _remap_variables(
9498
result_node, _ = rewrite.remap_variables(
9599
node, map(identifiers.ColumnId, uid_gen.get_uid_stream("bfcol_"))
96100
)
101+
result_node.validate_tree()
97102
return typing.cast(nodes.ResultNode, result_node)
98103

99104

@@ -102,13 +107,16 @@ def _compile_result_node(root: nodes.ResultNode) -> str:
102107
# of nodes using the same generator.
103108
uid_gen = guid.SequentialUIDGenerator()
104109
root = _remap_variables(root, uid_gen)
110+
# Remap variables creates too mayn new
111+
# root = rewrite.select_pullup(root, prefer_source_names=False)
105112
root = typing.cast(nodes.ResultNode, rewrite.defer_selection(root))
106113

107114
# Have to bind schema as the final step before compilation.
108115
# Probably, should defer even further
109116
root = typing.cast(nodes.ResultNode, schema_binding.bind_schema_to_tree(root))
110117

111-
sqlglot_ir_obj = compile_node(rewrite.as_sql_nodes(root), uid_gen)
118+
# TODO: Bake all IDs in tree, stop passing uid_gen to emitters
119+
sqlglot_ir_obj = compile_node(rewrite.as_sql_nodes(root, uid_gen), uid_gen)
112120
return sqlglot_ir_obj.sql
113121

114122

@@ -121,7 +129,7 @@ def compile_node(
121129
for current_node in list(node.iter_nodes_topo()):
122130
if current_node.child_nodes == ():
123131
# For leaf node, generates a dumpy child to pass the UID generator.
124-
child_results = tuple([sqlglot_ir.SQLGlotIR(uid_gen=uid_gen)])
132+
child_results = tuple([sqlglot_ir.SQLGlotIR.empty(uid_gen=uid_gen)])
125133
else:
126134
# Child nodes should have been compiled in the reverse topological order.
127135
child_results = tuple(
@@ -256,6 +264,23 @@ def compile_isin_join(
256264
)
257265

258266

267+
@_compile_node.register
268+
def compile_cte_ref_node(node: sql_nodes.SqlCteRefNode, child: sqlglot_ir.SQLGlotIR):
269+
return sqlglot_ir.SQLGlotIR.from_cte_ref(
270+
node.cte_name,
271+
uid_gen=child.uid_gen,
272+
)
273+
274+
275+
@_compile_node.register
276+
def compile_with_ctes_node(
277+
node: sql_nodes.SqlWithCtesNode,
278+
child: sqlglot_ir.SQLGlotIR,
279+
*ctes: sqlglot_ir.SQLGlotIR,
280+
):
281+
return child.with_ctes(tuple(zip(node.cte_names, ctes)))
282+
283+
259284
@_compile_node.register
260285
def compile_concat(
261286
node: nodes.ConcatNode, *children: sqlglot_ir.SQLGlotIR
@@ -271,7 +296,7 @@ def compile_concat(
271296
]
272297

273298
return sqlglot_ir.SQLGlotIR.from_union(
274-
[child._as_select() for child in children],
299+
[child.expr.as_select_all() for child in children],
275300
output_aliases=output_aliases,
276301
uid_gen=uid_gen,
277302
)

bigframes/core/compile/sqlglot/expressions/comparison_ops.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,27 +33,39 @@
3333
@register_unary_op(ops.IsInOp, pass_op=True)
3434
def _(expr: TypedExpr, op: ops.IsInOp) -> sge.Expression:
3535
values = []
36+
# bools are not comparable to non-bools in SQL, so we need to cast the expression to INT64 if the values contain non-bools.
37+
must_upcast_bools = dtypes.is_numeric(expr.dtype, include_bool=False) or any(
38+
dtypes.is_numeric(dtypes.bigframes_type(type(value)), include_bool=False)
39+
for value in op.values
40+
if not _is_null(value)
41+
)
3642
for value in op.values:
3743
if _is_null(value):
3844
continue
3945
dtype = dtypes.bigframes_type(type(value))
4046
if dtypes.can_compare(expr.dtype, dtype):
47+
if must_upcast_bools and dtype == dtypes.BOOL_DTYPE:
48+
value = int(value)
4149
values.append(sge.convert(value))
4250

51+
sg_lexpr: sge.Expression = expr.expr
52+
if expr.dtype == dtypes.BOOL_DTYPE and must_upcast_bools:
53+
sg_lexpr = sge.cast(expr.expr, "INT64")
54+
4355
if op.match_nulls:
4456
contains_nulls = any(_is_null(value) for value in op.values)
4557
if contains_nulls:
4658
if len(values) == 0:
47-
return sge.Is(this=expr.expr, expression=sge.Null())
48-
return sge.Is(this=expr.expr, expression=sge.Null()) | sge.In(
49-
this=expr.expr, expressions=values
59+
return sge.Is(this=sg_lexpr, expression=sge.Null())
60+
return sge.Is(this=sg_lexpr, expression=sge.Null()) | sge.In(
61+
this=sg_lexpr, expressions=values
5062
)
5163

5264
if len(values) == 0:
5365
return sge.convert(False)
5466

5567
return sge.func(
56-
"COALESCE", sge.In(this=expr.expr, expressions=values), sge.convert(False)
68+
"COALESCE", sge.In(this=sg_lexpr, expressions=values), sge.convert(False)
5769
)
5870

5971

0 commit comments

Comments
 (0)