Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 24fdb3a

Browse files
clean up select vs table expr cases in ir
1 parent e92b318 commit 24fdb3a

File tree

10 files changed

+60
-89
lines changed

10 files changed

+60
-89
lines changed

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ def _compile_result_node(root: nodes.ResultNode) -> str:
108108
root = typing.cast(nodes.ResultNode, schema_binding.bind_schema_to_tree(root))
109109

110110
sqlglot_ir = compile_node(rewrite.as_sql_nodes(root), uid_gen)
111-
print(sqlglot_ir.sql)
112111
return sqlglot_ir.sql
113112

114113

@@ -261,7 +260,7 @@ def compile_concat(node: nodes.ConcatNode, *children: ir.SQLGlotIR) -> ir.SQLGlo
261260
]
262261

263262
return ir.SQLGlotIR.from_union(
264-
[child.expr for child in children],
263+
[child._as_select() for child in children],
265264
output_aliases=output_aliases,
266265
uid_gen=uid_gen,
267266
)

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 45 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
class SQLGlotIR:
4545
"""Helper class to build SQLGlot Query and generate SQL string."""
4646

47-
expr: sge.Select = sg.select()
47+
expr: typing.Union[sge.Select, sge.Table] = sg.select()
4848
"""The SQLGlot expression representing the query."""
4949

5050
dialect = sg.dialects.bigquery.BigQuery
@@ -163,15 +163,9 @@ def select(
163163
sorting: tuple[sge.Ordered, ...] = (),
164164
limit: typing.Optional[int] = None,
165165
) -> SQLGlotIR:
166-
167166
# TODO: Explicitly insert CTEs into plan
168167
if isinstance(self.expr, sge.Select):
169-
new_expr = _select_to_cte(
170-
self.expr,
171-
sge.to_identifier(
172-
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
173-
),
174-
)
168+
new_expr, _ = self._select_to_cte()
175169
else:
176170
new_expr = sge.Select().from_(self.expr)
177171

@@ -272,15 +266,8 @@ def join(
272266
joins_nulls: bool = True,
273267
) -> SQLGlotIR:
274268
"""Joins the current query with another SQLGlotIR instance."""
275-
left_cte_name = sge.to_identifier(
276-
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
277-
)
278-
right_cte_name = sge.to_identifier(
279-
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
280-
)
281-
282-
left_select = _select_to_cte(self.expr, left_cte_name)
283-
right_select = _select_to_cte(right.expr, right_cte_name)
269+
left_select, left_cte_name = self._select_to_cte()
270+
right_select, right_cte_name = self._select_to_cte()
284271

285272
left_select, left_ctes = _pop_query_ctes(left_select)
286273
right_select, right_ctes = _pop_query_ctes(right_select)
@@ -311,13 +298,9 @@ def isin_join(
311298
joins_nulls: bool = True,
312299
) -> SQLGlotIR:
313300
"""Joins the current query with another SQLGlotIR instance."""
314-
left_cte_name = sge.to_identifier(
315-
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
316-
)
317-
318-
left_select = _select_to_cte(self.expr, left_cte_name)
301+
left_select, left_cte_name = self._select_to_cte()
319302
# Prefer subquery over CTE for the IN clause's right side to improve SQL readability.
320-
right_select = right.expr
303+
right_select = right._as_select()
321304

322305
left_select, left_ctes = _pop_query_ctes(left_select)
323306
right_select, right_ctes = _pop_query_ctes(right_select)
@@ -380,21 +363,12 @@ def explode(
380363

381364
def sample(self, fraction: float) -> SQLGlotIR:
382365
"""Uniform samples a fraction of the rows."""
383-
uuid_col = sge.to_identifier(
384-
next(self.uid_gen.get_uid_stream("bfcol_")), quoted=self.quoted
385-
)
386-
uuid_expr = sge.Alias(this=sge.func("RAND"), alias=uuid_col)
387366
condition = sge.LT(
388-
this=uuid_col,
367+
this=sge.func("RAND"),
389368
expression=_literal(fraction, dtypes.FLOAT_DTYPE),
390369
)
391370

392-
new_cte_name = sge.to_identifier(
393-
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
394-
)
395-
new_expr = _select_to_cte(
396-
self.expr.select(uuid_expr, append=True), new_cte_name
397-
).where(condition, append=False)
371+
new_expr = self._select_to_cte()[0].where(condition, append=False)
398372
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
399373

400374
def aggregate(
@@ -418,12 +392,7 @@ def aggregate(
418392
for id, expr in aggregations
419393
]
420394

421-
new_expr = _select_to_cte(
422-
self.expr,
423-
sge.to_identifier(
424-
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
425-
),
426-
)
395+
new_expr, _ = self._select_to_cte()
427396
new_expr = new_expr.group_by(*by_cols).select(
428397
*[*by_cols, *aggregations_expr], append=False
429398
)
@@ -443,7 +412,7 @@ def insert(
443412
destination: bigquery.TableReference,
444413
) -> str:
445414
"""Generates an INSERT INTO SQL statement from the current SELECT clause."""
446-
return sge.insert(self.expr.subquery(), _table(destination)).sql(
415+
return sge.insert(self._as_from_item(), _table(destination)).sql(
447416
dialect=self.dialect, pretty=self.pretty
448417
)
449418

@@ -467,7 +436,7 @@ def replace(
467436

468437
merge_str = sge.Merge(
469438
this=_table(destination),
470-
using=self.expr.subquery(),
439+
using=self._as_from_item(),
471440
on=_literal(False, dtypes.BOOL_DTYPE),
472441
).sql(dialect=self.dialect, pretty=self.pretty)
473442
return f"{merge_str}\n{whens_str}"
@@ -490,12 +459,7 @@ def _explode_single_column(
490459
)
491460
selection = sge.Star(replace=[unnested_column_alias.as_(column)])
492461

493-
new_expr = _select_to_cte(
494-
self.expr,
495-
sge.to_identifier(
496-
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
497-
),
498-
)
462+
new_expr, _ = self._select_to_cte()
499463
# Use LEFT JOIN to preserve rows when unnesting empty arrays.
500464
new_expr = new_expr.select(selection, append=False).join(
501465
unnest_expr, join_type="LEFT"
@@ -546,32 +510,46 @@ def _explode_multiple_columns(
546510
for column in columns
547511
]
548512
)
549-
new_expr = _select_to_cte(
550-
self.expr,
551-
sge.to_identifier(
552-
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
553-
),
554-
)
513+
new_expr, _ = self._select_to_cte()
555514
# Use LEFT JOIN to preserve rows when unnesting empty arrays.
556515
new_expr = new_expr.select(selection, append=False).join(
557516
unnest_expr, join_type="LEFT"
558517
)
559518
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
560519

520+
def _as_from_item(self) -> typing.Union[sge.Table, sge.Subquery]:
521+
if isinstance(self.expr, sge.Select):
522+
return self.expr.subquery()
523+
else: # table
524+
return self.expr
561525

562-
def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select:
563-
"""Transforms a given sge.Select query by pushing its main SELECT statement
564-
into a new CTE and then generates a 'SELECT * FROM new_cte_name'
565-
for the new query."""
566-
select_expr = expr.copy()
567-
select_expr, existing_ctes = _pop_query_ctes(select_expr)
568-
new_cte = sge.CTE(
569-
this=select_expr,
570-
alias=cte_name,
571-
)
572-
new_select_expr = sge.Select().select(sge.Star()).from_(sge.Table(this=cte_name))
573-
new_select_expr = _set_query_ctes(new_select_expr, [*existing_ctes, new_cte])
574-
return new_select_expr
526+
def _as_select(self) -> sge.Select:
527+
if isinstance(self.expr, sge.Select):
528+
return self.expr
529+
else: # table
530+
return sge.Select().from_(self.expr)
531+
532+
def _as_subquery(self) -> sge.Subquery:
533+
return self._as_select().subquery()
534+
535+
def _select_to_cte(self) -> tuple[sge.Select, sge.Identifier]:
536+
"""Transforms a given sge.Select query by pushing its main SELECT statement
537+
into a new CTE and then generates a 'SELECT * FROM new_cte_name'
538+
for the new query."""
539+
cte_name = sge.to_identifier(
540+
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
541+
)
542+
select_expr = self.expr._as_select().copy()
543+
select_expr, existing_ctes = _pop_query_ctes(select_expr)
544+
new_cte = sge.CTE(
545+
this=select_expr,
546+
alias=cte_name,
547+
)
548+
new_select_expr = (
549+
sge.Select().select(sge.Star()).from_(sge.Table(this=cte_name))
550+
)
551+
new_select_expr = _set_query_ctes(new_select_expr, [*existing_ctes, new_cte])
552+
return new_select_expr, cte_name
575553

576554

577555
def _is_null_literal(expr: sge.Expression) -> bool:

bigframes/core/sql_nodes.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,6 @@ class SqlSelectNode(nodes.UnaryNode):
9191
sorting: tuple[OrderingExpression, ...] = ()
9292
limit: Optional[int] = None
9393

94-
def __post_init__(self):
95-
try:
96-
self.fields
97-
except Exception:
98-
...
99-
10094
@functools.cached_property
10195
def fields(self) -> Sequence[nodes.Field]:
10296
fields = []

tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ WITH `bfcte_0` AS (
55
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
66
), `bfcte_1` AS (
77
SELECT
8-
`int64_col` AS `bfcol_6`,
9-
`int64_too` AS `bfcol_7`
8+
`rowindex` AS `bfcol_2`,
9+
`int64_col` AS `bfcol_3`
1010
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
1111
), `bfcte_2` AS (
1212
SELECT

tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/bool_col/out.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ WITH `bfcte_0` AS (
55
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
66
), `bfcte_1` AS (
77
SELECT
8-
`rowindex` AS `bfcol_6`,
9-
`bool_col` AS `bfcol_7`
8+
`rowindex` AS `bfcol_2`,
9+
`bool_col` AS `bfcol_3`
1010
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
1111
), `bfcte_2` AS (
1212
SELECT

tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/float64_col/out.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ WITH `bfcte_0` AS (
55
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
66
), `bfcte_1` AS (
77
SELECT
8-
`rowindex` AS `bfcol_6`,
9-
`float64_col` AS `bfcol_7`
8+
`rowindex` AS `bfcol_2`,
9+
`float64_col` AS `bfcol_3`
1010
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
1111
), `bfcte_2` AS (
1212
SELECT

tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/int64_col/out.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ WITH `bfcte_0` AS (
55
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
66
), `bfcte_1` AS (
77
SELECT
8-
`rowindex` AS `bfcol_6`,
9-
`int64_col` AS `bfcol_7`
8+
`rowindex` AS `bfcol_2`,
9+
`int64_col` AS `bfcol_3`
1010
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
1111
), `bfcte_2` AS (
1212
SELECT

tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/numeric_col/out.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ WITH `bfcte_0` AS (
55
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
66
), `bfcte_1` AS (
77
SELECT
8-
`rowindex` AS `bfcol_6`,
9-
`numeric_col` AS `bfcol_7`
8+
`rowindex` AS `bfcol_2`,
9+
`numeric_col` AS `bfcol_3`
1010
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
1111
), `bfcte_2` AS (
1212
SELECT

tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/string_col/out.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ WITH `bfcte_0` AS (
55
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
66
), `bfcte_1` AS (
77
SELECT
8-
`rowindex` AS `bfcol_4`,
9-
`string_col` AS `bfcol_5`
8+
`rowindex` AS `bfcol_0`,
9+
`string_col` AS `bfcol_1`
1010
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
1111
), `bfcte_2` AS (
1212
SELECT

tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/time_col/out.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ WITH `bfcte_0` AS (
55
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
66
), `bfcte_1` AS (
77
SELECT
8-
`rowindex` AS `bfcol_4`,
9-
`time_col` AS `bfcol_5`
8+
`rowindex` AS `bfcol_0`,
9+
`time_col` AS `bfcol_1`
1010
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
1111
), `bfcte_2` AS (
1212
SELECT

0 commit comments

Comments
 (0)