Skip to content

Commit 660438e

Browse files
kyleconroyclaude
andcommitted
Fix UNION grouping for parenthesized queries with DISTINCT->ALL transitions
When a parenthesized UNION query contains a DISTINCT->ALL mode transition, the explain output should group the DISTINCT portion into a nested SelectWithUnionQuery and lift the remaining selects to the outer level. Changes: - Parser now keeps parenthesized unions as nested SelectWithUnionQuery - Added expandNestedUnions() to flatten/expand nested unions appropriately: - Single-select nested unions are flattened - All-ALL mode nested unions are fully flattened - Nested unions with DISTINCT->ALL transitions are expanded to grouped results - Updated groupSelectsByUnionMode() to find the last non-ALL->ALL transition - Applied expansion logic to both regular and inherited-WITH explain paths Fixes stmt13 and stmt28 in 01529_union_distinct_and_setting_union_default_mode Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent f47f283 commit 660438e

File tree

3 files changed

+121
-27
lines changed

3 files changed

+121
-27
lines changed

internal/explain/select.go

Lines changed: 115 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,15 @@ func explainSelectWithUnionQueryWithInheritedWith(sb *strings.Builder, n *ast.Se
222222
fmt.Fprintf(sb, "%sSelectWithUnionQuery (children %d)\n", indent, children)
223223

224224
selects := simplifyUnionSelects(n.Selects)
225-
fmt.Fprintf(sb, "%s ExpressionList (children %d)\n", indent, len(selects))
226-
for _, sel := range selects {
225+
226+
// Expand any nested SelectWithUnionQuery that would be grouped
227+
expandedSelects, expandedModes := expandNestedUnions(selects, n.UnionModes)
228+
229+
// Check if we need to group selects due to mode changes
230+
groupedSelects := groupSelectsByUnionMode(expandedSelects, expandedModes)
231+
232+
fmt.Fprintf(sb, "%s ExpressionList (children %d)\n", indent, len(groupedSelects))
233+
for _, sel := range groupedSelects {
227234
ExplainSelectWithInheritedWith(sb, sel, inheritedWith, depth+2)
228235
}
229236

@@ -299,9 +306,13 @@ func explainSelectWithUnionQuery(sb *strings.Builder, n *ast.SelectWithUnionQuer
299306
// In that case, only the first SELECT is shown since column names come from the first SELECT anyway.
300307
selects := simplifyUnionSelects(n.Selects)
301308

309+
// Expand any nested SelectWithUnionQuery that would be grouped
310+
// This flattens [S1, nested(5)] into [S1, grouped(4), S6] when grouping applies
311+
expandedSelects, expandedModes := expandNestedUnions(selects, n.UnionModes)
312+
302313
// Check if we need to group selects due to mode changes
303314
// e.g., A UNION DISTINCT B UNION ALL C -> (A UNION DISTINCT B) UNION ALL C
304-
groupedSelects := groupSelectsByUnionMode(selects, n.UnionModes)
315+
groupedSelects := groupSelectsByUnionMode(expandedSelects, expandedModes)
305316

306317
// Wrap selects in ExpressionList
307318
fmt.Fprintf(sb, "%s ExpressionList (children %d)\n", indent, len(groupedSelects))
@@ -625,6 +636,99 @@ func simplifyUnionSelects(selects []ast.Statement) []ast.Statement {
625636
return selects
626637
}
627638

639+
// expandNestedUnions expands nested SelectWithUnionQuery elements.
640+
// - If a nested union has only ALL modes, it's completely flattened
641+
// - If a nested union has a DISTINCT->ALL transition, it's expanded to grouped results
642+
// For example, [S1, nested(S2,S3,S4,S5,S6)] with modes [ALL] where nested has modes [ALL,"",DISTINCT,ALL]
643+
// becomes [S1, grouped(S2,S3,S4,S5), S6] with modes [ALL, ALL]
644+
func expandNestedUnions(selects []ast.Statement, unionModes []string) ([]ast.Statement, []string) {
645+
result := make([]ast.Statement, 0, len(selects))
646+
resultModes := make([]string, 0, len(unionModes))
647+
648+
// Helper to check if all modes are ALL
649+
allModesAreAll := func(modes []string) bool {
650+
for _, m := range modes {
651+
normalized := m
652+
if len(m) > 6 && m[:6] == "UNION " {
653+
normalized = m[6:]
654+
}
655+
if normalized != "ALL" && normalized != "" {
656+
// "" can be bare UNION which may default to DISTINCT
657+
// but we treat it as potentially non-ALL
658+
return false
659+
}
660+
// For "" (bare UNION), we check if it's truly all-ALL by also checking
661+
// that DISTINCT is not present
662+
if normalized == "" {
663+
return false // bare UNION may be DISTINCT based on settings
664+
}
665+
}
666+
return true
667+
}
668+
669+
for i, sel := range selects {
670+
if nested, ok := sel.(*ast.SelectWithUnionQuery); ok {
671+
// Single select in parentheses - flatten it
672+
if len(nested.Selects) == 1 {
673+
result = append(result, nested.Selects[0])
674+
if i > 0 && i-1 < len(unionModes) {
675+
resultModes = append(resultModes, unionModes[i-1])
676+
}
677+
continue
678+
}
679+
// Check if all nested modes are ALL - if so, flatten completely
680+
if allModesAreAll(nested.UnionModes) {
681+
// Flatten completely: add outer mode first, then all nested selects and modes
682+
if i > 0 && i-1 < len(unionModes) {
683+
resultModes = append(resultModes, unionModes[i-1])
684+
}
685+
// Add first nested select
686+
if len(nested.Selects) > 0 {
687+
// Recursively expand in case of deeply nested unions
688+
expandedNested, expandedNestedModes := expandNestedUnions(nested.Selects, nested.UnionModes)
689+
for j, s := range expandedNested {
690+
result = append(result, s)
691+
if j < len(expandedNestedModes) {
692+
resultModes = append(resultModes, expandedNestedModes[j])
693+
}
694+
}
695+
}
696+
} else {
697+
// Check if this nested union would be grouped (DISTINCT->ALL transition)
698+
grouped := groupSelectsByUnionMode(nested.Selects, nested.UnionModes)
699+
if len(grouped) > 1 {
700+
// Grouping produced multiple elements - expand them
701+
// The outer mode (if any) applies to the first expanded element
702+
if i > 0 && i-1 < len(unionModes) {
703+
resultModes = append(resultModes, unionModes[i-1])
704+
}
705+
// Add all grouped elements and their modes
706+
for j, g := range grouped {
707+
result = append(result, g)
708+
if j < len(grouped)-1 {
709+
// Mode between grouped elements is ALL (from the transition point)
710+
resultModes = append(resultModes, "UNION ALL")
711+
}
712+
}
713+
} else {
714+
// No grouping, keep as-is
715+
result = append(result, sel)
716+
if i > 0 && i-1 < len(unionModes) {
717+
resultModes = append(resultModes, unionModes[i-1])
718+
}
719+
}
720+
}
721+
} else {
722+
result = append(result, sel)
723+
if i > 0 && i-1 < len(unionModes) {
724+
resultModes = append(resultModes, unionModes[i-1])
725+
}
726+
}
727+
}
728+
729+
return result, resultModes
730+
}
731+
628732
// groupSelectsByUnionMode groups selects when union modes change from DISTINCT to ALL.
629733
// For example, A UNION DISTINCT B UNION ALL C becomes (A UNION DISTINCT B) UNION ALL C.
630734
// This matches ClickHouse's EXPLAIN AST output which nests DISTINCT groups before ALL.
@@ -642,19 +746,17 @@ func groupSelectsByUnionMode(selects []ast.Statement, unionModes []string) []ast
642746
return mode
643747
}
644748

645-
// Only group when DISTINCT transitions to ALL
646-
// Find first DISTINCT mode, then check if it's followed by ALL
647-
firstMode := normalizeMode(unionModes[0])
648-
if firstMode != "DISTINCT" {
649-
return selects
650-
}
651-
652-
// Find where DISTINCT ends and ALL begins
749+
// Find the last DISTINCT->ALL transition
750+
// A transition occurs when a non-ALL mode (DISTINCT or bare "") is followed by ALL
653751
modeChangeIdx := -1
654752
for i := 1; i < len(unionModes); i++ {
655-
if normalizeMode(unionModes[i]) == "ALL" {
753+
prevMode := normalizeMode(unionModes[i-1])
754+
currMode := normalizeMode(unionModes[i])
755+
// Check for non-ALL -> ALL transition
756+
// Non-ALL means DISTINCT or "" (bare UNION, which defaults to DISTINCT)
757+
if currMode == "ALL" && prevMode != "ALL" {
656758
modeChangeIdx = i
657-
break
759+
// Continue to find the LAST such transition
658760
}
659761
}
660762

parser/parser.go

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -712,10 +712,9 @@ func (p *Parser) parseSelectWithUnion() *ast.SelectWithUnionQuery {
712712
break
713713
}
714714
p.expect(token.RPAREN)
715-
// Flatten nested selects into current query
716-
for _, s := range nested.Selects {
717-
query.Selects = append(query.Selects, s)
718-
}
715+
// Keep parenthesized union as nested SelectWithUnionQuery
716+
// This allows proper grouping in the explain phase
717+
query.Selects = append(query.Selects, nested)
719718
} else {
720719
sel := p.parseSelect()
721720
if sel == nil {
@@ -7692,10 +7691,8 @@ func (p *Parser) parseParenthesizedSelect() *ast.SelectWithUnionQuery {
76927691
break
76937692
}
76947693
p.expect(token.RPAREN)
7695-
// Flatten nested selects into current query
7696-
for _, s := range nested.Selects {
7697-
query.Selects = append(query.Selects, s)
7698-
}
7694+
// Keep parenthesized union as nested SelectWithUnionQuery
7695+
query.Selects = append(query.Selects, nested)
76997696
} else {
77007697
sel := p.parseSelect()
77017698
if sel == nil {
Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1 @@
1-
{
2-
"explain_todo": {
3-
"stmt13": true,
4-
"stmt28": true
5-
}
6-
}
1+
{}

0 commit comments

Comments
 (0)