@@ -222,8 +222,15 @@ func explainSelectWithUnionQueryWithInheritedWith(sb *strings.Builder, n *ast.Se
222222 fmt .Fprintf (sb , "%sSelectWithUnionQuery (children %d)\n " , indent , children )
223223
224224 selects := simplifyUnionSelects (n .Selects )
225- fmt .Fprintf (sb , "%s ExpressionList (children %d)\n " , indent , len (selects ))
226- for _ , sel := range selects {
225+
226+ // Expand any nested SelectWithUnionQuery that would be grouped
227+ expandedSelects , expandedModes := expandNestedUnions (selects , n .UnionModes )
228+
229+ // Check if we need to group selects due to mode changes
230+ groupedSelects := groupSelectsByUnionMode (expandedSelects , expandedModes )
231+
232+ fmt .Fprintf (sb , "%s ExpressionList (children %d)\n " , indent , len (groupedSelects ))
233+ for _ , sel := range groupedSelects {
227234 ExplainSelectWithInheritedWith (sb , sel , inheritedWith , depth + 2 )
228235 }
229236
@@ -299,9 +306,13 @@ func explainSelectWithUnionQuery(sb *strings.Builder, n *ast.SelectWithUnionQuer
299306 // In that case, only the first SELECT is shown since column names come from the first SELECT anyway.
300307 selects := simplifyUnionSelects (n .Selects )
301308
309+ // Expand any nested SelectWithUnionQuery that would be grouped
310+ // This flattens [S1, nested(5)] into [S1, grouped(4), S6] when grouping applies
311+ expandedSelects , expandedModes := expandNestedUnions (selects , n .UnionModes )
312+
302313 // Check if we need to group selects due to mode changes
303314 // e.g., A UNION DISTINCT B UNION ALL C -> (A UNION DISTINCT B) UNION ALL C
304- groupedSelects := groupSelectsByUnionMode (selects , n . UnionModes )
315+ groupedSelects := groupSelectsByUnionMode (expandedSelects , expandedModes )
305316
306317 // Wrap selects in ExpressionList
307318 fmt .Fprintf (sb , "%s ExpressionList (children %d)\n " , indent , len (groupedSelects ))
@@ -625,6 +636,99 @@ func simplifyUnionSelects(selects []ast.Statement) []ast.Statement {
625636 return selects
626637}
627638
639+ // expandNestedUnions expands nested SelectWithUnionQuery elements.
640+ // - If a nested union has only ALL modes, it's completely flattened
641+ // - If a nested union has a DISTINCT->ALL transition, it's expanded to grouped results
642+ // For example, [S1, nested(S2,S3,S4,S5,S6)] with modes [ALL] where nested has modes [ALL,"",DISTINCT,ALL]
643+ // becomes [S1, grouped(S2,S3,S4,S5), S6] with modes [ALL, ALL]
644+ func expandNestedUnions (selects []ast.Statement , unionModes []string ) ([]ast.Statement , []string ) {
645+ result := make ([]ast.Statement , 0 , len (selects ))
646+ resultModes := make ([]string , 0 , len (unionModes ))
647+
648+ // Helper to check if all modes are ALL
649+ allModesAreAll := func (modes []string ) bool {
650+ for _ , m := range modes {
651+ normalized := m
652+ if len (m ) > 6 && m [:6 ] == "UNION " {
653+ normalized = m [6 :]
654+ }
655+ if normalized != "ALL" && normalized != "" {
656+ // "" can be bare UNION which may default to DISTINCT
657+ // but we treat it as potentially non-ALL
658+ return false
659+ }
660+ // For "" (bare UNION), we check if it's truly all-ALL by also checking
661+ // that DISTINCT is not present
662+ if normalized == "" {
663+ return false // bare UNION may be DISTINCT based on settings
664+ }
665+ }
666+ return true
667+ }
668+
669+ for i , sel := range selects {
670+ if nested , ok := sel .(* ast.SelectWithUnionQuery ); ok {
671+ // Single select in parentheses - flatten it
672+ if len (nested .Selects ) == 1 {
673+ result = append (result , nested .Selects [0 ])
674+ if i > 0 && i - 1 < len (unionModes ) {
675+ resultModes = append (resultModes , unionModes [i - 1 ])
676+ }
677+ continue
678+ }
679+ // Check if all nested modes are ALL - if so, flatten completely
680+ if allModesAreAll (nested .UnionModes ) {
681+ // Flatten completely: add outer mode first, then all nested selects and modes
682+ if i > 0 && i - 1 < len (unionModes ) {
683+ resultModes = append (resultModes , unionModes [i - 1 ])
684+ }
685+ // Add first nested select
686+ if len (nested .Selects ) > 0 {
687+ // Recursively expand in case of deeply nested unions
688+ expandedNested , expandedNestedModes := expandNestedUnions (nested .Selects , nested .UnionModes )
689+ for j , s := range expandedNested {
690+ result = append (result , s )
691+ if j < len (expandedNestedModes ) {
692+ resultModes = append (resultModes , expandedNestedModes [j ])
693+ }
694+ }
695+ }
696+ } else {
697+ // Check if this nested union would be grouped (DISTINCT->ALL transition)
698+ grouped := groupSelectsByUnionMode (nested .Selects , nested .UnionModes )
699+ if len (grouped ) > 1 {
700+ // Grouping produced multiple elements - expand them
701+ // The outer mode (if any) applies to the first expanded element
702+ if i > 0 && i - 1 < len (unionModes ) {
703+ resultModes = append (resultModes , unionModes [i - 1 ])
704+ }
705+ // Add all grouped elements and their modes
706+ for j , g := range grouped {
707+ result = append (result , g )
708+ if j < len (grouped )- 1 {
709+ // Mode between grouped elements is ALL (from the transition point)
710+ resultModes = append (resultModes , "UNION ALL" )
711+ }
712+ }
713+ } else {
714+ // No grouping, keep as-is
715+ result = append (result , sel )
716+ if i > 0 && i - 1 < len (unionModes ) {
717+ resultModes = append (resultModes , unionModes [i - 1 ])
718+ }
719+ }
720+ }
721+ } else {
722+ result = append (result , sel )
723+ if i > 0 && i - 1 < len (unionModes ) {
724+ resultModes = append (resultModes , unionModes [i - 1 ])
725+ }
726+ }
727+ }
728+
729+ return result , resultModes
730+ }
731+
628732// groupSelectsByUnionMode groups selects when union modes change from DISTINCT to ALL.
629733// For example, A UNION DISTINCT B UNION ALL C becomes (A UNION DISTINCT B) UNION ALL C.
630734// This matches ClickHouse's EXPLAIN AST output which nests DISTINCT groups before ALL.
@@ -642,19 +746,17 @@ func groupSelectsByUnionMode(selects []ast.Statement, unionModes []string) []ast
642746 return mode
643747 }
644748
645- // Only group when DISTINCT transitions to ALL
646- // Find first DISTINCT mode, then check if it's followed by ALL
647- firstMode := normalizeMode (unionModes [0 ])
648- if firstMode != "DISTINCT" {
649- return selects
650- }
651-
652- // Find where DISTINCT ends and ALL begins
749+ // Find the last DISTINCT->ALL transition
750+ // A transition occurs when a non-ALL mode (DISTINCT or bare "") is followed by ALL
653751 modeChangeIdx := - 1
654752 for i := 1 ; i < len (unionModes ); i ++ {
655- if normalizeMode (unionModes [i ]) == "ALL" {
753+ prevMode := normalizeMode (unionModes [i - 1 ])
754+ currMode := normalizeMode (unionModes [i ])
755+ // Check for non-ALL -> ALL transition
756+ // Non-ALL means DISTINCT or "" (bare UNION, which defaults to DISTINCT)
757+ if currMode == "ALL" && prevMode != "ALL" {
656758 modeChangeIdx = i
657- break
759+ // Continue to find the LAST such transition
658760 }
659761 }
660762
0 commit comments