Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions pkg/sql/ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,11 @@ type TableReference struct {
Lateral bool // LATERAL keyword for correlated subqueries (PostgreSQL)
TableHints []string // SQL Server table hints: WITH (NOLOCK), WITH (ROWLOCK, UPDLOCK), etc.
Final bool // ClickHouse FINAL modifier: forces MergeTree part merge
// TableFunc is a function-call table reference such as
// Snowflake LATERAL FLATTEN(input => col), TABLE(my_func(1,2)),
// IDENTIFIER('t'), or PostgreSQL unnest(array_col). When set, Name
// holds the function name and TableFunc carries the call itself.
TableFunc *FunctionCall
// ForSystemTime is the MariaDB temporal table clause (10.3.4+).
// Example: SELECT * FROM t FOR SYSTEM_TIME AS OF '2024-01-01'
ForSystemTime *ForSystemTimeClause // MariaDB temporal query
Expand All @@ -254,6 +259,9 @@ func (t TableReference) Children() []Node {
if t.Subquery != nil {
nodes = append(nodes, t.Subquery)
}
if t.TableFunc != nil {
nodes = append(nodes, t.TableFunc)
}
if t.Pivot != nil {
nodes = append(nodes, t.Pivot)
}
Expand Down Expand Up @@ -1042,6 +1050,24 @@ func (u *UnaryExpression) TokenLiteral() string {

func (u UnaryExpression) Children() []Node { return []Node{u.Expr} }

// NamedArgument represents a function argument of the form `name => expr`,
// used by Snowflake (FLATTEN(input => col), GENERATOR(rowcount => 100)),
// BigQuery, Oracle, and PostgreSQL procedural calls.
type NamedArgument struct {
Name string
Value Expression
Pos models.Location
}

func (n *NamedArgument) expressionNode() {}
func (n NamedArgument) TokenLiteral() string { return n.Name }
func (n NamedArgument) Children() []Node {
if n.Value == nil {
return nil
}
return []Node{n.Value}
}

// CastExpression represents CAST(expr AS type) or TRY_CAST(expr AS type).
// Try is set when the expression originated from TRY_CAST (Snowflake / SQL
// Server / BigQuery), which returns NULL on conversion failure instead of
Expand Down
14 changes: 14 additions & 0 deletions pkg/sql/parser/pivot.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,17 @@ func (p *Parser) parseUnpivotClause() (*ast.UnpivotClause, error) {
Pos: pos,
}, nil
}

// supportsTableFunction reports whether the current dialect allows
// function-call style table references in the FROM list — Snowflake
// (FLATTEN, TABLE, IDENTIFIER, GENERATOR), BigQuery (UNNEST), and
// PostgreSQL (unnest, generate_series, json_each, ...).
func (p *Parser) supportsTableFunction() bool {
switch p.dialect {
case string(keywords.DialectSnowflake),
string(keywords.DialectBigQuery),
string(keywords.DialectPostgreSQL):
return true
}
return false
}
12 changes: 12 additions & 0 deletions pkg/sql/parser/select_subquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ func (p *Parser) parseFromTableReference() (ast.TableReference, error) {
Name: qualifiedName,
Lateral: isLateral,
}

// Function-call table reference (Snowflake FLATTEN, TABLE(...),
// IDENTIFIER(...), PostgreSQL unnest(...), BigQuery UNNEST(...)).
// If the parsed name is followed by '(' at FROM position, reparse
// it as a function call. Gated to dialects that actually use this.
if p.isType(models.TokenTypeLParen) && p.supportsTableFunction() {
funcCall, ferr := p.parseFunctionCall(qualifiedName)
if ferr != nil {
return tableRef, ferr
}
tableRef.TableFunc = funcCall
}
}

// Check for table alias (required for derived tables, optional for regular tables).
Expand Down
76 changes: 76 additions & 0 deletions pkg/sql/parser/snowflake_lateral_flatten_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright 2026 GoSQLX Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");

package parser_test

import (
"testing"

"github.com/ajitpratap0/GoSQLX/pkg/gosqlx"
"github.com/ajitpratap0/GoSQLX/pkg/sql/ast"
"github.com/ajitpratap0/GoSQLX/pkg/sql/keywords"
)

// TestSnowflakeFromTableFunctions verifies function-call style table refs
// (LATERAL FLATTEN, TABLE(...), IDENTIFIER(...), GENERATOR(...)) parse in
// the Snowflake dialect. Regression for #483.
func TestSnowflakeFromTableFunctions(t *testing.T) {
queries := map[string]string{
"lateral_flatten_named": `SELECT value FROM LATERAL FLATTEN(input => array_col)`,

"lateral_flatten_with_alias": `SELECT f.value
FROM events, LATERAL FLATTEN(input => events.tags) f`,

"table_of_udf": `SELECT * FROM TABLE(my_func(1, 2))`,

"identifier_wrapped": `SELECT * FROM IDENTIFIER('my_table')`,

"generator_named_args": `SELECT seq4()
FROM TABLE(GENERATOR(rowcount => 100))`,
}
for name, q := range queries {
q := q
t.Run(name, func(t *testing.T) {
if _, err := gosqlx.ParseWithDialect(q, keywords.DialectSnowflake); err != nil {
t.Fatalf("parse failed: %v", err)
}
})
}
}

// TestNamedArgumentASTShape verifies the NamedArgument AST node is produced
// for `name => expr` and is reachable via the visitor pattern.
func TestNamedArgumentASTShape(t *testing.T) {
q := `SELECT * FROM LATERAL FLATTEN(input => tags)`
tree, err := gosqlx.ParseWithDialect(q, keywords.DialectSnowflake)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
var found bool
var visit func(n ast.Node)
visit = func(n ast.Node) {
if n == nil || found {
return
}
if na, ok := n.(*ast.NamedArgument); ok {
if na.Name != "input" {
t.Fatalf("NamedArgument.Name: want %q, got %q", "input", na.Name)
}
if na.Value == nil {
t.Fatal("NamedArgument.Value nil")
}
found = true
return
}
for _, c := range n.Children() {
visit(c)
}
}
for _, stmt := range tree.Statements {
visit(stmt)
}
if !found {
t.Fatal("NamedArgument not found in AST")
}
}
28 changes: 28 additions & 0 deletions pkg/sql/parser/window.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,34 @@ func (p *Parser) parseFunctionCall(funcName string) (*ast.FunctionCall, error) {
// Parse arguments if not empty
if !p.isType(models.TokenTypeRParen) {
for !p.isType(models.TokenTypeOrder) {
// Named argument form: `name => expr` (Snowflake FLATTEN,
// BigQuery, Oracle, PostgreSQL procedural calls). Detect by a
// bare identifier immediately followed by =>.
if p.isIdentifier() &&
p.peekToken().Token.Type == models.TokenTypeRArrow {
namePos := p.currentLocation()
argName := p.currentToken.Token.Value
p.advance() // name
p.advance() // =>
value, err := p.parseExpression()
if err != nil {
return nil, err
}
arguments = append(arguments, &ast.NamedArgument{
Name: argName,
Value: value,
Pos: namePos,
})
if p.isType(models.TokenTypeComma) {
p.advance()
continue
}
if p.isType(models.TokenTypeRParen) {
break
}
return nil, p.expectedError(", or )")
}

arg, err := p.parseExpression()
if err != nil {
return nil, err
Expand Down
Loading