Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit a005129

Browse files
improve data source handling, compile predicates again
1 parent d76ffb1 commit a005129

File tree

4 files changed

+98
-48
lines changed

4 files changed

+98
-48
lines changed

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def _compile_node(
141141

142142

143143
@_compile_node.register
144-
def compile_sql_select(node: sql_nodes.SelectNode, child: ir.SQLGlotIR):
144+
def compile_sql_select(node: sql_nodes.SqlSelectNode, child: ir.SQLGlotIR):
145145
sqlglot_ir = child
146146
if node.sorting is not None:
147147
ordering_cols = tuple(
@@ -165,6 +165,12 @@ def compile_sql_select(node: sql_nodes.SelectNode, child: ir.SQLGlotIR):
165165
)
166166
sqlglot_ir = sqlglot_ir.select(projected_cols)
167167

168+
if len(node.predicates) > 0:
169+
sge_predicates = tuple(
170+
scalar_compiler.scalar_op_compiler.compile_expression(expression)
171+
for expression in node.predicates
172+
)
173+
sqlglot_ir = sqlglot_ir.filter(sge_predicates)
168174
if node.limit is not None:
169175
sqlglot_ir = sqlglot_ir.limit(node.limit)
170176

@@ -185,14 +191,12 @@ def compile_readlocal(node: nodes.ReadLocalNode, child: ir.SQLGlotIR) -> ir.SQLG
185191

186192

187193
@_compile_node.register
188-
def compile_readtable(node: nodes.ReadTableNode, child: ir.SQLGlotIR):
194+
def compile_readtable(node: sql_nodes.SqlDataSource, child: ir.SQLGlotIR):
189195
table = node.source.table
190196
return ir.SQLGlotIR.from_table(
191197
table.project_id,
192198
table.dataset_id,
193199
table.table_id,
194-
col_names=[col.source_id for col in node.scan_list.items],
195-
alias_names=[col.id.sql for col in node.scan_list.items],
196200
uid_gen=child.uid_gen,
197201
sql_predicate=node.source.sql_predicate,
198202
system_time=node.source.at_time,

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,6 @@ def from_table(
116116
project_id: str,
117117
dataset_id: str,
118118
table_id: str,
119-
col_names: typing.Sequence[str],
120-
alias_names: typing.Sequence[str],
121119
uid_gen: guid.SequentialUIDGenerator,
122120
sql_predicate: typing.Optional[str] = None,
123121
system_time: typing.Optional[datetime.datetime] = None,
@@ -134,15 +132,6 @@ def from_table(
134132
sql_predicate (typing.Optional[str]): An optional SQL predicate for filtering.
135133
system_time (typing.Optional[str]): An optional system time for time-travel queries.
136134
"""
137-
selections = [
138-
sge.Alias(
139-
this=sge.to_identifier(col_name, quoted=cls.quoted),
140-
alias=sge.to_identifier(alias_name, quoted=cls.quoted),
141-
)
142-
if col_name != alias_name
143-
else sge.to_identifier(col_name, quoted=cls.quoted)
144-
for col_name, alias_name in zip(col_names, alias_names)
145-
]
146135
version = (
147136
sge.Version(
148137
this="TIMESTAMP",
@@ -158,12 +147,14 @@ def from_table(
158147
catalog=sg.to_identifier(project_id, quoted=cls.quoted),
159148
version=version,
160149
)
161-
select_expr = sge.Select().select(*selections).from_(table_expr)
162150
if sql_predicate:
151+
select_expr = sge.Select().select(sge.Star()).from_(table_expr)
163152
select_expr = select_expr.where(
164153
sg.parse_one(sql_predicate, dialect="bigquery"), append=False
165154
)
166-
return cls(expr=select_expr, uid_gen=uid_gen)
155+
return cls(expr=select_expr, uid_gen=uid_gen)
156+
157+
return cls(expr=table_expr, uid_gen=uid_gen)
167158

168159
@classmethod
169160
def from_query_string(

bigframes/core/rewrite/as_sql.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Google LLC
1+
# Copyright 2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -27,14 +27,14 @@
2727
import bigframes.core.rewrite
2828

2929

30-
def _limit(select: sql_nodes.SelectNode, limit: int) -> sql_nodes.SelectNode:
30+
def _limit(select: sql_nodes.SqlSelectNode, limit: int) -> sql_nodes.SqlSelectNode:
3131
new_limit = limit if select.limit is None else min([select.limit, limit])
3232
return dataclasses.replace(select, limit=new_limit)
3333

3434

3535
def _try_sort(
36-
select: sql_nodes.SelectNode, sort_by: Sequence[ordering.OrderingExpression]
37-
) -> Optional[sql_nodes.SelectNode]:
36+
select: sql_nodes.SqlSelectNode, sort_by: Sequence[ordering.OrderingExpression]
37+
) -> Optional[sql_nodes.SqlSelectNode]:
3838
new_order_exprs = []
3939
for sort_expr in sort_by:
4040
new_expr = _try_bind(
@@ -50,8 +50,8 @@ def _try_sort(
5050

5151
def _sort(
5252
node: nodes.BigFrameNode, sort_by: Sequence[ordering.OrderingExpression]
53-
) -> sql_nodes.SelectNode:
54-
if isinstance(node, sql_nodes.SelectNode):
53+
) -> sql_nodes.SqlSelectNode:
54+
if isinstance(node, sql_nodes.SqlSelectNode):
5555
merged = _try_sort(node, sort_by)
5656
if merged:
5757
return merged
@@ -73,8 +73,8 @@ def _try_bind(
7373

7474

7575
def _try_add_cdefs(
76-
select: sql_nodes.SelectNode, cdefs: Sequence[nodes.ColumnDef]
77-
) -> Optional[sql_nodes.SelectNode]:
76+
select: sql_nodes.SqlSelectNode, cdefs: Sequence[nodes.ColumnDef]
77+
) -> Optional[sql_nodes.SqlSelectNode]:
7878
# TODO: add up complexity measure while inlining refs
7979
new_defs = []
8080
for cdef in cdefs:
@@ -91,8 +91,8 @@ def _try_add_cdefs(
9191

9292
def _add_cdefs(
9393
node: nodes.BigFrameNode, cdefs: Sequence[nodes.ColumnDef]
94-
) -> sql_nodes.SelectNode:
95-
if isinstance(node, sql_nodes.SelectNode):
94+
) -> sql_nodes.SqlSelectNode:
95+
if isinstance(node, sql_nodes.SqlSelectNode):
9696
merged = _try_add_cdefs(node, cdefs)
9797
if merged:
9898
return merged
@@ -103,8 +103,8 @@ def _add_cdefs(
103103

104104

105105
def _try_add_filter(
106-
select: sql_nodes.SelectNode, predicates: Sequence[expression.Expression]
107-
) -> Optional[sql_nodes.SelectNode]:
106+
select: sql_nodes.SqlSelectNode, predicates: Sequence[expression.Expression]
107+
) -> Optional[sql_nodes.SqlSelectNode]:
108108
# Constraint: filters can only be merged if they are scalar expression after binding
109109
new_predicates = []
110110
# bind variables, merge predicates
@@ -118,8 +118,8 @@ def _try_add_filter(
118118

119119
def _add_filter(
120120
node: nodes.BigFrameNode, predicates: Sequence[expression.Expression]
121-
) -> sql_nodes.SelectNode:
122-
if isinstance(node, sql_nodes.SelectNode):
121+
) -> sql_nodes.SqlSelectNode:
122+
if isinstance(node, sql_nodes.SqlSelectNode):
123123
result = _try_add_filter(node, predicates)
124124
if result:
125125
return result
@@ -128,8 +128,8 @@ def _add_filter(
128128
return new_node
129129

130130

131-
def _create_noop_select(node: nodes.BigFrameNode) -> sql_nodes.SelectNode:
132-
return sql_nodes.SelectNode(
131+
def _create_noop_select(node: nodes.BigFrameNode) -> sql_nodes.SqlSelectNode:
132+
return sql_nodes.SqlSelectNode(
133133
node,
134134
selections=tuple(
135135
nodes.ColumnDef(expression.ResolvedDerefOp.from_field(field), field.id)
@@ -139,7 +139,7 @@ def _create_noop_select(node: nodes.BigFrameNode) -> sql_nodes.SelectNode:
139139

140140

141141
def _try_remap_select_cols(
142-
select: sql_nodes.SelectNode, cols: Sequence[nodes.AliasedRef]
142+
select: sql_nodes.SqlSelectNode, cols: Sequence[nodes.AliasedRef]
143143
):
144144
new_defs = []
145145
for aliased_ref in cols:
@@ -151,7 +151,7 @@ def _try_remap_select_cols(
151151

152152

153153
def _remap_select_cols(node: nodes.BigFrameNode, cols: Sequence[nodes.AliasedRef]):
154-
if isinstance(node, sql_nodes.SelectNode):
154+
if isinstance(node, sql_nodes.SqlSelectNode):
155155
result = _try_remap_select_cols(node, cols)
156156
if result:
157157
return result
@@ -183,7 +183,14 @@ def _get_added_cdefs(node: Union[nodes.ProjectionNode, nodes.WindowOpNode]):
183183

184184
def _as_sql_node(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
185185
# case one, can be converted to select
186-
if isinstance(node, (nodes.ProjectionNode, nodes.WindowOpNode)):
186+
if isinstance(node, nodes.ReadTableNode):
187+
leaf = sql_nodes.SqlDataSource(source=node.source)
188+
mappings = [
189+
nodes.AliasedRef(expression.deref(scan_item.source_id), scan_item.id)
190+
for scan_item in node.scan_list.items
191+
]
192+
return _remap_select_cols(leaf, mappings)
193+
elif isinstance(node, (nodes.ProjectionNode, nodes.WindowOpNode)):
187194
cdefs = _get_added_cdefs(node)
188195
return _add_cdefs(node.child, cdefs)
189196
elif isinstance(node, (nodes.SelectionNode)):

bigframes/core/sql_nodes.py

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 Google LLC
1+
# Copyright 2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -18,26 +18,74 @@
1818
import functools
1919
from typing import Mapping, Optional, Sequence, Tuple
2020

21-
from bigframes.core import identifiers, nodes
21+
from bigframes.core import bq_data, identifiers, nodes
2222
import bigframes.core.expression as ex
2323
from bigframes.core.ordering import OrderingExpression
2424
import bigframes.dtypes
2525

26-
# A fixed number of variable to assume for overhead on some operations
27-
OVERHEAD_VARIABLES = 5
2826

27+
# TODO: Join node, union node
28+
@dataclasses.dataclass(frozen=True)
29+
class SqlDataSource(nodes.LeafNode):
30+
source: bq_data.BigqueryDataSource
31+
32+
@functools.cached_property
33+
def fields(self) -> Sequence[nodes.Field]:
34+
return tuple(
35+
nodes.Field(
36+
identifiers.ColumnId(source_id),
37+
self.source.schema.get_type(source_id),
38+
self.source.table.schema_by_id[source_id].is_nullable,
39+
)
40+
for source_id in self.source.schema.names
41+
)
42+
43+
@property
44+
def variables_introduced(self) -> int:
45+
# This operation only renames variables, doesn't actually create new ones
46+
return 0
47+
48+
@property
49+
def defines_namespace(self) -> bool:
50+
return True
51+
52+
@property
53+
def explicitly_ordered(self) -> bool:
54+
return False
55+
56+
@property
57+
def order_ambiguous(self) -> bool:
58+
return True
59+
60+
@property
61+
def row_count(self) -> Optional[int]:
62+
return self.source.n_rows
63+
64+
@property
65+
def node_defined_ids(self) -> Tuple[identifiers.ColumnId, ...]:
66+
return tuple(self.ids)
67+
68+
@property
69+
def consumed_ids(self):
70+
return ()
2971

30-
@dataclasses.dataclass(frozen=True, eq=True)
31-
class ColumnDef:
32-
expression: ex.Expression
33-
id: identifiers.ColumnId
72+
@property
73+
def _node_expressions(self):
74+
return ()
3475

76+
def remap_vars(
77+
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
78+
) -> SqlSelectNode:
79+
raise NotImplementedError()
3580

36-
# TODO: Raw data source node, join node, union node
81+
def remap_refs(
82+
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
83+
) -> SqlSelectNode:
84+
raise NotImplementedError() # type: ignore
3785

3886

3987
@dataclasses.dataclass(frozen=True)
40-
class SelectNode(nodes.UnaryNode):
88+
class SqlSelectNode(nodes.UnaryNode):
4189
selections: tuple[nodes.ColumnDef, ...] = ()
4290
predicates: tuple[ex.Expression, ...] = ()
4391
sorting: tuple[OrderingExpression, ...] = ()
@@ -106,10 +154,10 @@ def get_id_mapping(self) -> dict[identifiers.ColumnId, ex.Expression]:
106154

107155
def remap_vars(
108156
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
109-
) -> SelectNode:
157+
) -> SqlSelectNode:
110158
raise NotImplementedError()
111159

112160
def remap_refs(
113161
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
114-
) -> SelectNode:
162+
) -> SqlSelectNode:
115163
raise NotImplementedError() # type: ignore

0 commit comments

Comments
 (0)