diff --git a/pkg/sql/ast/ast.go b/pkg/sql/ast/ast.go index f7250a93..f105bc92 100644 --- a/pkg/sql/ast/ast.go +++ b/pkg/sql/ast/ast.go @@ -246,6 +246,8 @@ type TableReference struct { // Unpivot is the SQL Server / Oracle UNPIVOT clause for column-to-row transformation. // Example: SELECT * FROM t UNPIVOT (sales FOR region IN (north_sales, south_sales)) AS unpvt Unpivot *UnpivotClause + // MatchRecognize is the SQL:2016 row-pattern recognition clause (Snowflake, Oracle). + MatchRecognize *MatchRecognizeClause } func (t *TableReference) statementNode() {} @@ -275,6 +277,9 @@ func (t TableReference) Children() []Node { if t.Unpivot != nil { nodes = append(nodes, t.Unpivot) } + if t.MatchRecognize != nil { + nodes = append(nodes, t.MatchRecognize) + } return nodes } @@ -2147,6 +2152,62 @@ type PeriodDefinition struct { Pos models.Location // Source position of the PERIOD FOR keyword (1-based line and column) } +// MatchRecognizeClause represents the SQL:2016 MATCH_RECOGNIZE clause for +// row-pattern recognition in a FROM clause (Snowflake, Oracle, Databricks). +// +// MATCH_RECOGNIZE ( +// PARTITION BY symbol +// ORDER BY ts +// MEASURES MATCH_NUMBER() AS m +// ALL ROWS PER MATCH +// PATTERN (UP+ DOWN+) +// DEFINE UP AS price > PREV(price), DOWN AS price < PREV(price) +// ) +type MatchRecognizeClause struct { + PartitionBy []Expression + OrderBy []OrderByExpression + Measures []MeasureDef + RowsPerMatch string // "ONE ROW PER MATCH" or "ALL ROWS PER MATCH" (empty = default) + AfterMatch string // raw text: "SKIP TO NEXT ROW", "SKIP PAST LAST ROW", etc. + Pattern string // raw pattern text: "UP+ DOWN+" + Definitions []PatternDef + Pos models.Location +} + +// MeasureDef is one MEASURES entry: expr AS alias. +type MeasureDef struct { + Expr Expression + Alias string +} + +// PatternDef is one DEFINE entry: variable_name AS boolean_condition. +type PatternDef struct { + Name string + Condition Expression +} + +func (m *MatchRecognizeClause) expressionNode() {} +func (m MatchRecognizeClause) TokenLiteral() string { return "MATCH_RECOGNIZE" } +func (m MatchRecognizeClause) Children() []Node { + var nodes []Node + nodes = append(nodes, nodifyExpressions(m.PartitionBy)...) + for _, ob := range m.OrderBy { + ob := ob + nodes = append(nodes, &ob) + } + for _, md := range m.Measures { + if md.Expr != nil { + nodes = append(nodes, md.Expr) + } + } + for _, pd := range m.Definitions { + if pd.Condition != nil { + nodes = append(nodes, pd.Condition) + } + } + return nodes +} + // expressionNode satisfies the Expression interface so PeriodDefinition can be // stored in CreateTableStatement.PeriodDefinitions without a separate interface type. // Semantically it is a table column constraint, not a scalar expression. diff --git a/pkg/sql/parser/expressions_literal.go b/pkg/sql/parser/expressions_literal.go index 350f6c94..ad9b8165 100644 --- a/pkg/sql/parser/expressions_literal.go +++ b/pkg/sql/parser/expressions_literal.go @@ -104,8 +104,9 @@ func (p *Parser) parsePrimaryExpression() (ast.Expression, error) { return funcCall, nil } - // Handle keywords that can be used as function names in MySQL (IF, REPLACE, etc.) - if (p.isType(models.TokenTypeIf) || p.isType(models.TokenTypeReplace)) && p.peekToken().Token.Type == models.TokenTypeLParen { + // Handle keywords that can be used as function names (IF, REPLACE, FIRST, LAST, etc.) + if (p.isType(models.TokenTypeIf) || p.isType(models.TokenTypeReplace) || + p.isType(models.TokenTypeFirst) || p.isType(models.TokenTypeLast)) && p.peekToken().Token.Type == models.TokenTypeLParen { kwPos := p.currentLocation() identName := p.currentToken.Token.Value p.advance() diff --git a/pkg/sql/parser/match_recognize.go b/pkg/sql/parser/match_recognize.go new file mode 100644 index 00000000..2bd7fa31 --- /dev/null +++ b/pkg/sql/parser/match_recognize.go @@ -0,0 +1,228 @@ +// Copyright 2026 GoSQLX Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 + +// Package parser - match_recognize.go +// SQL:2016 MATCH_RECOGNIZE clause for row-pattern recognition (Snowflake, Oracle). + +package parser + +import ( + "strings" + + "github.com/ajitpratap0/GoSQLX/pkg/models" + "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" + "github.com/ajitpratap0/GoSQLX/pkg/sql/keywords" +) + +// isMatchRecognizeKeyword returns true if the current token is the contextual +// MATCH_RECOGNIZE keyword in a dialect that supports it. +func (p *Parser) isMatchRecognizeKeyword() bool { + if p.dialect != string(keywords.DialectSnowflake) && + p.dialect != string(keywords.DialectOracle) { + return false + } + return strings.EqualFold(p.currentToken.Token.Value, "MATCH_RECOGNIZE") +} + +// parseMatchRecognize parses the MATCH_RECOGNIZE clause. The current token +// must be MATCH_RECOGNIZE. +// +// Grammar: +// +// MATCH_RECOGNIZE ( +// [PARTITION BY expr, ...] +// [ORDER BY expr [ASC|DESC], ...] +// [MEASURES measure_expr AS alias, ...] +// [ONE ROW PER MATCH | ALL ROWS PER MATCH] +// [AFTER MATCH SKIP ...] +// PATTERN ( pattern_regex ) +// DEFINE var AS condition, ... +// ) +func (p *Parser) parseMatchRecognize() (*ast.MatchRecognizeClause, error) { + pos := p.currentLocation() + p.advance() // Consume MATCH_RECOGNIZE + + if !p.isType(models.TokenTypeLParen) { + return nil, p.expectedError("( after MATCH_RECOGNIZE") + } + p.advance() // Consume ( + + clause := &ast.MatchRecognizeClause{Pos: pos} + + // Parse sub-clauses in order. Each is optional except PATTERN and DEFINE. + // PARTITION BY + if p.isType(models.TokenTypePartition) { + p.advance() // PARTITION + if p.isType(models.TokenTypeBy) { + p.advance() // BY + } + for { + expr, err := p.parseExpression() + if err != nil { + return nil, err + } + clause.PartitionBy = append(clause.PartitionBy, expr) + if !p.isType(models.TokenTypeComma) { + break + } + p.advance() + } + } + + // ORDER BY + if p.isType(models.TokenTypeOrder) { + p.advance() // ORDER + if p.isType(models.TokenTypeBy) { + p.advance() // BY + } + for { + expr, err := p.parseExpression() + if err != nil { + return nil, err + } + entry := ast.OrderByExpression{Expression: expr, Ascending: true} + if p.isType(models.TokenTypeAsc) { + p.advance() + } else if p.isType(models.TokenTypeDesc) { + entry.Ascending = false + p.advance() + } + clause.OrderBy = append(clause.OrderBy, entry) + if !p.isType(models.TokenTypeComma) { + break + } + p.advance() + } + } + + // MEASURES + if strings.EqualFold(p.currentToken.Token.Value, "MEASURES") { + p.advance() // MEASURES + for { + expr, err := p.parseExpression() + if err != nil { + return nil, err + } + alias := "" + if p.isType(models.TokenTypeAs) { + p.advance() // AS + alias = p.currentToken.Token.Value + p.advance() // alias name + } + clause.Measures = append(clause.Measures, ast.MeasureDef{ + Expr: expr, + Alias: alias, + }) + if !p.isType(models.TokenTypeComma) { + break + } + p.advance() + } + } + + // ONE ROW PER MATCH / ALL ROWS PER MATCH + if strings.EqualFold(p.currentToken.Token.Value, "ONE") { + clause.RowsPerMatch = "ONE ROW PER MATCH" + p.advance() // ONE + p.advance() // ROW + p.advance() // PER + p.advance() // MATCH + } else if p.isType(models.TokenTypeAll) { + clause.RowsPerMatch = "ALL ROWS PER MATCH" + p.advance() // ALL + p.advance() // ROWS + p.advance() // PER + p.advance() // MATCH + } + + // AFTER MATCH SKIP ... — consume as raw text until PATTERN or DEFINE + if strings.EqualFold(p.currentToken.Token.Value, "AFTER") { + var parts []string + for { + val := strings.ToUpper(p.currentToken.Token.Value) + if val == "PATTERN" || val == "DEFINE" { + break + } + if p.isType(models.TokenTypeEOF) || p.isType(models.TokenTypeRParen) { + break + } + parts = append(parts, p.currentToken.Token.Value) + p.advance() + } + clause.AfterMatch = strings.Join(parts, " ") + } + + // PATTERN ( regex ) + if strings.EqualFold(p.currentToken.Token.Value, "PATTERN") { + p.advance() // PATTERN + if !p.isType(models.TokenTypeLParen) { + return nil, p.expectedError("( after PATTERN") + } + p.advance() // Consume ( + + // Collect pattern tokens as raw text until the matching ')' + var patParts []string + depth := 1 + for depth > 0 { + if p.isType(models.TokenTypeEOF) { + return nil, p.expectedError(") to close PATTERN") + } + if p.isType(models.TokenTypeLParen) { + depth++ + patParts = append(patParts, "(") + } else if p.isType(models.TokenTypeRParen) { + depth-- + if depth > 0 { + patParts = append(patParts, ")") + } + } else { + patParts = append(patParts, p.currentToken.Token.Value) + } + p.advance() + } + clause.Pattern = strings.Join(patParts, " ") + } + + // DEFINE var AS condition, ... + if strings.EqualFold(p.currentToken.Token.Value, "DEFINE") { + p.advance() // DEFINE + for { + if p.isType(models.TokenTypeRParen) || p.isType(models.TokenTypeEOF) { + break + } + name := p.currentToken.Token.Value + p.advance() // variable name + + if !p.isType(models.TokenTypeAs) { + return nil, p.expectedError("AS after pattern variable " + name) + } + p.advance() // AS + + cond, err := p.parseExpression() + if err != nil { + return nil, err + } + clause.Definitions = append(clause.Definitions, ast.PatternDef{ + Name: name, + Condition: cond, + }) + if !p.isType(models.TokenTypeComma) { + break + } + p.advance() + } + } + + // Expect closing ) + if !p.isType(models.TokenTypeRParen) { + return nil, p.expectedError(") to close MATCH_RECOGNIZE") + } + p.advance() // Consume ) + + return clause, nil +} diff --git a/pkg/sql/parser/select_subquery.go b/pkg/sql/parser/select_subquery.go index f5377799..76fb0171 100644 --- a/pkg/sql/parser/select_subquery.go +++ b/pkg/sql/parser/select_subquery.go @@ -179,7 +179,7 @@ func (p *Parser) parseFromTableReference() (ast.TableReference, error) { // Similarly, START followed by WITH is a hierarchical query seed, not an alias. // Don't consume PIVOT/UNPIVOT as a table alias — they are contextual // keywords in SQL Server/Oracle and must reach the pivot-clause parser below. - if (p.isIdentifier() || p.isType(models.TokenTypeAs)) && !p.isMariaDBClauseStart() && !p.isPivotKeyword() && !p.isUnpivotKeyword() && !p.isQualifyKeyword() && !p.isMinusSetOp() && !p.isSnowflakeTimeTravelStart() && !p.isSampleKeyword() { + if (p.isIdentifier() || p.isType(models.TokenTypeAs)) && !p.isMariaDBClauseStart() && !p.isPivotKeyword() && !p.isUnpivotKeyword() && !p.isQualifyKeyword() && !p.isMinusSetOp() && !p.isSnowflakeTimeTravelStart() && !p.isSampleKeyword() && !p.isMatchRecognizeKeyword() { if p.isType(models.TokenTypeAs) { p.advance() // Consume AS if !p.isIdentifier() { @@ -237,6 +237,23 @@ func (p *Parser) parseFromTableReference() (ast.TableReference, error) { p.parsePivotAlias(&tableRef) } + // Snowflake / Oracle MATCH_RECOGNIZE clause + if p.isMatchRecognizeKeyword() { + mr, err := p.parseMatchRecognize() + if err != nil { + return tableRef, err + } + tableRef.MatchRecognize = mr + // Optional alias after MATCH_RECOGNIZE (...) + if p.isType(models.TokenTypeAs) { + p.advance() + } + if p.isIdentifier() { + tableRef.Alias = p.currentToken.Token.Value + p.advance() + } + } + return tableRef, nil } @@ -293,7 +310,7 @@ func (p *Parser) parseJoinedTableRef(joinType string) (ast.TableReference, error // Similarly, START followed by WITH is a hierarchical query seed, not an alias. // Don't consume PIVOT/UNPIVOT as a table alias — they are contextual // keywords in SQL Server/Oracle and must reach the pivot-clause parser below. - if (p.isIdentifier() || p.isType(models.TokenTypeAs)) && !p.isMariaDBClauseStart() && !p.isPivotKeyword() && !p.isUnpivotKeyword() && !p.isQualifyKeyword() && !p.isMinusSetOp() && !p.isSnowflakeTimeTravelStart() && !p.isSampleKeyword() { + if (p.isIdentifier() || p.isType(models.TokenTypeAs)) && !p.isMariaDBClauseStart() && !p.isPivotKeyword() && !p.isUnpivotKeyword() && !p.isQualifyKeyword() && !p.isMinusSetOp() && !p.isSnowflakeTimeTravelStart() && !p.isSampleKeyword() && !p.isMatchRecognizeKeyword() { if p.isType(models.TokenTypeAs) { p.advance() if !p.isIdentifier() { diff --git a/pkg/sql/parser/snowflake_match_recognize_test.go b/pkg/sql/parser/snowflake_match_recognize_test.go new file mode 100644 index 00000000..425a72f7 --- /dev/null +++ b/pkg/sql/parser/snowflake_match_recognize_test.go @@ -0,0 +1,136 @@ +// Copyright 2026 GoSQLX Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); + +package parser_test + +import ( + "testing" + + "github.com/ajitpratap0/GoSQLX/pkg/gosqlx" + "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" + "github.com/ajitpratap0/GoSQLX/pkg/sql/keywords" +) + +// TestSnowflakeMatchRecognize verifies the SQL:2016 MATCH_RECOGNIZE clause +// parses for the Snowflake dialect. This was the last remaining Snowflake QA +// failure. Regression for #483. +func TestSnowflakeMatchRecognize(t *testing.T) { + queries := map[string]string{ + "basic_up_down": `SELECT * FROM stock_price MATCH_RECOGNIZE ( + PARTITION BY symbol + ORDER BY ts + MEASURES MATCH_NUMBER() AS m + ALL ROWS PER MATCH + PATTERN (UP+ DOWN+) + DEFINE UP AS price > PREV(price), DOWN AS price < PREV(price) + )`, + + "one_row_per_match": `SELECT * FROM events MATCH_RECOGNIZE ( + ORDER BY ts + MEASURES FIRST(ts) AS start_ts, LAST(ts) AS end_ts + ONE ROW PER MATCH + PATTERN (A B+ C) + DEFINE A AS status = 'start', B AS status = 'running', C AS status = 'done' + )`, + + "with_alias": `SELECT mr.* FROM events MATCH_RECOGNIZE ( + ORDER BY ts + PATTERN (A+ B) + DEFINE A AS val > 0, B AS val <= 0 + ) AS mr`, + + "pattern_alternation": `SELECT * FROM t MATCH_RECOGNIZE ( + ORDER BY ts + PATTERN ((A | B) C+) + DEFINE A AS x = 1, B AS x = 2, C AS x = 3 + )`, + + "measures_only": `SELECT * FROM t MATCH_RECOGNIZE ( + ORDER BY id + MEASURES COUNT(*) AS cnt + ALL ROWS PER MATCH + PATTERN (X+) + DEFINE X AS val > 10 + )`, + + "partition_and_order": `SELECT * FROM t MATCH_RECOGNIZE ( + PARTITION BY region, category + ORDER BY ts DESC + PATTERN (A B) + DEFINE A AS revenue > 100, B AS revenue < 50 + )`, + } + for name, q := range queries { + q := q + t.Run(name, func(t *testing.T) { + if _, err := gosqlx.ParseWithDialect(q, keywords.DialectSnowflake); err != nil { + t.Fatalf("parse failed: %v", err) + } + }) + } +} + +// TestMatchRecognizeASTShape verifies the MatchRecognizeClause AST node is +// populated and reachable via Children() traversal. +func TestMatchRecognizeASTShape(t *testing.T) { + q := `SELECT * FROM stock_price MATCH_RECOGNIZE ( + PARTITION BY symbol + ORDER BY ts + MEASURES MATCH_NUMBER() AS m + ALL ROWS PER MATCH + PATTERN (UP+ DOWN+) + DEFINE UP AS price > PREV(price), DOWN AS price < PREV(price) + )` + tree, err := gosqlx.ParseWithDialect(q, keywords.DialectSnowflake) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + + var mr *ast.MatchRecognizeClause + var visit func(n ast.Node) + visit = func(n ast.Node) { + if n == nil || mr != nil { + return + } + if m, ok := n.(*ast.MatchRecognizeClause); ok { + mr = m + return + } + for _, c := range n.Children() { + visit(c) + } + } + for _, s := range tree.Statements { + visit(s) + } + if mr == nil { + t.Fatal("MatchRecognizeClause not found in AST") + } + if len(mr.PartitionBy) != 1 { + t.Fatalf("PartitionBy: want 1, got %d", len(mr.PartitionBy)) + } + if len(mr.OrderBy) != 1 { + t.Fatalf("OrderBy: want 1, got %d", len(mr.OrderBy)) + } + if len(mr.Measures) != 1 || mr.Measures[0].Alias != "m" { + t.Fatalf("Measures: want [{alias:m}], got %+v", mr.Measures) + } + if mr.RowsPerMatch != "ALL ROWS PER MATCH" { + t.Fatalf("RowsPerMatch: want %q, got %q", "ALL ROWS PER MATCH", mr.RowsPerMatch) + } + if mr.Pattern == "" { + t.Fatal("Pattern is empty") + } + if len(mr.Definitions) != 2 { + t.Fatalf("Definitions: want 2 (UP, DOWN), got %d", len(mr.Definitions)) + } + if mr.Definitions[0].Name != "UP" { + t.Fatalf("Definitions[0].Name: want UP, got %s", mr.Definitions[0].Name) + } + // Verify Children() includes the sub-expressions + children := mr.Children() + if len(children) < 4 { + t.Fatalf("Children(): want >=4 (partition+order+measure+2 defs), got %d", len(children)) + } +}