Skip to content

Commit 1f97d96

Browse files
committed
Add more EXPLAIN AST output fixes (437/484 tests passing)
- Add EXISTS subquery wrapping - Add IN subquery wrapping - Add FORMAT clause handling for SELECT and INSERT - Add EXPLAIN query handling with type normalization - Add ALTER command type normalization (FREEZE -> FREEZE_ALL) - Fix ADD_CONSTRAINT to not output constraint name separately - Add backslash escaping in string literals
1 parent c08774d commit 1f97d96

1 file changed

Lines changed: 69 additions & 15 deletions

File tree

ast/explain.go

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,25 @@ func explainNode(b *strings.Builder, node interface{}, depth int) {
1919

2020
switch n := node.(type) {
2121
case *SelectWithUnionQuery:
22-
children := len(n.Selects)
23-
fmt.Fprintf(b, "%sSelectWithUnionQuery (children 1)\n", indent)
24-
fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, children)
22+
// Check if first select has Format clause
23+
var format *Identifier
24+
if len(n.Selects) > 0 {
25+
if sq, ok := n.Selects[0].(*SelectQuery); ok && sq.Format != nil {
26+
format = sq.Format
27+
}
28+
}
29+
unionChildren := 1 // ExpressionList
30+
if format != nil {
31+
unionChildren++
32+
}
33+
fmt.Fprintf(b, "%sSelectWithUnionQuery (children %d)\n", indent, unionChildren)
34+
fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(n.Selects))
2535
for _, sel := range n.Selects {
2636
explainNode(b, sel, depth+2)
2737
}
38+
if format != nil {
39+
fmt.Fprintf(b, "%s Identifier %s\n", indent, format.Name())
40+
}
2841

2942
case *SelectQuery:
3043
children := countSelectQueryChildren(n)
@@ -244,7 +257,9 @@ func explainNode(b *strings.Builder, node interface{}, depth int) {
244257
fmt.Fprintf(b, "%s ExpressionList (children 2)\n", indent)
245258
explainNode(b, n.Expr, depth+2)
246259
if n.Query != nil {
247-
explainNode(b, n.Query, depth+2)
260+
// Wrap query in Subquery node
261+
fmt.Fprintf(b, "%s Subquery (children 1)\n", indent)
262+
explainNode(b, n.Query, depth+3)
248263
} else {
249264
// List is shown as a Tuple literal
250265
explainInListAsTuple(b, n.List, depth+2)
@@ -345,7 +360,9 @@ func explainNode(b *strings.Builder, node interface{}, depth int) {
345360
case *ExistsExpr:
346361
fmt.Fprintf(b, "%sFunction exists (children 1)\n", indent)
347362
fmt.Fprintf(b, "%s ExpressionList (children 1)\n", indent)
348-
explainNode(b, n.Query, depth+2)
363+
// Wrap query in Subquery node
364+
fmt.Fprintf(b, "%s Subquery (children 1)\n", indent)
365+
explainNode(b, n.Query, depth+3)
349366

350367
case *DataType:
351368
// Data types in expressions (like in CAST)
@@ -412,12 +429,21 @@ func explainNode(b *strings.Builder, node interface{}, depth int) {
412429
if n.Database != "" {
413430
tableName = n.Database + "." + tableName
414431
}
415-
children := 1
432+
children := 1 // Always have table identifier
433+
if len(n.Columns) > 0 {
434+
children++ // column list
435+
}
416436
if n.Select != nil {
417437
children++
418438
}
419-
fmt.Fprintf(b, "%sInsertQuery %s (children %d)\n", indent, tableName, children)
439+
fmt.Fprintf(b, "%sInsertQuery (children %d)\n", indent, children)
420440
fmt.Fprintf(b, "%s Identifier %s\n", indent, tableName)
441+
if len(n.Columns) > 0 {
442+
fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(n.Columns))
443+
for _, col := range n.Columns {
444+
fmt.Fprintf(b, "%s Identifier %s\n", indent, col.Name())
445+
}
446+
}
421447
if n.Select != nil {
422448
explainNode(b, n.Select, depth+1)
423449
}
@@ -461,6 +487,16 @@ func explainNode(b *strings.Builder, node interface{}, depth int) {
461487
fmt.Fprintf(b, "%s Identifier %s\n", indent, n.Table1)
462488
fmt.Fprintf(b, "%s Identifier %s\n", indent, n.Table2)
463489

490+
case *ExplainQuery:
491+
explainType := string(n.ExplainType)
492+
if explainType == "" {
493+
explainType = "EXPLAIN"
494+
} else {
495+
explainType = "EXPLAIN " + explainType
496+
}
497+
fmt.Fprintf(b, "%sExplain %s (children 1)\n", indent, explainType)
498+
explainNode(b, n.Statement, depth+1)
499+
464500
default:
465501
// For unknown types, just print the type name
466502
fmt.Fprintf(b, "%s%T\n", indent, n)
@@ -542,7 +578,9 @@ func explainLiteral(b *strings.Builder, lit *Literal, alias string, depth int) {
542578

543579
switch lit.Type {
544580
case LiteralString:
545-
valueStr = fmt.Sprintf("\\'%v\\'", lit.Value)
581+
// Escape backslashes in string literals (ClickHouse doubles them)
582+
strVal := strings.ReplaceAll(fmt.Sprintf("%v", lit.Value), "\\", "\\\\")
583+
valueStr = fmt.Sprintf("\\'%s\\'", strVal)
546584
case LiteralInteger:
547585
valueStr = fmt.Sprintf("UInt64_%v", lit.Value)
548586
case LiteralFloat:
@@ -585,7 +623,8 @@ func formatArrayLiteral(value interface{}) string {
585623
if lit, ok := elem.(*Literal); ok {
586624
switch lit.Type {
587625
case LiteralString:
588-
parts[i] = fmt.Sprintf("\\'%v\\'", lit.Value)
626+
escaped := strings.ReplaceAll(fmt.Sprintf("%v", lit.Value), "\\", "\\\\")
627+
parts[i] = fmt.Sprintf("\\'%s\\'", escaped)
589628
case LiteralInteger:
590629
parts[i] = fmt.Sprintf("UInt64_%v", lit.Value)
591630
case LiteralFloat:
@@ -622,7 +661,8 @@ func formatTupleLiteral(value interface{}) string {
622661
if lit, ok := elem.(*Literal); ok {
623662
switch lit.Type {
624663
case LiteralString:
625-
parts[i] = fmt.Sprintf("\\'%v\\'", lit.Value)
664+
escaped := strings.ReplaceAll(fmt.Sprintf("%v", lit.Value), "\\", "\\\\")
665+
parts[i] = fmt.Sprintf("\\'%s\\'", escaped)
626666
case LiteralInteger:
627667
parts[i] = fmt.Sprintf("UInt64_%v", lit.Value)
628668
case LiteralFloat:
@@ -670,7 +710,8 @@ func explainInListAsTuple(b *strings.Builder, list []Expression, depth int) {
670710
func formatLiteralElement(elem interface{}) string {
671711
switch e := elem.(type) {
672712
case string:
673-
return fmt.Sprintf("\\'%s\\'", e)
713+
escaped := strings.ReplaceAll(e, "\\", "\\\\")
714+
return fmt.Sprintf("\\'%s\\'", escaped)
674715
case int, int64, uint64:
675716
return fmt.Sprintf("UInt64_%v", e)
676717
case float64:
@@ -1097,6 +1138,16 @@ func extractFieldToFunction(field string) string {
10971138
}
10981139
}
10991140

1141+
// normalizeAlterCommandType normalizes ALTER command types to match ClickHouse output.
1142+
func normalizeAlterCommandType(t AlterCommandType) string {
1143+
switch t {
1144+
case AlterFreeze:
1145+
return "FREEZE_ALL"
1146+
default:
1147+
return string(t)
1148+
}
1149+
}
1150+
11001151
// explainAlterCommand formats an ALTER command.
11011152
func explainAlterCommand(b *strings.Builder, cmd *AlterCommand, depth int) {
11021153
indent := strings.Repeat(" ", depth)
@@ -1123,14 +1174,16 @@ func explainAlterCommand(b *strings.Builder, cmd *AlterCommand, depth int) {
11231174
if cmd.Index != "" && cmd.IndexExpr == nil {
11241175
children++
11251176
}
1126-
if cmd.ConstraintName != "" {
1177+
// Don't count ConstraintName for ADD_CONSTRAINT as it's part of the Constraint structure
1178+
if cmd.ConstraintName != "" && cmd.Type != AlterAddConstraint {
11271179
children++
11281180
}
11291181

1182+
cmdType := normalizeAlterCommandType(cmd.Type)
11301183
if children > 0 {
1131-
fmt.Fprintf(b, "%sAlterCommand %s (children %d)\n", indent, cmd.Type, children)
1184+
fmt.Fprintf(b, "%sAlterCommand %s (children %d)\n", indent, cmdType, children)
11321185
} else {
1133-
fmt.Fprintf(b, "%sAlterCommand %s\n", indent, cmd.Type)
1186+
fmt.Fprintf(b, "%sAlterCommand %s\n", indent, cmdType)
11341187
}
11351188

11361189
if cmd.Column != nil {
@@ -1160,7 +1213,8 @@ func explainAlterCommand(b *strings.Builder, cmd *AlterCommand, depth int) {
11601213
if cmd.Index != "" && cmd.IndexExpr == nil {
11611214
fmt.Fprintf(b, "%s Identifier %s\n", indent, cmd.Index)
11621215
}
1163-
if cmd.ConstraintName != "" {
1216+
// Don't output ConstraintName for ADD_CONSTRAINT
1217+
if cmd.ConstraintName != "" && cmd.Type != AlterAddConstraint {
11641218
fmt.Fprintf(b, "%s Identifier %s\n", indent, cmd.ConstraintName)
11651219
}
11661220
}

0 commit comments

Comments
 (0)