Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 6a1107e

Browse files
refactor: Add cte factoring to new compiler
1 parent 61c17e3 commit 6a1107e

File tree

7 files changed

+257
-85
lines changed

7 files changed

+257
-85
lines changed

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def compile_sql(request: configs.CompileRequest) -> configs.CompileResult:
4848
output_cols=output_names,
4949
limit=request.peek_count,
5050
)
51+
# Extract CTEs early, as later rewriters could otherwise make common subtrees diverge
52+
result_node = typing.cast(nodes.ResultNode, rewrite.extract_ctes(result_node))
5153
if request.sort_rows:
5254
# Can only pullup slice if we are doing ORDER BY in outermost SELECT
5355
# Need to do this before replacing unsupported ops, as that will rewrite slice ops
@@ -103,11 +105,21 @@ def _compile_result_node(root: nodes.ResultNode) -> str:
103105
root = _remap_variables(root, uid_gen)
104106
root = typing.cast(nodes.ResultNode, rewrite.defer_selection(root))
105107

108+
# TODO: Extract out CTEs to a with_ctes node?
109+
cte_nodes = _get_ctes(root)
110+
106111
# Have to bind schema as the final step before compilation.
107112
# Probably, should defer even further
108113
root = typing.cast(nodes.ResultNode, schema_binding.bind_schema_to_tree(root))
109114

110115
sqlglot_ir = compile_node(rewrite.as_sql_nodes(root), uid_gen)
116+
sqlglot_ir = sqlglot_ir.with_ctes(
117+
tuple(
118+
(compile_node(cte_node, uid_gen)._as_select(), cte_node.name)
119+
for cte_node in cte_nodes
120+
)
121+
)
122+
111123
return sqlglot_ir.sql
112124

113125

@@ -247,6 +259,31 @@ def compile_isin_join(
247259
)
248260

249261

262+
@_compile_node.register
263+
def compile_cte_node(node: nodes.CteRefNode, _)
264+
table = node.name
265+
return ir.SQLGlotIR.from_table(
266+
table.project_id,
267+
table.dataset_id,
268+
table.table_id,
269+
uid_gen=child.uid_gen,
270+
sql_predicate=node.source.sql_predicate,
271+
system_time=node.source.at_time,
272+
)
273+
274+
@_compile_node.register
275+
def compile_cte_node(node: nodes.CteNode, _)
276+
table = node.source.table
277+
return ir.SQLGlotIR.from_table(
278+
table.project_id,
279+
table.dataset_id,
280+
table.table_id,
281+
uid_gen=child.uid_gen,
282+
sql_predicate=node.source.sql_predicate,
283+
system_time=node.source.at_time,
284+
)
285+
286+
250287
@_compile_node.register
251288
def compile_concat(node: nodes.ConcatNode, *children: ir.SQLGlotIR) -> ir.SQLGlotIR:
252289
assert len(children) >= 1
@@ -312,3 +349,16 @@ def _replace_unsupported_ops(node: nodes.BigFrameNode):
312349
node = nodes.bottom_up(node, rewrite.rewrite_slice)
313350
node = nodes.bottom_up(node, rewrite.rewrite_range_rolling)
314351
return node
352+
353+
354+
def _get_ctes(root: nodes.ResultNode) -> typing.Sequence[nodes.CteNode]:
355+
"""
356+
Get ctes from plan in topological order.
357+
"""
358+
359+
def merge_list(node, cte_list):
360+
if isinstance(node, nodes.CteNode):
361+
return (*cte_list, node)
362+
return cte_list
363+
364+
return root.reduce_up(merge_list)

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 31 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def select(
165165
) -> SQLGlotIR:
166166
# TODO: Explicitly insert CTEs into plan
167167
if isinstance(self.expr, sge.Select):
168-
new_expr, _ = self._select_to_cte()
168+
new_expr, _ = self._as_from_item()
169169
else:
170170
new_expr = sge.Select().from_(self.expr)
171171

@@ -222,21 +222,8 @@ def from_union(
222222
assert (
223223
len(list(selects)) >= 2
224224
), f"At least two select expressions must be provided, but got {selects}."
225-
226-
existing_ctes: list[sge.CTE] = []
227-
union_selects: list[sge.Select] = []
228-
for select in selects:
229-
assert isinstance(
230-
select, sge.Select
231-
), f"All provided expressions must be of type sge.Select, but got {type(select)}"
232-
233-
select_expr = select.copy()
234-
select_expr, select_ctes = _pop_query_ctes(select_expr)
235-
existing_ctes = _merge_ctes(existing_ctes, select_ctes)
236-
union_selects.append(select_expr)
237-
238-
union_expr: sge.Query = union_selects[0].subquery()
239-
for select in union_selects[1:]:
225+
union_expr: sge.Query = selects[0].subquery()
226+
for select in selects[1:]:
240227
union_expr = sge.Union(
241228
this=union_expr,
242229
expression=select.subquery(),
@@ -254,7 +241,6 @@ def from_union(
254241
final_select_expr = (
255242
sge.Select().select(*selections).from_(union_expr.subquery())
256243
)
257-
final_select_expr = _set_query_ctes(final_select_expr, existing_ctes)
258244
return cls(expr=final_select_expr, uid_gen=uid_gen)
259245

260246
def join(
@@ -266,12 +252,8 @@ def join(
266252
joins_nulls: bool = True,
267253
) -> SQLGlotIR:
268254
"""Joins the current query with another SQLGlotIR instance."""
269-
left_select, left_cte_name = self._select_to_cte()
270-
right_select, right_cte_name = right._select_to_cte()
271-
272-
left_select, left_ctes = _pop_query_ctes(left_select)
273-
right_select, right_ctes = _pop_query_ctes(right_select)
274-
merged_ctes = _merge_ctes(left_ctes, right_ctes)
255+
left_from = self._as_from_item()
256+
right_from = right._as_from_item()
275257

276258
join_on = _and(
277259
tuple(
@@ -283,10 +265,9 @@ def join(
283265
new_expr = (
284266
sge.Select()
285267
.select(sge.Star())
286-
.from_(sge.Table(this=left_cte_name))
287-
.join(sge.Table(this=right_cte_name), on=join_on, join_type=join_type_str)
268+
.from_(left_from)
269+
.join(right_from, on=join_on, join_type=join_type_str)
288270
)
289-
new_expr = _set_query_ctes(new_expr, merged_ctes)
290271

291272
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
292273

@@ -298,16 +279,12 @@ def isin_join(
298279
joins_nulls: bool = True,
299280
) -> SQLGlotIR:
300281
"""Joins the current query with another SQLGlotIR instance."""
301-
left_select, left_cte_name = self._select_to_cte()
282+
left_from = self._as_from_item()
302283
# Prefer subquery over CTE for the IN clause's right side to improve SQL readability.
303284
right_select = right._as_select()
304285

305-
left_select, left_ctes = _pop_query_ctes(left_select)
306-
right_select, right_ctes = _pop_query_ctes(right_select)
307-
merged_ctes = _merge_ctes(left_ctes, right_ctes)
308-
309286
left_condition = typed_expr.TypedExpr(
310-
sge.Column(this=conditions[0].expr, table=left_cte_name),
287+
sge.Column(this=conditions[0].expr, table=left_from),
311288
conditions[0].dtype,
312289
)
313290

@@ -341,10 +318,9 @@ def isin_join(
341318

342319
new_expr = (
343320
sge.Select()
344-
.select(sge.Column(this=sge.Star(), table=left_cte_name), new_column)
345-
.from_(sge.Table(this=left_cte_name))
321+
.select(sge.Column(this=sge.Star(), table=left_from), new_column)
322+
.from_(left_from)
346323
)
347-
new_expr = _set_query_ctes(new_expr, merged_ctes)
348324

349325
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
350326

@@ -368,7 +344,7 @@ def sample(self, fraction: float) -> SQLGlotIR:
368344
expression=_literal(fraction, dtypes.FLOAT_DTYPE),
369345
)
370346

371-
new_expr = self._select_to_cte()[0].where(condition, append=False)
347+
new_expr = self._as_select().where(condition, append=False)
372348
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
373349

374350
def aggregate(
@@ -392,7 +368,7 @@ def aggregate(
392368
for id, expr in aggregations
393369
]
394370

395-
new_expr, _ = self._select_to_cte()
371+
new_expr = self._as_select()
396372
new_expr = new_expr.group_by(*by_cols).select(
397373
*[*by_cols, *aggregations_expr], append=False
398374
)
@@ -407,12 +383,26 @@ def aggregate(
407383
new_expr = new_expr.where(condition, append=False)
408384
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
409385

386+
def with_ctes(
387+
self,
388+
ctes: tuple[tuple[str, sge.Select], ...],
389+
) -> SQLGlotIR:
390+
sge_ctes = [
391+
sge.CTE(
392+
this=cte,
393+
alias=cte_name,
394+
)
395+
for cte_name, cte in ctes
396+
]
397+
select_expr = _set_query_ctes(self._as_select(), sge_ctes)
398+
return SQLGlotIR(expr=select_expr, uid_gen=self.uid_gen)
399+
410400
def insert(
411401
self,
412402
destination: bigquery.TableReference,
413403
) -> str:
414404
"""Generates an INSERT INTO SQL statement from the current SELECT clause."""
415-
return sge.insert(self._as_from_item(), _table(destination)).sql(
405+
return sge.insert(self._as_select(), _table(destination)).sql(
416406
dialect=self.dialect, pretty=self.pretty
417407
)
418408

@@ -436,7 +426,7 @@ def replace(
436426

437427
merge_str = sge.Merge(
438428
this=_table(destination),
439-
using=self._as_from_item(),
429+
using=self._as_select(),
440430
on=_literal(False, dtypes.BOOL_DTYPE),
441431
).sql(dialect=self.dialect, pretty=self.pretty)
442432
return f"{merge_str}\n{whens_str}"
@@ -459,7 +449,7 @@ def _explode_single_column(
459449
)
460450
selection = sge.Star(replace=[unnested_column_alias.as_(column)])
461451

462-
new_expr, _ = self._select_to_cte()
452+
new_expr = self._as_select()
463453
# Use LEFT JOIN to preserve rows when unnesting empty arrays.
464454
new_expr = new_expr.select(selection, append=False).join(
465455
unnest_expr, join_type="LEFT"
@@ -510,7 +500,7 @@ def _explode_multiple_columns(
510500
for column in columns
511501
]
512502
)
513-
new_expr, _ = self._select_to_cte()
503+
new_expr = self._as_select()
514504
# Use LEFT JOIN to preserve rows when unnesting empty arrays.
515505
new_expr = new_expr.select(selection, append=False).join(
516506
unnest_expr, join_type="LEFT"
@@ -532,25 +522,6 @@ def _as_select(self) -> sge.Select:
532522
def _as_subquery(self) -> sge.Subquery:
533523
return self._as_select().subquery()
534524

535-
def _select_to_cte(self) -> tuple[sge.Select, sge.Identifier]:
536-
"""Transforms a given sge.Select query by pushing its main SELECT statement
537-
into a new CTE and then generates a 'SELECT * FROM new_cte_name'
538-
for the new query."""
539-
cte_name = sge.to_identifier(
540-
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
541-
)
542-
select_expr = self._as_select().copy()
543-
select_expr, existing_ctes = _pop_query_ctes(select_expr)
544-
new_cte = sge.CTE(
545-
this=select_expr,
546-
alias=cte_name,
547-
)
548-
new_select_expr = (
549-
sge.Select().select(sge.Star()).from_(sge.Table(this=cte_name))
550-
)
551-
new_select_expr = _set_query_ctes(new_select_expr, [*existing_ctes, new_cte])
552-
return new_select_expr, cte_name
553-
554525

555526
def _is_null_literal(expr: sge.Expression) -> bool:
556527
"""Checks if the given expression is a NULL literal."""
@@ -743,26 +714,3 @@ def _set_query_ctes(
743714
else:
744715
raise ValueError("The expression does not support CTEs.")
745716
return new_expr
746-
747-
748-
def _merge_ctes(ctes1: list[sge.CTE], ctes2: list[sge.CTE]) -> list[sge.CTE]:
749-
"""Merges two lists of CTEs, de-duplicating by alias name."""
750-
seen = {cte.alias: cte for cte in ctes1}
751-
for cte in ctes2:
752-
if cte.alias not in seen:
753-
seen[cte.alias] = cte
754-
return list(seen.values())
755-
756-
757-
def _pop_query_ctes(
758-
expr: sge.Select,
759-
) -> tuple[sge.Select, list[sge.CTE]]:
760-
"""Pops the CTEs of a given sge.Select expression."""
761-
if "with" in expr.arg_types.keys():
762-
expr_ctes = expr.args.pop("with", [])
763-
return expr, expr_ctes
764-
elif "with_" in expr.arg_types.keys():
765-
expr_ctes = expr.args.pop("with_", [])
766-
return expr, expr_ctes
767-
else:
768-
raise ValueError("The expression does not support CTEs.")

bigframes/core/nodes.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1715,6 +1715,73 @@ def _node_expressions(self):
17151715
return tuple(ref for ref, _ in self.output_cols)
17161716

17171717

1718+
@dataclasses.dataclass(frozen=True, eq=False)
1719+
class CteRefNode(UnaryNode):
1720+
cols: tuple[ex.DerefOp, ...]
1721+
1722+
@property
1723+
def fields(self) -> Sequence[Field]:
1724+
# Fields property here is for output schema, not to be consumed by a parent node.
1725+
input_fields_by_id = {field.id: field for field in self.child.fields}
1726+
return tuple(input_fields_by_id[ref.id] for ref in self.cols)
1727+
1728+
@property
1729+
def variables_introduced(self) -> int:
1730+
# This operation only renames variables, doesn't actually create new ones
1731+
return 0
1732+
1733+
@property
1734+
def row_count(self) -> Optional[int]:
1735+
return self.child.row_count
1736+
1737+
@property
1738+
def node_defined_ids(self) -> Tuple[identifiers.ColumnId, ...]:
1739+
return ()
1740+
1741+
def remap_vars(
1742+
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
1743+
) -> CteRefNode:
1744+
return self
1745+
1746+
def remap_refs(
1747+
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
1748+
) -> CteRefNode:
1749+
new_cols = tuple(id.remap_column_refs(mappings) for id in self.cols)
1750+
return dataclasses.replace(self, cols=new_cols)
1751+
1752+
1753+
@dataclasses.dataclass(frozen=True, eq=False)
1754+
class CteNode(UnaryNode):
1755+
name: str
1756+
1757+
@property
1758+
def fields(self) -> Sequence[Field]:
1759+
return self.child.fields
1760+
1761+
@property
1762+
def variables_introduced(self) -> int:
1763+
# This operation only renames variables, doesn't actually create new ones
1764+
return 0
1765+
1766+
@property
1767+
def row_count(self) -> Optional[int]:
1768+
return self.child.row_count
1769+
1770+
@property
1771+
def node_defined_ids(self) -> Tuple[identifiers.ColumnId, ...]:
1772+
return ()
1773+
1774+
def remap_vars(
1775+
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
1776+
) -> CteNode:
1777+
return self
1778+
1779+
def remap_refs(
1780+
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
1781+
) -> CteNode:
1782+
return self
1783+
1784+
17181785
# Tree operators
17191786
def top_down(
17201787
root: BigFrameNode,

bigframes/core/rewrite/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from bigframes.core.rewrite.as_sql import as_sql_nodes
16+
from bigframes.core.rewrite.ctes import extract_ctes
1617
from bigframes.core.rewrite.fold_row_count import fold_row_counts
1718
from bigframes.core.rewrite.identifiers import remap_variables
1819
from bigframes.core.rewrite.implicit_align import try_row_join
@@ -34,6 +35,7 @@
3435

3536
__all__ = [
3637
"as_sql_nodes",
38+
"extract_ctes",
3739
"legacy_join_as_projection",
3840
"try_row_join",
3941
"rewrite_slice",

0 commit comments

Comments
 (0)