@@ -165,7 +165,7 @@ def select(
165165 ) -> SQLGlotIR :
166166 # TODO: Explicitly insert CTEs into plan
167167 if isinstance (self .expr , sge .Select ):
168- new_expr , _ = self ._select_to_cte ()
168+ new_expr , _ = self ._as_from_item ()
169169 else :
170170 new_expr = sge .Select ().from_ (self .expr )
171171
@@ -222,21 +222,8 @@ def from_union(
222222 assert (
223223 len (list (selects )) >= 2
224224 ), f"At least two select expressions must be provided, but got { selects } ."
225-
226- existing_ctes : list [sge .CTE ] = []
227- union_selects : list [sge .Select ] = []
228- for select in selects :
229- assert isinstance (
230- select , sge .Select
231- ), f"All provided expressions must be of type sge.Select, but got { type (select )} "
232-
233- select_expr = select .copy ()
234- select_expr , select_ctes = _pop_query_ctes (select_expr )
235- existing_ctes = _merge_ctes (existing_ctes , select_ctes )
236- union_selects .append (select_expr )
237-
238- union_expr : sge .Query = union_selects [0 ].subquery ()
239- for select in union_selects [1 :]:
225+ union_expr : sge .Query = selects [0 ].subquery ()
226+ for select in selects [1 :]:
240227 union_expr = sge .Union (
241228 this = union_expr ,
242229 expression = select .subquery (),
@@ -254,7 +241,6 @@ def from_union(
254241 final_select_expr = (
255242 sge .Select ().select (* selections ).from_ (union_expr .subquery ())
256243 )
257- final_select_expr = _set_query_ctes (final_select_expr , existing_ctes )
258244 return cls (expr = final_select_expr , uid_gen = uid_gen )
259245
260246 def join (
@@ -266,12 +252,8 @@ def join(
266252 joins_nulls : bool = True ,
267253 ) -> SQLGlotIR :
268254 """Joins the current query with another SQLGlotIR instance."""
269- left_select , left_cte_name = self ._select_to_cte ()
270- right_select , right_cte_name = right ._select_to_cte ()
271-
272- left_select , left_ctes = _pop_query_ctes (left_select )
273- right_select , right_ctes = _pop_query_ctes (right_select )
274- merged_ctes = _merge_ctes (left_ctes , right_ctes )
255+ left_from = self ._as_from_item ()
256+ right_from = right ._as_from_item ()
275257
276258 join_on = _and (
277259 tuple (
@@ -283,10 +265,9 @@ def join(
283265 new_expr = (
284266 sge .Select ()
285267 .select (sge .Star ())
286- .from_ (sge . Table ( this = left_cte_name ) )
287- .join (sge . Table ( this = right_cte_name ) , on = join_on , join_type = join_type_str )
268+ .from_ (left_from )
269+ .join (right_from , on = join_on , join_type = join_type_str )
288270 )
289- new_expr = _set_query_ctes (new_expr , merged_ctes )
290271
291272 return SQLGlotIR (expr = new_expr , uid_gen = self .uid_gen )
292273
@@ -298,16 +279,12 @@ def isin_join(
298279 joins_nulls : bool = True ,
299280 ) -> SQLGlotIR :
300281 """Joins the current query with another SQLGlotIR instance."""
301- left_select , left_cte_name = self ._select_to_cte ()
282+ left_from = self ._as_from_item ()
302283 # Prefer subquery over CTE for the IN clause's right side to improve SQL readability.
303284 right_select = right ._as_select ()
304285
305- left_select , left_ctes = _pop_query_ctes (left_select )
306- right_select , right_ctes = _pop_query_ctes (right_select )
307- merged_ctes = _merge_ctes (left_ctes , right_ctes )
308-
309286 left_condition = typed_expr .TypedExpr (
310- sge .Column (this = conditions [0 ].expr , table = left_cte_name ),
287+ sge .Column (this = conditions [0 ].expr , table = left_from ),
311288 conditions [0 ].dtype ,
312289 )
313290
@@ -341,10 +318,9 @@ def isin_join(
341318
342319 new_expr = (
343320 sge .Select ()
344- .select (sge .Column (this = sge .Star (), table = left_cte_name ), new_column )
345- .from_ (sge . Table ( this = left_cte_name ) )
321+ .select (sge .Column (this = sge .Star (), table = left_from ), new_column )
322+ .from_ (left_from )
346323 )
347- new_expr = _set_query_ctes (new_expr , merged_ctes )
348324
349325 return SQLGlotIR (expr = new_expr , uid_gen = self .uid_gen )
350326
@@ -368,7 +344,7 @@ def sample(self, fraction: float) -> SQLGlotIR:
368344 expression = _literal (fraction , dtypes .FLOAT_DTYPE ),
369345 )
370346
371- new_expr = self ._select_to_cte ()[ 0 ] .where (condition , append = False )
347+ new_expr = self ._as_select () .where (condition , append = False )
372348 return SQLGlotIR (expr = new_expr , uid_gen = self .uid_gen )
373349
374350 def aggregate (
@@ -392,7 +368,7 @@ def aggregate(
392368 for id , expr in aggregations
393369 ]
394370
395- new_expr , _ = self ._select_to_cte ()
371+ new_expr = self ._as_select ()
396372 new_expr = new_expr .group_by (* by_cols ).select (
397373 * [* by_cols , * aggregations_expr ], append = False
398374 )
@@ -407,12 +383,26 @@ def aggregate(
407383 new_expr = new_expr .where (condition , append = False )
408384 return SQLGlotIR (expr = new_expr , uid_gen = self .uid_gen )
409385
386+ def with_ctes (
387+ self ,
388+ ctes : tuple [tuple [str , sge .Select ], ...],
389+ ) -> SQLGlotIR :
390+ sge_ctes = [
391+ sge .CTE (
392+ this = cte ,
393+ alias = cte_name ,
394+ )
395+ for cte_name , cte in ctes
396+ ]
397+ select_expr = _set_query_ctes (self ._as_select (), sge_ctes )
398+ return SQLGlotIR (expr = select_expr , uid_gen = self .uid_gen )
399+
410400 def insert (
411401 self ,
412402 destination : bigquery .TableReference ,
413403 ) -> str :
414404 """Generates an INSERT INTO SQL statement from the current SELECT clause."""
415- return sge .insert (self ._as_from_item (), _table (destination )).sql (
405+ return sge .insert (self ._as_select (), _table (destination )).sql (
416406 dialect = self .dialect , pretty = self .pretty
417407 )
418408
@@ -436,7 +426,7 @@ def replace(
436426
437427 merge_str = sge .Merge (
438428 this = _table (destination ),
439- using = self ._as_from_item (),
429+ using = self ._as_select (),
440430 on = _literal (False , dtypes .BOOL_DTYPE ),
441431 ).sql (dialect = self .dialect , pretty = self .pretty )
442432 return f"{ merge_str } \n { whens_str } "
@@ -459,7 +449,7 @@ def _explode_single_column(
459449 )
460450 selection = sge .Star (replace = [unnested_column_alias .as_ (column )])
461451
462- new_expr , _ = self ._select_to_cte ()
452+ new_expr = self ._as_select ()
463453 # Use LEFT JOIN to preserve rows when unnesting empty arrays.
464454 new_expr = new_expr .select (selection , append = False ).join (
465455 unnest_expr , join_type = "LEFT"
@@ -510,7 +500,7 @@ def _explode_multiple_columns(
510500 for column in columns
511501 ]
512502 )
513- new_expr , _ = self ._select_to_cte ()
503+ new_expr = self ._as_select ()
514504 # Use LEFT JOIN to preserve rows when unnesting empty arrays.
515505 new_expr = new_expr .select (selection , append = False ).join (
516506 unnest_expr , join_type = "LEFT"
@@ -532,25 +522,6 @@ def _as_select(self) -> sge.Select:
532522 def _as_subquery (self ) -> sge .Subquery :
533523 return self ._as_select ().subquery ()
534524
535- def _select_to_cte (self ) -> tuple [sge .Select , sge .Identifier ]:
536- """Transforms a given sge.Select query by pushing its main SELECT statement
537- into a new CTE and then generates a 'SELECT * FROM new_cte_name'
538- for the new query."""
539- cte_name = sge .to_identifier (
540- next (self .uid_gen .get_uid_stream ("bfcte_" )), quoted = self .quoted
541- )
542- select_expr = self ._as_select ().copy ()
543- select_expr , existing_ctes = _pop_query_ctes (select_expr )
544- new_cte = sge .CTE (
545- this = select_expr ,
546- alias = cte_name ,
547- )
548- new_select_expr = (
549- sge .Select ().select (sge .Star ()).from_ (sge .Table (this = cte_name ))
550- )
551- new_select_expr = _set_query_ctes (new_select_expr , [* existing_ctes , new_cte ])
552- return new_select_expr , cte_name
553-
554525
555526def _is_null_literal (expr : sge .Expression ) -> bool :
556527 """Checks if the given expression is a NULL literal."""
@@ -743,26 +714,3 @@ def _set_query_ctes(
743714 else :
744715 raise ValueError ("The expression does not support CTEs." )
745716 return new_expr
746-
747-
748- def _merge_ctes (ctes1 : list [sge .CTE ], ctes2 : list [sge .CTE ]) -> list [sge .CTE ]:
749- """Merges two lists of CTEs, de-duplicating by alias name."""
750- seen = {cte .alias : cte for cte in ctes1 }
751- for cte in ctes2 :
752- if cte .alias not in seen :
753- seen [cte .alias ] = cte
754- return list (seen .values ())
755-
756-
757- def _pop_query_ctes (
758- expr : sge .Select ,
759- ) -> tuple [sge .Select , list [sge .CTE ]]:
760- """Pops the CTEs of a given sge.Select expression."""
761- if "with" in expr .arg_types .keys ():
762- expr_ctes = expr .args .pop ("with" , [])
763- return expr , expr_ctes
764- elif "with_" in expr .arg_types .keys ():
765- expr_ctes = expr .args .pop ("with_" , [])
766- return expr , expr_ctes
767- else :
768- raise ValueError ("The expression does not support CTEs." )
0 commit comments