Skip to content

Commit 1bc5dae

Browse files
authored
Add support for ANY/ALL comparison operators with subqueries (#110)
1 parent 3f729a2 commit 1bc5dae

File tree

149 files changed

+1046
-731
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

149 files changed

+1046
-731
lines changed

ast/ast.go

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ type SelectQuery struct {
7272
Having Expression `json:"having,omitempty"`
7373
Qualify Expression `json:"qualify,omitempty"`
7474
Window []*WindowDefinition `json:"window,omitempty"`
75-
OrderBy []*OrderByElement `json:"order_by,omitempty"`
75+
OrderBy []*OrderByElement `json:"order_by,omitempty"`
76+
Interpolate []*InterpolateElement `json:"interpolate,omitempty"`
7677
Limit Expression `json:"limit,omitempty"`
7778
LimitBy []Expression `json:"limit_by,omitempty"`
7879
LimitByLimit Expression `json:"limit_by_limit,omitempty"` // LIMIT value before BY (e.g., LIMIT 1 BY x LIMIT 3)
@@ -212,6 +213,17 @@ type OrderByElement struct {
212213
func (o *OrderByElement) Pos() token.Position { return o.Position }
213214
func (o *OrderByElement) End() token.Position { return o.Position }
214215

216+
// InterpolateElement represents a single column interpolation in INTERPOLATE clause.
217+
// Example: INTERPOLATE (value AS value + 1)
218+
type InterpolateElement struct {
219+
Position token.Position `json:"-"`
220+
Column string `json:"column"`
221+
Value Expression `json:"value,omitempty"` // nil if just column name
222+
}
223+
224+
func (i *InterpolateElement) Pos() token.Position { return i.Position }
225+
func (i *InterpolateElement) End() token.Position { return i.Position }
226+
215227
// SettingExpr represents a setting expression.
216228
type SettingExpr struct {
217229
Position token.Position `json:"-"`
@@ -284,6 +296,7 @@ type CreateQuery struct {
284296
FunctionName string `json:"function_name,omitempty"`
285297
FunctionBody Expression `json:"function_body,omitempty"`
286298
UserName string `json:"user_name,omitempty"`
299+
Format string `json:"format,omitempty"` // For FORMAT clause
287300
}
288301

289302
func (c *CreateQuery) Pos() token.Position { return c.Position }
@@ -493,6 +506,7 @@ type DropQuery struct {
493506
OnCluster string `json:"on_cluster,omitempty"`
494507
DropDatabase bool `json:"drop_database,omitempty"`
495508
Sync bool `json:"sync,omitempty"`
509+
Format string `json:"format,omitempty"` // For FORMAT clause
496510
}
497511

498512
func (d *DropQuery) Pos() token.Position { return d.Position }
@@ -512,6 +526,20 @@ func (u *UndropQuery) Pos() token.Position { return u.Position }
512526
func (u *UndropQuery) End() token.Position { return u.Position }
513527
func (u *UndropQuery) statementNode() {}
514528

529+
// UpdateQuery represents a standalone UPDATE statement.
530+
// In ClickHouse, UPDATE is syntactic sugar for ALTER TABLE ... UPDATE
531+
type UpdateQuery struct {
532+
Position token.Position `json:"-"`
533+
Database string `json:"database,omitempty"`
534+
Table string `json:"table"`
535+
Assignments []*Assignment `json:"assignments"`
536+
Where Expression `json:"where,omitempty"`
537+
}
538+
539+
func (u *UpdateQuery) Pos() token.Position { return u.Position }
540+
func (u *UpdateQuery) End() token.Position { return u.Position }
541+
func (u *UpdateQuery) statementNode() {}
542+
515543
// AlterQuery represents an ALTER statement.
516544
type AlterQuery struct {
517545
Position token.Position `json:"-"`
@@ -520,6 +548,7 @@ type AlterQuery struct {
520548
Commands []*AlterCommand `json:"commands"`
521549
OnCluster string `json:"on_cluster,omitempty"`
522550
Settings []*SettingExpr `json:"settings,omitempty"`
551+
Format string `json:"format,omitempty"` // For FORMAT clause
523552
}
524553

525554
func (a *AlterQuery) Pos() token.Position { return a.Position }
@@ -624,6 +653,7 @@ const (
624653
AlterDropStatistics AlterCommandType = "DROP_STATISTICS"
625654
AlterClearStatistics AlterCommandType = "CLEAR_STATISTICS"
626655
AlterMaterializeStatistics AlterCommandType = "MATERIALIZE_STATISTICS"
656+
AlterModifyComment AlterCommandType = "MODIFY_COMMENT"
627657
)
628658

629659
// TruncateQuery represents a TRUNCATE statement.

internal/explain/explain.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ import (
1212
// This affects how negated literals with aliases are formatted
1313
var inSubqueryContext bool
1414

15+
// inCreateQueryContext is a package-level flag to track when we're inside a CreateQuery
16+
// This affects whether FORMAT is output at SelectWithUnionQuery level (it shouldn't be, as CreateQuery outputs it)
17+
var inCreateQueryContext bool
18+
1519
// Explain returns the EXPLAIN AST output for a statement, matching ClickHouse's format.
1620
func Explain(stmt ast.Statement) string {
1721
var sb strings.Builder
@@ -57,6 +61,8 @@ func Node(sb *strings.Builder, node interface{}, depth int) {
5761
// Expressions
5862
case *ast.OrderByElement:
5963
explainOrderByElement(sb, n, indent, depth)
64+
case *ast.InterpolateElement:
65+
explainInterpolateElement(sb, n, indent, depth)
6066
case *ast.Identifier:
6167
explainIdentifier(sb, n, indent)
6268
case *ast.Literal:
@@ -236,6 +242,8 @@ func Node(sb *strings.Builder, node interface{}, depth int) {
236242
explainCheckQuery(sb, n, indent)
237243
case *ast.CreateIndexQuery:
238244
explainCreateIndexQuery(sb, n, indent, depth)
245+
case *ast.UpdateQuery:
246+
explainUpdateQuery(sb, n, indent, depth)
239247

240248
// Types
241249
case *ast.DataType:
@@ -262,6 +270,8 @@ func Node(sb *strings.Builder, node interface{}, depth int) {
262270
explainDictionaryLayout(sb, n, indent, depth)
263271
case *ast.DictionaryRange:
264272
explainDictionaryRange(sb, n, indent, depth)
273+
case *ast.Assignment:
274+
explainAssignment(sb, n, indent, depth)
265275

266276
default:
267277
// For unhandled types, just print the type name

internal/explain/expressions.go

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,31 @@ func containsNonLiteralExpressions(exprs []ast.Expression) bool {
235235
return false
236236
}
237237

238+
// containsNonLiteralInNested checks if an array or tuple literal contains
239+
// non-literal elements at any nesting level (identifiers, function calls, etc.)
240+
func containsNonLiteralInNested(lit *ast.Literal) bool {
241+
if lit.Type != ast.LiteralArray && lit.Type != ast.LiteralTuple {
242+
return false
243+
}
244+
exprs, ok := lit.Value.([]ast.Expression)
245+
if !ok {
246+
return false
247+
}
248+
for _, e := range exprs {
249+
// Check if this element is a non-literal (identifier, function call, etc.)
250+
if _, isLit := e.(*ast.Literal); !isLit {
251+
return true
252+
}
253+
// Recursively check nested arrays/tuples
254+
if innerLit, ok := e.(*ast.Literal); ok {
255+
if containsNonLiteralInNested(innerLit) {
256+
return true
257+
}
258+
}
259+
}
260+
return false
261+
}
262+
238263
// containsTuples checks if a slice of expressions contains any tuple literals
239264
func containsTuples(exprs []ast.Expression) bool {
240265
for _, e := range exprs {
@@ -377,10 +402,23 @@ func explainUnaryExpr(sb *strings.Builder, n *ast.UnaryExpr, indent string, dept
377402
// Convert positive integer to negative
378403
switch val := lit.Value.(type) {
379404
case int64:
380-
fmt.Fprintf(sb, "%sLiteral Int64_%d\n", indent, -val)
405+
negVal := -val
406+
// ClickHouse normalizes -0 to UInt64_0
407+
if negVal == 0 {
408+
fmt.Fprintf(sb, "%sLiteral UInt64_0\n", indent)
409+
} else if negVal > 0 {
410+
fmt.Fprintf(sb, "%sLiteral UInt64_%d\n", indent, negVal)
411+
} else {
412+
fmt.Fprintf(sb, "%sLiteral Int64_%d\n", indent, negVal)
413+
}
381414
return
382415
case uint64:
383-
fmt.Fprintf(sb, "%sLiteral Int64_-%d\n", indent, val)
416+
// ClickHouse normalizes -0 to UInt64_0
417+
if val == 0 {
418+
fmt.Fprintf(sb, "%sLiteral UInt64_0\n", indent)
419+
} else {
420+
fmt.Fprintf(sb, "%sLiteral Int64_-%d\n", indent, val)
421+
}
384422
return
385423
}
386424
case ast.LiteralFloat:
@@ -432,11 +470,23 @@ func explainAliasedExpr(sb *strings.Builder, n *ast.AliasedExpr, depth int) {
432470
needsFunctionFormat = true
433471
break
434472
}
473+
// Also check if nested arrays/tuples contain non-literal elements
474+
if lit, ok := expr.(*ast.Literal); ok {
475+
if containsNonLiteralInNested(lit) {
476+
needsFunctionFormat = true
477+
break
478+
}
479+
}
435480
}
436481
if needsFunctionFormat {
437482
// Render as Function tuple with alias
438483
fmt.Fprintf(sb, "%sFunction tuple (alias %s) (children %d)\n", indent, escapeAlias(n.Alias), 1)
439-
fmt.Fprintf(sb, "%s ExpressionList (children %d)\n", indent, len(exprs))
484+
// For empty ExpressionList, don't include children count
485+
if len(exprs) > 0 {
486+
fmt.Fprintf(sb, "%s ExpressionList (children %d)\n", indent, len(exprs))
487+
} else {
488+
fmt.Fprintf(sb, "%s ExpressionList\n", indent)
489+
}
440490
for _, expr := range exprs {
441491
Node(sb, expr, depth+2)
442492
}
@@ -463,6 +513,11 @@ func explainAliasedExpr(sb *strings.Builder, n *ast.AliasedExpr, depth int) {
463513
needsFunctionFormat = true
464514
break
465515
}
516+
// Check for function calls - use Function array
517+
if _, ok := expr.(*ast.FunctionCall); ok {
518+
needsFunctionFormat = true
519+
break
520+
}
466521
}
467522
if needsFunctionFormat {
468523
// Render as Function array with alias
@@ -577,6 +632,9 @@ func explainAliasedExpr(sb *strings.Builder, n *ast.AliasedExpr, depth int) {
577632
case *ast.CaseExpr:
578633
// CASE expressions with alias
579634
explainCaseExprWithAlias(sb, e, n.Alias, indent, depth)
635+
case *ast.ExistsExpr:
636+
// EXISTS expressions with alias
637+
explainExistsExprWithAlias(sb, e, n.Alias, indent, depth)
580638
default:
581639
// For other types, recursively explain and add alias info
582640
Node(sb, n.Expr, depth)

internal/explain/format.go

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,36 @@ func escapeStringLiteral(s string) string {
7474
return sb.String()
7575
}
7676

77+
// escapeStringForTypeParam escapes special characters for use in type parameters
78+
// Uses extra escaping because type strings are embedded inside another string literal
79+
func escapeStringForTypeParam(s string) string {
80+
var sb strings.Builder
81+
for i := 0; i < len(s); i++ {
82+
b := s[i]
83+
switch b {
84+
case '\\':
85+
sb.WriteString("\\\\\\\\\\\\\\\\") // backslash becomes 8 backslashes
86+
case '\'':
87+
sb.WriteString("\\\\\\\\\\'") // single quote becomes 5 backslashes + quote
88+
case '\n':
89+
sb.WriteString("\\\\\\\\n") // newline becomes \\\\n
90+
case '\t':
91+
sb.WriteString("\\\\\\\\t") // tab becomes \\\\t
92+
case '\r':
93+
sb.WriteString("\\\\\\\\r") // carriage return becomes \\\\r
94+
case '\x00':
95+
sb.WriteString("\\\\\\\\0") // null becomes \\\\0
96+
case '\b':
97+
sb.WriteString("\\\\\\\\b") // backspace becomes \\\\b
98+
case '\f':
99+
sb.WriteString("\\\\\\\\f") // form feed becomes \\\\f
100+
default:
101+
sb.WriteByte(b)
102+
}
103+
}
104+
return sb.String()
105+
}
106+
77107
// FormatLiteral formats a literal value for EXPLAIN AST output
78108
func FormatLiteral(lit *ast.Literal) string {
79109
switch lit.Type {
@@ -270,7 +300,9 @@ func formatBinaryExprForType(expr *ast.BinaryExpr) string {
270300
// Format left side
271301
if lit, ok := expr.Left.(*ast.Literal); ok {
272302
if lit.Type == ast.LiteralString {
273-
left = fmt.Sprintf("\\\\\\'%s\\\\\\'", lit.Value)
303+
// Use extra escaping for type parameters since they're embedded in another string literal
304+
escaped := escapeStringForTypeParam(fmt.Sprintf("%v", lit.Value))
305+
left = fmt.Sprintf("\\\\\\'%s\\\\\\'", escaped)
274306
} else {
275307
left = fmt.Sprintf("%v", lit.Value)
276308
}
@@ -285,13 +317,24 @@ func formatBinaryExprForType(expr *ast.BinaryExpr) string {
285317
right = fmt.Sprintf("%v", lit.Value)
286318
} else if ident, ok := expr.Right.(*ast.Identifier); ok {
287319
right = ident.Name()
320+
} else if unary, ok := expr.Right.(*ast.UnaryExpr); ok {
321+
// Handle unary expressions like -100
322+
right = formatUnaryExprForType(unary)
288323
} else {
289324
right = fmt.Sprintf("%v", expr.Right)
290325
}
291326

292327
return left + " " + expr.Op + " " + right
293328
}
294329

330+
// formatUnaryExprForType formats a unary expression for use in type parameters (e.g., -100)
331+
func formatUnaryExprForType(expr *ast.UnaryExpr) string {
332+
if lit, ok := expr.Operand.(*ast.Literal); ok {
333+
return expr.Op + fmt.Sprintf("%v", lit.Value)
334+
}
335+
return expr.Op + fmt.Sprintf("%v", expr.Operand)
336+
}
337+
295338
// NormalizeFunctionName normalizes function names to match ClickHouse's EXPLAIN AST output
296339
func NormalizeFunctionName(name string) string {
297340
// ClickHouse normalizes certain function names in EXPLAIN AST
@@ -314,9 +357,9 @@ func NormalizeFunctionName(name string) string {
314357
"least": "least",
315358
"concat_ws": "concat",
316359
"position": "position",
317-
// SQL standard ANY/ALL subquery operators
318-
"anymatch": "in",
319-
"allmatch": "notIn",
360+
// SQL standard ANY/ALL subquery operators - simple cases
361+
"anyequals": "in",
362+
"allnotequals": "notIn",
320363
}
321364
if n, ok := normalized[strings.ToLower(name)]; ok {
322365
return n
@@ -351,6 +394,8 @@ func OperatorToFunction(op string) string {
351394
return "lessOrEquals"
352395
case ">=":
353396
return "greaterOrEquals"
397+
case "<=>":
398+
return "isNotDistinctFrom"
354399
case "AND":
355400
return "and"
356401
case "OR":

0 commit comments

Comments
 (0)