diff --git a/pkg/sql/ast/ast.go b/pkg/sql/ast/ast.go index 08942754..dc0d280a 100644 --- a/pkg/sql/ast/ast.go +++ b/pkg/sql/ast/ast.go @@ -309,21 +309,57 @@ func (e ExistsExpression) Children() []Node { return []Node{e.Subquery} } -// InExpression represents expr IN (values) +// InExpression represents expr IN (values) or expr IN (subquery) type InExpression struct { - Expr Expression - List []Expression - Not bool + Expr Expression + List []Expression // For value list: IN (1, 2, 3) + Subquery Statement // For subquery: IN (SELECT ...) + Not bool } func (i *InExpression) expressionNode() {} func (i InExpression) TokenLiteral() string { return "IN" } func (i InExpression) Children() []Node { children := []Node{i.Expr} - children = append(children, nodifyExpressions(i.List)...) + if i.Subquery != nil { + children = append(children, i.Subquery) + } else { + children = append(children, nodifyExpressions(i.List)...) + } return children } +// SubqueryExpression represents a scalar subquery (SELECT ...) +type SubqueryExpression struct { + Subquery Statement +} + +func (s *SubqueryExpression) expressionNode() {} +func (s SubqueryExpression) TokenLiteral() string { return "SUBQUERY" } +func (s SubqueryExpression) Children() []Node { return []Node{s.Subquery} } + +// AnyExpression represents expr op ANY (subquery) +type AnyExpression struct { + Expr Expression + Operator string + Subquery Statement +} + +func (a *AnyExpression) expressionNode() {} +func (a AnyExpression) TokenLiteral() string { return "ANY" } +func (a AnyExpression) Children() []Node { return []Node{a.Expr, a.Subquery} } + +// AllExpression represents expr op ALL (subquery) +type AllExpression struct { + Expr Expression + Operator string + Subquery Statement +} + +func (al *AllExpression) expressionNode() {} +func (al AllExpression) TokenLiteral() string { return "ALL" } +func (al AllExpression) Children() []Node { return []Node{al.Expr, al.Subquery} } + // BetweenExpression represents expr BETWEEN lower AND upper type BetweenExpression struct { Expr Expression diff --git a/pkg/sql/parser/parser.go b/pkg/sql/parser/parser.go index f1154e3a..d18abd7e 100644 --- a/pkg/sql/parser/parser.go +++ b/pkg/sql/parser/parser.go @@ -382,7 +382,7 @@ func (p *Parser) parseComparisonExpression() (ast.Expression, error) { } } - // Check for BETWEEN operator: expr [NOT] BETWEEN lower AND upper + // Check for BETWEEN operator if p.currentToken.Type == "BETWEEN" { p.advance() // Consume BETWEEN @@ -392,9 +392,9 @@ func (p *Parser) parseComparisonExpression() (ast.Expression, error) { return nil, fmt.Errorf("failed to parse BETWEEN lower bound: %w", err) } - // Expect AND + // Expect AND keyword if p.currentToken.Type != "AND" { - return nil, p.expectedError("AND in BETWEEN expression") + return nil, p.expectedError("AND") } p.advance() // Consume AND @@ -412,7 +412,7 @@ func (p *Parser) parseComparisonExpression() (ast.Expression, error) { }, nil } - // Check for LIKE/ILIKE operator: expr [NOT] LIKE pattern [ESCAPE escape_char] + // Check for LIKE/ILIKE operator if p.currentToken.Type == "LIKE" || p.currentToken.Type == "ILIKE" { operator := p.currentToken.Literal p.advance() // Consume LIKE/ILIKE @@ -431,36 +431,54 @@ func (p *Parser) parseComparisonExpression() (ast.Expression, error) { }, nil } - // Check for IN operator: expr [NOT] IN (values) + // Check for IN operator if p.currentToken.Type == "IN" { p.advance() // Consume IN // Expect opening parenthesis if p.currentToken.Type != "(" { - return nil, p.expectedError("( after IN") + return nil, p.expectedError("(") } p.advance() // Consume ( + // Check if this is a subquery (starts with SELECT or WITH) + if p.currentToken.Type == "SELECT" || p.currentToken.Type == "WITH" { + // Parse subquery + subquery, err := p.parseSubquery() + if err != nil { + return nil, fmt.Errorf("failed to parse IN subquery: %w", err) + } + + // Expect closing parenthesis + if p.currentToken.Type != ")" { + return nil, p.expectedError(")") + } + p.advance() // Consume ) + + return &ast.InExpression{ + Expr: left, + Subquery: subquery, + Not: notPrefix, + }, nil + } + // Parse value list var values []ast.Expression for { - val, err := p.parseExpression() + value, err := p.parseExpression() if err != nil { return nil, fmt.Errorf("failed to parse IN value: %w", err) } - values = append(values, val) + values = append(values, value) if p.currentToken.Type == "," { p.advance() // Consume comma - } else { + } else if p.currentToken.Type == ")" { break + } else { + return nil, p.expectedError(", or )") } } - - // Expect closing parenthesis - if p.currentToken.Type != ")" { - return nil, p.expectedError(") to close IN list") - } p.advance() // Consume ) return &ast.InExpression{ @@ -470,7 +488,13 @@ func (p *Parser) parseComparisonExpression() (ast.Expression, error) { }, nil } - // Check for IS NULL / IS NOT NULL: expr IS [NOT] NULL + // If NOT was consumed but no BETWEEN/LIKE/IN follows, we need to handle this case + // Put back the NOT by creating a NOT expression with left as the operand + if notPrefix { + return nil, fmt.Errorf("expected BETWEEN, LIKE, or IN after NOT") + } + + // Check for IS NULL / IS NOT NULL if p.currentToken.Type == "IS" { p.advance() // Consume IS @@ -490,17 +514,55 @@ func (p *Parser) parseComparisonExpression() (ast.Expression, error) { }, nil } - return nil, p.expectedError("NULL after IS") + return nil, p.expectedError("NULL") } // Check if this is a comparison binary expression if p.currentToken.Type == "=" || p.currentToken.Type == "<" || p.currentToken.Type == ">" || p.currentToken.Type == "!=" || - p.currentToken.Type == "<=" || p.currentToken.Type == ">=" { + p.currentToken.Type == "<=" || p.currentToken.Type == ">=" || + p.currentToken.Type == "<>" { // Save the operator operator := p.currentToken.Literal p.advance() + // Check for ANY/ALL subquery operators + if p.currentToken.Type == "ANY" || p.currentToken.Type == "ALL" { + quantifier := p.currentToken.Type + p.advance() // Consume ANY/ALL + + // Expect opening parenthesis + if p.currentToken.Type != "(" { + return nil, p.expectedError("(") + } + p.advance() // Consume ( + + // Parse subquery + subquery, err := p.parseSubquery() + if err != nil { + return nil, fmt.Errorf("failed to parse %s subquery: %w", quantifier, err) + } + + // Expect closing parenthesis + if p.currentToken.Type != ")" { + return nil, p.expectedError(")") + } + p.advance() // Consume ) + + if quantifier == "ANY" { + return &ast.AnyExpression{ + Expr: left, + Operator: operator, + Subquery: subquery, + }, nil + } + return &ast.AllExpression{ + Expr: left, + Operator: operator, + Subquery: subquery, + }, nil + } + // Parse the right side of the expression right, err := p.parsePrimaryExpression() if err != nil { @@ -521,6 +583,10 @@ func (p *Parser) parseComparisonExpression() (ast.Expression, error) { // parsePrimaryExpression parses a primary expression (literals, identifiers, function calls) func (p *Parser) parsePrimaryExpression() (ast.Expression, error) { switch p.currentToken.Type { + case "CASE": + // Handle CASE expressions (both simple and searched forms) + return p.parseCaseExpression() + case "IDENT": // Handle identifiers and function calls identName := p.currentToken.Literal @@ -590,12 +656,101 @@ func (p *Parser) parsePrimaryExpression() (ast.Expression, error) { p.advance() return &ast.LiteralValue{Value: value, Type: "placeholder"}, nil + case "NULL": + // Handle NULL literal + p.advance() + return &ast.LiteralValue{Value: nil, Type: "null"}, nil + + case "(": + // Handle parenthesized expression or subquery + p.advance() // Consume ( + + // Check if this is a subquery (starts with SELECT or WITH) + if p.currentToken.Type == "SELECT" || p.currentToken.Type == "WITH" { + // Parse subquery + subquery, err := p.parseSubquery() + if err != nil { + return nil, fmt.Errorf("failed to parse subquery: %w", err) + } + // Expect closing parenthesis + if p.currentToken.Type != ")" { + return nil, p.expectedError(")") + } + p.advance() // Consume ) + return &ast.SubqueryExpression{Subquery: subquery}, nil + } + + // Regular parenthesized expression + expr, err := p.parseExpression() + if err != nil { + return nil, err + } + + // Expect closing parenthesis + if p.currentToken.Type != ")" { + return nil, p.expectedError(")") + } + p.advance() // Consume ) + return expr, nil + + case "EXISTS": + // Handle EXISTS (subquery) + p.advance() // Consume EXISTS + + // Expect opening parenthesis + if p.currentToken.Type != "(" { + return nil, p.expectedError("(") + } + p.advance() // Consume ( + + // Parse the subquery + subquery, err := p.parseSubquery() + if err != nil { + return nil, fmt.Errorf("failed to parse EXISTS subquery: %w", err) + } + + // Expect closing parenthesis + if p.currentToken.Type != ")" { + return nil, p.expectedError(")") + } + p.advance() // Consume ) + + return &ast.ExistsExpression{Subquery: subquery}, nil + case "NOT": - // Handle NOT as unary operator for boolean negation - // e.g., WHERE NOT active, WHERE NOT (a AND b) + // Handle NOT expression (NOT EXISTS, NOT boolean) p.advance() // Consume NOT - // Parse the following expression at comparison level - // This handles: NOT active, NOT (a > b), NOT EXISTS (...) + + if p.currentToken.Type == "EXISTS" { + // NOT EXISTS (subquery) + p.advance() // Consume EXISTS + + if p.currentToken.Type != "(" { + return nil, p.expectedError("(") + } + p.advance() // Consume ( + + subquery, err := p.parseSubquery() + if err != nil { + return nil, fmt.Errorf("failed to parse NOT EXISTS subquery: %w", err) + } + + if p.currentToken.Type != ")" { + return nil, p.expectedError(")") + } + p.advance() // Consume ) + + // Return NOT EXISTS as a BinaryExpression with NOT flag + return &ast.BinaryExpression{ + Left: &ast.ExistsExpression{Subquery: subquery}, + Operator: "NOT", + Right: nil, + Not: true, + }, nil + } + + // NOT followed by other expression (boolean negation) + // Parse at comparison level for proper precedence: NOT (a > b), NOT active expr, err := p.parseComparisonExpression() if err != nil { return nil, err @@ -610,6 +765,98 @@ func (p *Parser) parsePrimaryExpression() (ast.Expression, error) { } } +// parseCaseExpression parses a CASE expression (both simple and searched forms) +// +// Simple CASE: CASE expr WHEN value THEN result ... [ELSE result] END +// Searched CASE: CASE WHEN condition THEN result ... [ELSE result] END +func (p *Parser) parseCaseExpression() (*ast.CaseExpression, error) { + p.advance() // Consume CASE + + caseExpr := &ast.CaseExpression{ + WhenClauses: make([]ast.WhenClause, 0), + } + + // Check if this is a simple CASE (has a value expression) or searched CASE (no value) + // Simple CASE: CASE expr WHEN value THEN result + // Searched CASE: CASE WHEN condition THEN result + if p.currentToken.Type != "WHEN" { + // This is a simple CASE - parse the value expression + value, err := p.parseExpression() + if err != nil { + return nil, fmt.Errorf("failed to parse CASE value: %w", err) + } + caseExpr.Value = value + } + + // Parse WHEN clauses (at least one required) + for p.currentToken.Type == "WHEN" { + p.advance() // Consume WHEN + + // Parse the condition/value expression + condition, err := p.parseExpression() + if err != nil { + return nil, fmt.Errorf("failed to parse WHEN condition: %w", err) + } + + // Expect THEN keyword + if p.currentToken.Type != "THEN" { + return nil, p.expectedError("THEN") + } + p.advance() // Consume THEN + + // Parse the result expression + result, err := p.parseExpression() + if err != nil { + return nil, fmt.Errorf("failed to parse THEN result: %w", err) + } + + caseExpr.WhenClauses = append(caseExpr.WhenClauses, ast.WhenClause{ + Condition: condition, + Result: result, + }) + } + + // Check that we have at least one WHEN clause + if len(caseExpr.WhenClauses) == 0 { + return nil, fmt.Errorf("CASE expression requires at least one WHEN clause") + } + + // Parse optional ELSE clause + if p.currentToken.Type == "ELSE" { + p.advance() // Consume ELSE + + elseResult, err := p.parseExpression() + if err != nil { + return nil, fmt.Errorf("failed to parse ELSE result: %w", err) + } + caseExpr.ElseClause = elseResult + } + + // Expect END keyword + if p.currentToken.Type != "END" { + return nil, p.expectedError("END") + } + p.advance() // Consume END + + return caseExpr, nil +} + +// parseSubquery parses a subquery (SELECT or WITH statement). +// Expects current token to be SELECT or WITH. +func (p *Parser) parseSubquery() (ast.Statement, error) { + if p.currentToken.Type == "WITH" { + // WITH statement handles its own token consumption + return p.parseWithStatement() + } + + if p.currentToken.Type == "SELECT" { + p.advance() // Consume SELECT + return p.parseSelectWithSetOperations() + } + + return nil, fmt.Errorf("expected SELECT or WITH, got %s", p.currentToken.Type) +} + // parseFunctionCall parses a function call with optional OVER clause for window functions. // // Examples: diff --git a/pkg/sql/parser/parser_coverage_test.go b/pkg/sql/parser/parser_coverage_test.go index f439a3a4..55d1a60b 100644 --- a/pkg/sql/parser/parser_coverage_test.go +++ b/pkg/sql/parser/parser_coverage_test.go @@ -303,7 +303,7 @@ func TestParser_ExpressionEdgeCases(t *testing.T) { {Type: "TRUE", Literal: "TRUE"}, {Type: ")", Literal: ")"}, }, - wantErr: true, // NOT with parentheses not yet supported + wantErr: false, // NOT with parentheses now supported }, { name: "BETWEEN expression", @@ -623,7 +623,7 @@ func TestParser_CTEEdgeCases(t *testing.T) { {Type: "IDENT", Literal: "active"}, {Type: ")", Literal: ")"}, }, - wantErr: true, // Subqueries in WHERE not yet supported + wantErr: false, // Subqueries in WHERE now supported }, { name: "CTE with DELETE statement", @@ -654,7 +654,7 @@ func TestParser_CTEEdgeCases(t *testing.T) { {Type: "IDENT", Literal: "old_records"}, {Type: ")", Literal: ")"}, }, - wantErr: true, // Subqueries in WHERE not yet supported + wantErr: false, // Subqueries in WHERE now supported }, } @@ -777,7 +777,7 @@ func TestParser_TableDrivenComplexScenarios(t *testing.T) { {Type: "STRING", Literal: "USA"}, {Type: ")", Literal: ")"}, }, - wantErr: true, // Subqueries in WHERE not yet supported + wantErr: false, // Subqueries in WHERE now supported }, { name: "CASE expression in SELECT", @@ -797,7 +797,7 @@ func TestParser_TableDrivenComplexScenarios(t *testing.T) { {Type: "FROM", Literal: "FROM"}, {Type: "IDENT", Literal: "users"}, }, - wantErr: true, // CASE expressions not yet supported + wantErr: false, // CASE expressions now supported }, { name: "DISTINCT with aggregate", diff --git a/pkg/sql/parser/parser_error_recovery_test.go b/pkg/sql/parser/parser_error_recovery_test.go index 427f2f60..4e55b336 100644 --- a/pkg/sql/parser/parser_error_recovery_test.go +++ b/pkg/sql/parser/parser_error_recovery_test.go @@ -298,7 +298,7 @@ func TestUpdateStatement_EdgeCases(t *testing.T) { { name: "UPDATE with NULL", sql: "UPDATE users SET deleted_at = NULL WHERE id = 1", - shouldErr: true, // NULL not supported as value + shouldErr: false, // NULL now supported as value }, { name: "UPDATE without WHERE (dangerous)", @@ -565,7 +565,7 @@ func TestExpressions_BoundaryValues(t *testing.T) { { name: "NULL comparison", sql: "SELECT * FROM users WHERE deleted_at = NULL", - shouldErr: true, // NULL not supported as value, should use IS NULL + shouldErr: false, // NULL now supported as value (though IS NULL is preferred) }, { name: "true boolean",