4444class SQLGlotIR :
4545 """Helper class to build SQLGlot Query and generate SQL string."""
4646
47- expr : sge .Select = sg .select ()
47+ expr : typing . Union [ sge .Select , sge . Table ] = sg .select ()
4848 """The SQLGlot expression representing the query."""
4949
5050 dialect = sg .dialects .bigquery .BigQuery
@@ -163,15 +163,9 @@ def select(
163163 sorting : tuple [sge .Ordered , ...] = (),
164164 limit : typing .Optional [int ] = None ,
165165 ) -> SQLGlotIR :
166-
167166 # TODO: Explicitly insert CTEs into plan
168167 if isinstance (self .expr , sge .Select ):
169- new_expr = _select_to_cte (
170- self .expr ,
171- sge .to_identifier (
172- next (self .uid_gen .get_uid_stream ("bfcte_" )), quoted = self .quoted
173- ),
174- )
168+ new_expr , _ = self ._select_to_cte ()
175169 else :
176170 new_expr = sge .Select ().from_ (self .expr )
177171
@@ -272,15 +266,8 @@ def join(
272266 joins_nulls : bool = True ,
273267 ) -> SQLGlotIR :
274268 """Joins the current query with another SQLGlotIR instance."""
275- left_cte_name = sge .to_identifier (
276- next (self .uid_gen .get_uid_stream ("bfcte_" )), quoted = self .quoted
277- )
278- right_cte_name = sge .to_identifier (
279- next (self .uid_gen .get_uid_stream ("bfcte_" )), quoted = self .quoted
280- )
281-
282- left_select = _select_to_cte (self .expr , left_cte_name )
283- right_select = _select_to_cte (right .expr , right_cte_name )
269+ left_select , left_cte_name = self ._select_to_cte ()
270+ right_select , right_cte_name = self ._select_to_cte ()
284271
285272 left_select , left_ctes = _pop_query_ctes (left_select )
286273 right_select , right_ctes = _pop_query_ctes (right_select )
@@ -311,13 +298,9 @@ def isin_join(
311298 joins_nulls : bool = True ,
312299 ) -> SQLGlotIR :
313300 """Joins the current query with another SQLGlotIR instance."""
314- left_cte_name = sge .to_identifier (
315- next (self .uid_gen .get_uid_stream ("bfcte_" )), quoted = self .quoted
316- )
317-
318- left_select = _select_to_cte (self .expr , left_cte_name )
301+ left_select , left_cte_name = self ._select_to_cte ()
319302 # Prefer subquery over CTE for the IN clause's right side to improve SQL readability.
320- right_select = right .expr
303+ right_select = right ._as_select ()
321304
322305 left_select , left_ctes = _pop_query_ctes (left_select )
323306 right_select , right_ctes = _pop_query_ctes (right_select )
@@ -380,21 +363,12 @@ def explode(
380363
381364 def sample (self , fraction : float ) -> SQLGlotIR :
382365 """Uniform samples a fraction of the rows."""
383- uuid_col = sge .to_identifier (
384- next (self .uid_gen .get_uid_stream ("bfcol_" )), quoted = self .quoted
385- )
386- uuid_expr = sge .Alias (this = sge .func ("RAND" ), alias = uuid_col )
387366 condition = sge .LT (
388- this = uuid_col ,
367+ this = sge . func ( "RAND" ) ,
389368 expression = _literal (fraction , dtypes .FLOAT_DTYPE ),
390369 )
391370
392- new_cte_name = sge .to_identifier (
393- next (self .uid_gen .get_uid_stream ("bfcte_" )), quoted = self .quoted
394- )
395- new_expr = _select_to_cte (
396- self .expr .select (uuid_expr , append = True ), new_cte_name
397- ).where (condition , append = False )
371+ new_expr = self ._select_to_cte ()[0 ].where (condition , append = False )
398372 return SQLGlotIR (expr = new_expr , uid_gen = self .uid_gen )
399373
400374 def aggregate (
@@ -418,12 +392,7 @@ def aggregate(
418392 for id , expr in aggregations
419393 ]
420394
421- new_expr = _select_to_cte (
422- self .expr ,
423- sge .to_identifier (
424- next (self .uid_gen .get_uid_stream ("bfcte_" )), quoted = self .quoted
425- ),
426- )
395+ new_expr , _ = self ._select_to_cte ()
427396 new_expr = new_expr .group_by (* by_cols ).select (
428397 * [* by_cols , * aggregations_expr ], append = False
429398 )
@@ -443,7 +412,7 @@ def insert(
443412 destination : bigquery .TableReference ,
444413 ) -> str :
445414 """Generates an INSERT INTO SQL statement from the current SELECT clause."""
446- return sge .insert (self .expr . subquery (), _table (destination )).sql (
415+ return sge .insert (self ._as_from_item (), _table (destination )).sql (
447416 dialect = self .dialect , pretty = self .pretty
448417 )
449418
@@ -467,7 +436,7 @@ def replace(
467436
468437 merge_str = sge .Merge (
469438 this = _table (destination ),
470- using = self .expr . subquery (),
439+ using = self ._as_from_item (),
471440 on = _literal (False , dtypes .BOOL_DTYPE ),
472441 ).sql (dialect = self .dialect , pretty = self .pretty )
473442 return f"{ merge_str } \n { whens_str } "
@@ -490,12 +459,7 @@ def _explode_single_column(
490459 )
491460 selection = sge .Star (replace = [unnested_column_alias .as_ (column )])
492461
493- new_expr = _select_to_cte (
494- self .expr ,
495- sge .to_identifier (
496- next (self .uid_gen .get_uid_stream ("bfcte_" )), quoted = self .quoted
497- ),
498- )
462+ new_expr , _ = self ._select_to_cte ()
499463 # Use LEFT JOIN to preserve rows when unnesting empty arrays.
500464 new_expr = new_expr .select (selection , append = False ).join (
501465 unnest_expr , join_type = "LEFT"
@@ -546,32 +510,46 @@ def _explode_multiple_columns(
546510 for column in columns
547511 ]
548512 )
549- new_expr = _select_to_cte (
550- self .expr ,
551- sge .to_identifier (
552- next (self .uid_gen .get_uid_stream ("bfcte_" )), quoted = self .quoted
553- ),
554- )
513+ new_expr , _ = self ._select_to_cte ()
555514 # Use LEFT JOIN to preserve rows when unnesting empty arrays.
556515 new_expr = new_expr .select (selection , append = False ).join (
557516 unnest_expr , join_type = "LEFT"
558517 )
559518 return SQLGlotIR (expr = new_expr , uid_gen = self .uid_gen )
560519
520+ def _as_from_item (self ) -> typing .Union [sge .Table , sge .Subquery ]:
521+ if isinstance (self .expr , sge .Select ):
522+ return self .expr .subquery ()
523+ else : # table
524+ return self .expr
561525
562- def _select_to_cte (expr : sge .Select , cte_name : sge .Identifier ) -> sge .Select :
563- """Transforms a given sge.Select query by pushing its main SELECT statement
564- into a new CTE and then generates a 'SELECT * FROM new_cte_name'
565- for the new query."""
566- select_expr = expr .copy ()
567- select_expr , existing_ctes = _pop_query_ctes (select_expr )
568- new_cte = sge .CTE (
569- this = select_expr ,
570- alias = cte_name ,
571- )
572- new_select_expr = sge .Select ().select (sge .Star ()).from_ (sge .Table (this = cte_name ))
573- new_select_expr = _set_query_ctes (new_select_expr , [* existing_ctes , new_cte ])
574- return new_select_expr
526+ def _as_select (self ) -> sge .Select :
527+ if isinstance (self .expr , sge .Select ):
528+ return self .expr
529+ else : # table
530+ return sge .Select ().from_ (self .expr )
531+
532+ def _as_subquery (self ) -> sge .Subquery :
533+ return self ._as_select ().subquery ()
534+
535+ def _select_to_cte (self ) -> tuple [sge .Select , sge .Identifier ]:
536+ """Transforms a given sge.Select query by pushing its main SELECT statement
537+ into a new CTE and then generates a 'SELECT * FROM new_cte_name'
538+ for the new query."""
539+ cte_name = sge .to_identifier (
540+ next (self .uid_gen .get_uid_stream ("bfcte_" )), quoted = self .quoted
541+ )
542+ select_expr = self .expr ._as_select ().copy ()
543+ select_expr , existing_ctes = _pop_query_ctes (select_expr )
544+ new_cte = sge .CTE (
545+ this = select_expr ,
546+ alias = cte_name ,
547+ )
548+ new_select_expr = (
549+ sge .Select ().select (sge .Star ()).from_ (sge .Table (this = cte_name ))
550+ )
551+ new_select_expr = _set_query_ctes (new_select_expr , [* existing_ctes , new_cte ])
552+ return new_select_expr , cte_name
575553
576554
577555def _is_null_literal (expr : sge .Expression ) -> bool :
0 commit comments