Skip to content

Commit 13e9661

Browse files
committed
Improve Explain function to pass more tests
- Fix string escaping in array literals to use backslash-escaped quotes - Add WITH clause support for CTEs - Add window frame bounds support (ROWS BETWEEN ... AND ...) - Add non-SELECT statement support (USE, TRUNCATE, ALTER, DROP, CREATE, etc.) - Fix boolean literals to use Bool_1/Bool_0 format - Fix CREATE query output format to match ClickHouse EXPLAIN AST Test coverage improved from 310 to 341 passing tests out of 484.
1 parent 19f9353 commit 13e9661

File tree

1 file changed

+294
-9
lines changed

1 file changed

+294
-9
lines changed

ast/explain.go

Lines changed: 294 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ func explainNode(b *strings.Builder, node interface{}, depth int) {
2929
case *SelectQuery:
3030
children := countSelectQueryChildren(n)
3131
fmt.Fprintf(b, "%sSelectQuery (children %d)\n", indent, children)
32+
// WITH clause (comes first)
33+
if len(n.With) > 0 {
34+
fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(n.With))
35+
for _, w := range n.With {
36+
explainNode(b, w, depth+2)
37+
}
38+
}
3239
// Columns
3340
if len(n.Columns) > 0 {
3441
fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(n.Columns))
@@ -300,9 +307,15 @@ func explainNode(b *strings.Builder, node interface{}, depth int) {
300307
explainNodeWithAlias(b, n.Expr, n.Alias, depth)
301308

302309
case *WithElement:
303-
fmt.Fprintf(b, "%sWithElement (children 1)\n", indent)
304-
fmt.Fprintf(b, "%s Identifier %s\n", indent, n.Name)
305-
explainNode(b, n.Query, depth+1)
310+
// For scalar WITH (WITH 1 AS x), output the expression with alias
311+
// For subquery WITH (WITH x AS (SELECT 1)), output as WithElement
312+
if _, isSubquery := n.Query.(*Subquery); isSubquery {
313+
fmt.Fprintf(b, "%sWithElement (children 1)\n", indent)
314+
explainNode(b, n.Query, depth+1)
315+
} else {
316+
// Scalar expression - output with alias
317+
explainNodeWithAlias(b, n.Query, n.Name, depth)
318+
}
306319

307320
case *ExistsExpr:
308321
fmt.Fprintf(b, "%sFunction exists (children 1)\n", indent)
@@ -321,6 +334,102 @@ func explainNode(b *strings.Builder, node interface{}, depth int) {
321334
fmt.Fprintf(b, "%sIdentifier %s\n", indent, n.Name)
322335
}
323336

337+
// Non-SELECT statements
338+
case *UseQuery:
339+
fmt.Fprintf(b, "%sUseQuery %s (children 1)\n", indent, n.Database)
340+
fmt.Fprintf(b, "%s Identifier %s\n", indent, n.Database)
341+
342+
case *TruncateQuery:
343+
tableName := n.Table
344+
if n.Database != "" {
345+
tableName = n.Database + "." + tableName
346+
}
347+
fmt.Fprintf(b, "%sTruncateQuery %s (children 1)\n", indent, tableName)
348+
fmt.Fprintf(b, "%s Identifier %s\n", indent, tableName)
349+
350+
case *AlterQuery:
351+
tableName := n.Table
352+
if n.Database != "" {
353+
tableName = n.Database + "." + tableName
354+
}
355+
fmt.Fprintf(b, "%sAlterQuery %s (children 2)\n", indent, tableName)
356+
fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(n.Commands))
357+
for _, cmd := range n.Commands {
358+
explainAlterCommand(b, cmd, depth+2)
359+
}
360+
fmt.Fprintf(b, "%s Identifier %s\n", indent, tableName)
361+
362+
case *DropQuery:
363+
var name string
364+
if n.DropDatabase {
365+
name = n.Database
366+
} else if n.View != "" {
367+
name = n.View
368+
} else {
369+
name = n.Table
370+
}
371+
if n.Database != "" && !n.DropDatabase {
372+
name = n.Database + "." + name
373+
}
374+
fmt.Fprintf(b, "%sDropQuery %s (children 1)\n", indent, name)
375+
fmt.Fprintf(b, "%s Identifier %s\n", indent, name)
376+
377+
case *CreateQuery:
378+
explainCreateQuery(b, n, depth)
379+
380+
case *InsertQuery:
381+
tableName := n.Table
382+
if n.Database != "" {
383+
tableName = n.Database + "." + tableName
384+
}
385+
children := 1
386+
if n.Select != nil {
387+
children++
388+
}
389+
fmt.Fprintf(b, "%sInsertQuery %s (children %d)\n", indent, tableName, children)
390+
fmt.Fprintf(b, "%s Identifier %s\n", indent, tableName)
391+
if n.Select != nil {
392+
explainNode(b, n.Select, depth+1)
393+
}
394+
395+
case *SystemQuery:
396+
fmt.Fprintf(b, "%sSystemQuery %s\n", indent, n.Command)
397+
398+
case *OptimizeQuery:
399+
tableName := n.Table
400+
if n.Database != "" {
401+
tableName = n.Database + "." + tableName
402+
}
403+
fmt.Fprintf(b, "%sOptimizeQuery %s (children 1)\n", indent, tableName)
404+
fmt.Fprintf(b, "%s Identifier %s\n", indent, tableName)
405+
406+
case *DescribeQuery:
407+
tableName := n.Table
408+
if n.Database != "" {
409+
tableName = n.Database + "." + tableName
410+
}
411+
fmt.Fprintf(b, "%sDescribeQuery %s (children 1)\n", indent, tableName)
412+
fmt.Fprintf(b, "%s Identifier %s\n", indent, tableName)
413+
414+
case *ShowQuery:
415+
fmt.Fprintf(b, "%sShowQuery %s\n", indent, n.ShowType)
416+
417+
case *SetQuery:
418+
fmt.Fprintf(b, "%sSetQuery (children %d)\n", indent, len(n.Settings))
419+
for _, s := range n.Settings {
420+
fmt.Fprintf(b, "%s SettingExpr %s\n", indent, s.Name)
421+
}
422+
423+
case *RenameQuery:
424+
fmt.Fprintf(b, "%sRenameQuery (children 2)\n", indent)
425+
fmt.Fprintf(b, "%s Identifier %s\n", indent, n.From)
426+
fmt.Fprintf(b, "%s Identifier %s\n", indent, n.To)
427+
428+
case *ExchangeQuery:
429+
fmt.Fprintf(b, "%sExchangeQuery (children 2)\n", indent)
430+
fmt.Fprintf(b, "%s Identifier %s\n", indent, n.Table1)
431+
fmt.Fprintf(b, "%s Identifier %s\n", indent, n.Table2)
432+
324433
default:
325434
// For unknown types, just print the type name
326435
fmt.Fprintf(b, "%s%T\n", indent, n)
@@ -409,9 +518,9 @@ func explainLiteral(b *strings.Builder, lit *Literal, alias string, depth int) {
409518
valueStr = fmt.Sprintf("Float64_%v", lit.Value)
410519
case LiteralBoolean:
411520
if lit.Value.(bool) {
412-
valueStr = "UInt8_1"
521+
valueStr = "Bool_1"
413522
} else {
414-
valueStr = "UInt8_0"
523+
valueStr = "Bool_0"
415524
}
416525
case LiteralNull:
417526
valueStr = "NULL"
@@ -498,16 +607,16 @@ func explainInListAsTuple(b *strings.Builder, list []Expression, depth int) {
498607
func formatLiteralElement(elem interface{}) string {
499608
switch e := elem.(type) {
500609
case string:
501-
return fmt.Sprintf("'%s'", e)
610+
return fmt.Sprintf("\\'%s\\'", e)
502611
case int, int64, uint64:
503612
return fmt.Sprintf("UInt64_%v", e)
504613
case float64:
505614
return fmt.Sprintf("Float64_%v", e)
506615
case bool:
507616
if e {
508-
return "UInt8_1"
617+
return "Bool_1"
509618
}
510-
return "UInt8_0"
619+
return "Bool_0"
511620
default:
512621
return fmt.Sprintf("%v", e)
513622
}
@@ -559,14 +668,23 @@ func explainFunctionCallWithAlias(b *strings.Builder, fn *FunctionCall, alias st
559668
func explainWindowSpec(b *strings.Builder, spec *WindowSpec, depth int) {
560669
indent := strings.Repeat(" ", depth)
561670

562-
// Count children: partition by + order by
671+
// Count children: partition by + order by + frame bounds
563672
children := 0
564673
if len(spec.PartitionBy) > 0 {
565674
children++
566675
}
567676
if len(spec.OrderBy) > 0 {
568677
children++
569678
}
679+
// Count frame bound children
680+
if spec.Frame != nil {
681+
if spec.Frame.StartBound != nil && spec.Frame.StartBound.Offset != nil {
682+
children++
683+
}
684+
if spec.Frame.EndBound != nil && spec.Frame.EndBound.Offset != nil {
685+
children++
686+
}
687+
}
570688

571689
if children > 0 {
572690
fmt.Fprintf(b, "%sWindowDefinition (children %d)\n", indent, children)
@@ -589,6 +707,16 @@ func explainWindowSpec(b *strings.Builder, spec *WindowSpec, depth int) {
589707
explainOrderByElement(b, elem, depth+2)
590708
}
591709
}
710+
711+
// Frame bounds
712+
if spec.Frame != nil {
713+
if spec.Frame.StartBound != nil && spec.Frame.StartBound.Offset != nil {
714+
explainNode(b, spec.Frame.StartBound.Offset, depth+1)
715+
}
716+
if spec.Frame.EndBound != nil && spec.Frame.EndBound.Offset != nil {
717+
explainNode(b, spec.Frame.EndBound.Offset, depth+1)
718+
}
719+
}
592720
}
593721

594722
// explainTableJoin formats a table join.
@@ -734,6 +862,9 @@ func explainLambda(b *strings.Builder, l *Lambda, depth int) {
734862
// countSelectQueryChildren counts the non-nil children of a SelectQuery.
735863
func countSelectQueryChildren(s *SelectQuery) int {
736864
count := 0
865+
if len(s.With) > 0 {
866+
count++
867+
}
737868
if len(s.Columns) > 0 {
738869
count++
739870
}
@@ -872,3 +1003,157 @@ func extractFieldToFunction(field string) string {
8721003
return "to" + strings.Title(strings.ToLower(field))
8731004
}
8741005
}
1006+
1007+
// explainAlterCommand formats an ALTER command.
1008+
func explainAlterCommand(b *strings.Builder, cmd *AlterCommand, depth int) {
1009+
indent := strings.Repeat(" ", depth)
1010+
1011+
children := 0
1012+
if cmd.Column != nil {
1013+
children++
1014+
}
1015+
if cmd.ColumnName != "" && cmd.Type != AlterAddColumn && cmd.Type != AlterModifyColumn {
1016+
children++
1017+
}
1018+
if cmd.AfterColumn != "" {
1019+
children++
1020+
}
1021+
if cmd.Constraint != nil {
1022+
children++
1023+
}
1024+
if cmd.IndexExpr != nil {
1025+
children++
1026+
}
1027+
1028+
if children > 0 {
1029+
fmt.Fprintf(b, "%sAlterCommand %s (children %d)\n", indent, cmd.Type, children)
1030+
} else {
1031+
fmt.Fprintf(b, "%sAlterCommand %s\n", indent, cmd.Type)
1032+
}
1033+
1034+
if cmd.Column != nil {
1035+
explainColumnDeclaration(b, cmd.Column, depth+1)
1036+
}
1037+
if cmd.ColumnName != "" && cmd.Type != AlterAddColumn && cmd.Type != AlterModifyColumn {
1038+
fmt.Fprintf(b, "%s Identifier %s\n", indent, cmd.ColumnName)
1039+
}
1040+
if cmd.AfterColumn != "" {
1041+
fmt.Fprintf(b, "%s Identifier %s\n", indent, cmd.AfterColumn)
1042+
}
1043+
if cmd.Constraint != nil {
1044+
explainConstraint(b, cmd.Constraint, depth+1)
1045+
}
1046+
if cmd.IndexExpr != nil {
1047+
fmt.Fprintf(b, "%s Index (children 2)\n", indent)
1048+
explainNode(b, cmd.IndexExpr, depth+2)
1049+
if cmd.IndexType != "" {
1050+
fmt.Fprintf(b, "%s Function %s (children 1)\n", indent, cmd.IndexType)
1051+
fmt.Fprintf(b, "%s ExpressionList\n", indent)
1052+
}
1053+
}
1054+
}
1055+
1056+
// explainColumnDeclaration formats a column declaration.
1057+
func explainColumnDeclaration(b *strings.Builder, col *ColumnDeclaration, depth int) {
1058+
indent := strings.Repeat(" ", depth)
1059+
1060+
children := 0
1061+
if col.Type != nil {
1062+
children++
1063+
}
1064+
if col.Default != nil {
1065+
children++
1066+
}
1067+
1068+
fmt.Fprintf(b, "%sColumnDeclaration %s (children %d)\n", indent, col.Name, children)
1069+
if col.Type != nil {
1070+
fmt.Fprintf(b, "%s DataType %s\n", indent, col.Type.Name)
1071+
}
1072+
if col.Default != nil {
1073+
explainNode(b, col.Default, depth+1)
1074+
}
1075+
}
1076+
1077+
// explainConstraint formats a constraint.
1078+
func explainConstraint(b *strings.Builder, c *Constraint, depth int) {
1079+
indent := strings.Repeat(" ", depth)
1080+
fmt.Fprintf(b, "%sConstraint (children 1)\n", indent)
1081+
explainNode(b, c.Expression, depth+1)
1082+
}
1083+
1084+
// explainCreateQuery formats a CREATE query.
1085+
func explainCreateQuery(b *strings.Builder, n *CreateQuery, depth int) {
1086+
indent := strings.Repeat(" ", depth)
1087+
1088+
if n.CreateDatabase {
1089+
fmt.Fprintf(b, "%sCreateQuery %s (children 1)\n", indent, n.Database)
1090+
fmt.Fprintf(b, "%s Identifier %s\n", indent, n.Database)
1091+
return
1092+
}
1093+
1094+
var name string
1095+
if n.View != "" {
1096+
name = n.View
1097+
} else {
1098+
name = n.Table
1099+
}
1100+
if n.Database != "" {
1101+
name = n.Database + "." + name
1102+
}
1103+
1104+
children := 1 // identifier
1105+
if len(n.Columns) > 0 {
1106+
children++
1107+
}
1108+
if n.Engine != nil || len(n.OrderBy) > 0 {
1109+
children++
1110+
}
1111+
if n.AsSelect != nil {
1112+
children++
1113+
}
1114+
1115+
fmt.Fprintf(b, "%sCreateQuery %s (children %d)\n", indent, name, children)
1116+
fmt.Fprintf(b, "%s Identifier %s\n", indent, name)
1117+
1118+
if len(n.Columns) > 0 {
1119+
fmt.Fprintf(b, "%s Columns definition (children 1)\n", indent)
1120+
fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(n.Columns))
1121+
for _, col := range n.Columns {
1122+
explainColumnDeclaration(b, col, depth+3)
1123+
}
1124+
}
1125+
1126+
if n.Engine != nil || len(n.OrderBy) > 0 {
1127+
storageChildren := 0
1128+
if n.Engine != nil {
1129+
storageChildren++
1130+
}
1131+
if len(n.OrderBy) > 0 {
1132+
storageChildren++
1133+
}
1134+
fmt.Fprintf(b, "%s Storage definition (children %d)\n", indent, storageChildren)
1135+
if n.Engine != nil {
1136+
fmt.Fprintf(b, "%s Function %s (children 1)\n", indent, n.Engine.Name)
1137+
fmt.Fprintf(b, "%s ExpressionList\n", indent)
1138+
}
1139+
if len(n.OrderBy) > 0 {
1140+
// For simple ORDER BY, just output the identifier
1141+
if len(n.OrderBy) == 1 {
1142+
if id, ok := n.OrderBy[0].(*Identifier); ok {
1143+
fmt.Fprintf(b, "%s Identifier %s\n", indent, id.Name())
1144+
} else {
1145+
explainNode(b, n.OrderBy[0], depth+2)
1146+
}
1147+
} else {
1148+
fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(n.OrderBy))
1149+
for _, expr := range n.OrderBy {
1150+
explainNode(b, expr, depth+3)
1151+
}
1152+
}
1153+
}
1154+
}
1155+
1156+
if n.AsSelect != nil {
1157+
explainNode(b, n.AsSelect, depth+1)
1158+
}
1159+
}

0 commit comments

Comments
 (0)