@@ -298,16 +298,21 @@ func explainSelectWithUnionQuery(sb *strings.Builder, n *ast.SelectWithUnionQuer
298298 // ClickHouse optimizes UNION ALL when selects have identical expressions but different aliases.
299299 // In that case, only the first SELECT is shown since column names come from the first SELECT anyway.
300300 selects := simplifyUnionSelects (n .Selects )
301+
302+ // Check if we need to group selects due to mode changes
303+ // e.g., A UNION DISTINCT B UNION ALL C -> (A UNION DISTINCT B) UNION ALL C
304+ groupedSelects := groupSelectsByUnionMode (selects , n .UnionModes )
305+
301306 // Wrap selects in ExpressionList
302- fmt .Fprintf (sb , "%s ExpressionList (children %d)\n " , indent , len (selects ))
307+ fmt .Fprintf (sb , "%s ExpressionList (children %d)\n " , indent , len (groupedSelects ))
303308
304309 // Check if first operand has a WITH clause to be inherited by subsequent operands
305310 var inheritedWith []ast.Expression
306311 if len (selects ) > 0 {
307312 inheritedWith = extractWithClause (selects [0 ])
308313 }
309314
310- for i , sel := range selects {
315+ for i , sel := range groupedSelects {
311316 if i > 0 && len (inheritedWith ) > 0 {
312317 // Subsequent operands inherit the WITH clause from the first operand
313318 explainSelectQueryWithInheritedWith (sb , sel , inheritedWith , depth + 2 )
@@ -620,6 +625,62 @@ func simplifyUnionSelects(selects []ast.Statement) []ast.Statement {
620625 return selects
621626}
622627
628+ // groupSelectsByUnionMode groups selects when union modes change from DISTINCT to ALL.
629+ // For example, A UNION DISTINCT B UNION ALL C becomes (A UNION DISTINCT B) UNION ALL C.
630+ // This matches ClickHouse's EXPLAIN AST output which nests DISTINCT groups before ALL.
631+ // Note: The reverse (ALL followed by DISTINCT) does NOT trigger nesting.
632+ func groupSelectsByUnionMode (selects []ast.Statement , unionModes []string ) []ast.Statement {
633+ if len (selects ) < 3 || len (unionModes ) < 2 {
634+ return selects
635+ }
636+
637+ // Normalize union modes (strip "UNION " prefix if present)
638+ normalizeMode := func (mode string ) string {
639+ if len (mode ) > 6 && mode [:6 ] == "UNION " {
640+ return mode [6 :]
641+ }
642+ return mode
643+ }
644+
645+ // Only group when DISTINCT transitions to ALL
646+ // Find first DISTINCT mode, then check if it's followed by ALL
647+ firstMode := normalizeMode (unionModes [0 ])
648+ if firstMode != "DISTINCT" {
649+ return selects
650+ }
651+
652+ // Find where DISTINCT ends and ALL begins
653+ modeChangeIdx := - 1
654+ for i := 1 ; i < len (unionModes ); i ++ {
655+ if normalizeMode (unionModes [i ]) == "ALL" {
656+ modeChangeIdx = i
657+ break
658+ }
659+ }
660+
661+ // If no DISTINCT->ALL transition found, return as-is
662+ if modeChangeIdx == - 1 {
663+ return selects
664+ }
665+
666+ // Create a nested SelectWithUnionQuery for selects 0..modeChangeIdx (inclusive)
667+ // modeChangeIdx is the index of the union operator, so we include selects[0] through selects[modeChangeIdx]
668+ nestedSelects := selects [:modeChangeIdx + 1 ]
669+ nestedModes := unionModes [:modeChangeIdx ]
670+
671+ nested := & ast.SelectWithUnionQuery {
672+ Selects : nestedSelects ,
673+ UnionModes : nestedModes ,
674+ }
675+
676+ // Result is [nested, selects[modeChangeIdx+1], ...]
677+ result := make ([]ast.Statement , 0 , len (selects )- modeChangeIdx )
678+ result = append (result , nested )
679+ result = append (result , selects [modeChangeIdx + 1 :]... )
680+
681+ return result
682+ }
683+
623684func countSelectQueryChildren (n * ast.SelectQuery ) int {
624685 count := 1 // columns ExpressionList
625686 // WITH clause
0 commit comments