Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 10a9798

Browse files
separate logical cte nodes from concrete sql ones
1 parent 2d9dad5 commit 10a9798

File tree

7 files changed

+164
-114
lines changed

7 files changed

+164
-114
lines changed

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -106,21 +106,11 @@ def _compile_result_node(root: nodes.ResultNode) -> str:
106106
root = _remap_variables(root, uid_gen)
107107
root = typing.cast(nodes.ResultNode, rewrite.defer_selection(root))
108108

109-
# TODO: Extract out CTEs to a with_ctes node?
110-
cte_nodes = _get_ctes(root)
111-
112109
# Have to bind schema as the final step before compilation.
113110
# Probably, should defer even further
114111
root = typing.cast(nodes.ResultNode, schema_binding.bind_schema_to_tree(root))
115112

116113
sqlglot_ir_obj = compile_node(rewrite.as_sql_nodes(root), uid_gen)
117-
sqlglot_ir_obj = sqlglot_ir_obj.with_ctes(
118-
tuple(
119-
(compile_node(cte_node, uid_gen)._as_select(), cte_node.name)
120-
for cte_node in cte_nodes
121-
)
122-
)
123-
124114
return sqlglot_ir_obj.sql
125115

126116

@@ -269,16 +259,20 @@ def compile_isin_join(
269259

270260

271261
@_compile_node.register
272-
def compile_cte_ref_node(node: nodes.CteRefNode, child: sqlglot_ir.SQLGlotIR):
262+
def compile_cte_ref_node(node: sql_nodes.SqlCteRefNode, child: sqlglot_ir.SQLGlotIR):
273263
return sqlglot_ir.SQLGlotIR.from_cte_ref(
274-
node.child.name, # type: ignore
264+
node.cte_name,
275265
uid_gen=child.uid_gen,
276266
)
277267

278268

279269
@_compile_node.register
280-
def compile_cte_node(node: nodes.CteNode, child: sqlglot_ir.SQLGlotIR):
281-
raise ValueError("CTE definitions should not be directly compiled")
270+
def compile_with_ctes_node(
271+
node: sql_nodes.SqlWithCtesNode,
272+
child: sqlglot_ir.SQLGlotIR,
273+
*ctes: sqlglot_ir.SQLGlotIR,
274+
):
275+
return child.with_ctes(tuple(zip(node.cte_names, ctes)))
282276

283277

284278
@_compile_node.register

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,11 +386,11 @@ def aggregate(
386386

387387
def with_ctes(
388388
self,
389-
ctes: tuple[tuple[str, sge.Select], ...],
389+
ctes: tuple[tuple[str, SQLGlotIR], ...],
390390
) -> SQLGlotIR:
391391
sge_ctes = [
392392
sge.CTE(
393-
this=cte,
393+
this=cte._as_select(),
394394
alias=cte_name,
395395
)
396396
for cte_name, cte in ctes

bigframes/core/nodes.py

Lines changed: 3 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1713,52 +1713,18 @@ def _node_expressions(self):
17131713
return tuple(ref for ref, _ in self.output_cols)
17141714

17151715

1716-
@dataclasses.dataclass(frozen=True, eq=False)
1717-
class CteRefNode(UnaryNode):
1718-
cols: tuple[ex.DerefOp, ...]
1719-
1720-
@property
1721-
def fields(self) -> Sequence[Field]:
1722-
# Fields property here is for output schema, not to be consumed by a parent node.
1723-
input_fields_by_id = {field.id: field for field in self.child.fields}
1724-
return tuple(input_fields_by_id[ref.id] for ref in self.cols)
1725-
1726-
@property
1727-
def variables_introduced(self) -> int:
1728-
# This operation only renames variables, doesn't actually create new ones
1729-
return 0
1730-
1731-
@property
1732-
def row_count(self) -> Optional[int]:
1733-
return self.child.row_count
1734-
1735-
@property
1736-
def node_defined_ids(self) -> Tuple[identifiers.ColumnId, ...]:
1737-
return ()
1738-
1739-
def remap_vars(
1740-
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
1741-
) -> CteRefNode:
1742-
return self
1743-
1744-
def remap_refs(
1745-
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
1746-
) -> CteRefNode:
1747-
new_cols = tuple(id.remap_column_refs(mappings) for id in self.cols)
1748-
return dataclasses.replace(self, cols=new_cols)
1749-
1750-
17511716
@dataclasses.dataclass(frozen=True, eq=False)
17521717
class CteNode(UnaryNode):
1753-
name: str
1718+
"""
1719+
Semantically a no-op, used to indicate shared subtrees and act as optimization boundary.
1720+
"""
17541721

17551722
@property
17561723
def fields(self) -> Sequence[Field]:
17571724
return self.child.fields
17581725

17591726
@property
17601727
def variables_introduced(self) -> int:
1761-
# This operation only renames variables, doesn't actually create new ones
17621728
return 0
17631729

17641730
@property

bigframes/core/rewrite/as_sql.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,26 @@ def _as_sql_node(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
222222
return node
223223

224224

225+
def _extract_ctes(root: nodes.BigFrameNode) -> nodes.BigFrameNode:
226+
topological_ctes = list(
227+
filter(lambda n: isinstance(n, nodes.CteNode), root.iter_nodes_topo())
228+
)
229+
cte_names = tuple(f"cte_{i}" for i in range(len(topological_ctes)))
230+
231+
mapping = {
232+
cte_node: sql_nodes.SqlCteRefNode(cte_name, tuple(cte_node.fields))
233+
for cte_node, cte_name in zip(topological_ctes, cte_names)
234+
}
235+
236+
# Replace all CTEs with CTE references and wrap the new root in a WITH clause
237+
return sql_nodes.SqlWithCtesNode(
238+
root.top_down(lambda x: mapping.get(x, x)),
239+
cte_names,
240+
tuple(cte.top_down(lambda x: mapping.get(x, x)) for cte in topological_ctes),
241+
)
242+
243+
225244
def as_sql_nodes(root: nodes.BigFrameNode) -> nodes.BigFrameNode:
226245
# TODO: Aggregations, Unions, Joins, raw data sources
246+
root = _extract_ctes(root)
227247
return nodes.bottom_up(root, _as_sql_node)

bigframes/core/rewrite/ctes.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,24 @@
1515

1616
from collections import defaultdict
1717

18-
from bigframes.core import expression, nodes
18+
from bigframes.core import nodes
1919

2020

2121
def extract_ctes(root: nodes.BigFrameNode) -> nodes.BigFrameNode:
2222
# identify candidates
23-
# candidates
2423
node_parents: dict[nodes.BigFrameNode, int] = defaultdict(int)
2524
for parent in root.unique_nodes():
2625
for child in parent.child_nodes:
2726
node_parents[child] += 1
2827

2928
counter = 0
30-
# ok time to replace via extract
29+
3130
# we just mark in place, rather than pull out of the tree.
32-
# if we did pull out of tree, we'd want to make sure to extract bottom-up
3331
def insert_cte_markers(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
3432
nonlocal counter
3533
if node_parents[node] > 1:
3634
counter += 1
37-
return nodes.CteRefNode(
38-
nodes.CteNode(node, name=f"cte_{counter})"),
39-
cols=tuple(expression.DerefOp(id) for id in node.ids),
40-
)
35+
return nodes.CteNode(node)
4136
return node
4237

4338
return root.top_down(insert_cte_markers)

bigframes/core/rewrite/pruning.py

Lines changed: 1 addition & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import collections
1514
import dataclasses
1615
import functools
1716
import itertools
@@ -23,39 +22,7 @@
2322
def column_pruning(
2423
root: nodes.BigFrameNode,
2524
) -> nodes.BigFrameNode:
26-
# We wrap the entire process in a fixed-point iteration to ensure
27-
# the global push-down through CTEs fully settles.
28-
@to_fixed(max_iterations=100)
29-
def prune_tree(current_root: nodes.BigFrameNode) -> nodes.BigFrameNode:
30-
# Apply local top-down pruning rules (pushes selections to CteRefNodes)
31-
pushed_root = nodes.top_down(current_root, prune_columns)
32-
33-
# Gather the union of required columns globally across all CTE refs
34-
cte_reqs: typing.DefaultDict[
35-
nodes.BigFrameNode, set[identifiers.ColumnId]
36-
] = collections.defaultdict(set)
37-
38-
def gather_reqs(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
39-
if isinstance(node, nodes.CteRefNode):
40-
# node.child is the referenced CTE definition (CteNode)
41-
for col in node.cols:
42-
cte_reqs[node.child].add(col.id)
43-
return node
44-
45-
nodes.top_down(pushed_root, gather_reqs)
46-
47-
# Apply the unioned required columns to the CTE definitions
48-
def apply_cte_reqs(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
49-
if isinstance(node, nodes.CteNode) and node in cte_reqs:
50-
needed_ids = frozenset(cte_reqs[node])
51-
pruned_child = prune_node(node.child, needed_ids)
52-
if pruned_child is not node.child:
53-
return node.replace_child(pruned_child)
54-
return node
55-
56-
return nodes.top_down(pushed_root, apply_cte_reqs)
57-
58-
return prune_tree(root)
25+
return nodes.top_down(root, prune_columns)
5926

6027

6128
def to_fixed(max_iterations: int = 100):
@@ -110,24 +77,6 @@ def prune_selection_child(
11077
{id: ref.id for ref, id in child.input_output_pairs}
11178
).replace_child(child.child)
11279

113-
elif isinstance(child, nodes.CteRefNode):
114-
# Push selection locally into the CTE Reference
115-
needed_ids = selection.consumed_ids
116-
new_cols = (
117-
tuple(col for col in child.cols if col.id in needed_ids) or child.cols[0:1]
118-
)
119-
120-
if new_cols == child.cols:
121-
return selection
122-
return selection.replace_child(dataclasses.replace(child, cols=new_cols))
123-
124-
elif isinstance(child, nodes.CteNode):
125-
# Pass-through selection to the CTE Definition just in case it's wrapped locally
126-
pruned_child = prune_node(child.child, selection.consumed_ids)
127-
if pruned_child is child.child:
128-
return selection
129-
return selection.replace_child(child.replace_child(pruned_child))
130-
13180
elif isinstance(child, nodes.AdditiveNode):
13281
if not set(field.id for field in child.added_fields) & selection.consumed_ids:
13382
return selection.replace_child(child.additive_base)

bigframes/core/sql_nodes.py

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,18 @@
1616

1717
import dataclasses
1818
import functools
19-
from typing import Mapping, Optional, Sequence, Tuple
19+
from typing import Callable, Mapping, Optional, Sequence, Tuple
2020

2121
from bigframes.core import bq_data, identifiers, nodes
2222
import bigframes.core.expression as ex
2323
from bigframes.core.ordering import OrderingExpression
2424
import bigframes.dtypes
2525

26+
# SQL Nodes are generally terminal, so don't support rich transformation methods
27+
# like remap_vars, remap_refs, etc.
28+
# Still, fields should be defined on them, as typing info is still used for
29+
# dispatching some operators in the emitter, and for validation.
30+
2631

2732
# TODO: Join node, union node
2833
@dataclasses.dataclass(frozen=True)
@@ -84,6 +89,127 @@ def remap_refs(
8489
raise NotImplementedError() # type: ignore
8590

8691

92+
@dataclasses.dataclass(frozen=True)
93+
class SqlWithCtesNode(nodes.BigFrameNode):
94+
# def, name pairs
95+
child: nodes.BigFrameNode
96+
cte_names: tuple[str, ...]
97+
cte_defs: tuple[nodes.BigFrameNode, ...]
98+
99+
@property
100+
def child_nodes(self) -> Sequence[nodes.BigFrameNode]:
101+
return (self.child, *self.cte_defs)
102+
103+
@property
104+
def fields(self) -> Sequence[nodes.Field]:
105+
return self.child.fields
106+
107+
@property
108+
def variables_introduced(self) -> int:
109+
# This operation only renames variables, doesn't actually create new ones
110+
return 0
111+
112+
@property
113+
def defines_namespace(self) -> bool:
114+
return True
115+
116+
@property
117+
def explicitly_ordered(self) -> bool:
118+
return False
119+
120+
@property
121+
def order_ambiguous(self) -> bool:
122+
return True
123+
124+
@property
125+
def row_count(self) -> Optional[int]:
126+
return self.child.row_count
127+
128+
@property
129+
def node_defined_ids(self) -> Tuple[identifiers.ColumnId, ...]:
130+
return tuple(self.ids)
131+
132+
@property
133+
def consumed_ids(self):
134+
return ()
135+
136+
@property
137+
def _node_expressions(self):
138+
return ()
139+
140+
def remap_vars(
141+
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
142+
) -> SqlWithCtesNode:
143+
raise NotImplementedError()
144+
145+
def remap_refs(
146+
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
147+
) -> SqlWithCtesNode:
148+
raise NotImplementedError() # type: ignore
149+
150+
def transform_children(
151+
self, transform: Callable[[nodes.BigFrameNode], nodes.BigFrameNode]
152+
) -> SqlWithCtesNode:
153+
return SqlWithCtesNode(
154+
transform(self.child),
155+
self.cte_names,
156+
tuple(transform(cte) for cte in self.cte_defs),
157+
)
158+
159+
160+
@dataclasses.dataclass(frozen=True)
161+
class SqlCteRefNode(nodes.LeafNode):
162+
cte_name: str
163+
cte_schema: tuple[nodes.Field, ...]
164+
165+
@property
166+
def fields(self) -> Sequence[nodes.Field]:
167+
return self.cte_schema
168+
169+
@property
170+
def variables_introduced(self) -> int:
171+
# This operation only renames variables, doesn't actually create new ones
172+
return 0
173+
174+
@property
175+
def defines_namespace(self) -> bool:
176+
return True
177+
178+
@property
179+
def explicitly_ordered(self) -> bool:
180+
return False
181+
182+
@property
183+
def order_ambiguous(self) -> bool:
184+
return True
185+
186+
@property
187+
def row_count(self) -> Optional[int]:
188+
raise NotImplementedError()
189+
190+
@property
191+
def node_defined_ids(self) -> Tuple[identifiers.ColumnId, ...]:
192+
return tuple(self.ids)
193+
194+
@property
195+
def consumed_ids(self):
196+
return ()
197+
198+
@property
199+
def _node_expressions(self):
200+
return ()
201+
202+
def remap_vars(
203+
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
204+
) -> SqlCteRefNode:
205+
raise NotImplementedError()
206+
207+
def remap_refs(
208+
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
209+
) -> SqlCteRefNode:
210+
raise NotImplementedError() # type: ignore
211+
212+
87213
@dataclasses.dataclass(frozen=True)
88214
class SqlSelectNode(nodes.UnaryNode):
89215
selections: tuple[nodes.ColumnDef, ...] = ()

0 commit comments

Comments
 (0)