Skip to content

Commit f480288

Browse files
ajitpratap0Ajit Pratap Singh
andauthored
feat(parser): Snowflake LATERAL FLATTEN and named arguments (#483) (#492)
Co-authored-by: Ajit Pratap Singh <ajitpratapsingh@Ajits-Mac-mini-2655.local>
1 parent 60101ec commit f480288

File tree

5 files changed

+156
-0
lines changed

5 files changed

+156
-0
lines changed

pkg/sql/ast/ast.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,11 @@ type TableReference struct {
228228
Lateral bool // LATERAL keyword for correlated subqueries (PostgreSQL)
229229
TableHints []string // SQL Server table hints: WITH (NOLOCK), WITH (ROWLOCK, UPDLOCK), etc.
230230
Final bool // ClickHouse FINAL modifier: forces MergeTree part merge
231+
// TableFunc is a function-call table reference such as
232+
// Snowflake LATERAL FLATTEN(input => col), TABLE(my_func(1,2)),
233+
// IDENTIFIER('t'), or PostgreSQL unnest(array_col). When set, Name
234+
// holds the function name and TableFunc carries the call itself.
235+
TableFunc *FunctionCall
231236
// ForSystemTime is the MariaDB temporal table clause (10.3.4+).
232237
// Example: SELECT * FROM t FOR SYSTEM_TIME AS OF '2024-01-01'
233238
ForSystemTime *ForSystemTimeClause // MariaDB temporal query
@@ -254,6 +259,9 @@ func (t TableReference) Children() []Node {
254259
if t.Subquery != nil {
255260
nodes = append(nodes, t.Subquery)
256261
}
262+
if t.TableFunc != nil {
263+
nodes = append(nodes, t.TableFunc)
264+
}
257265
if t.Pivot != nil {
258266
nodes = append(nodes, t.Pivot)
259267
}
@@ -1042,6 +1050,24 @@ func (u *UnaryExpression) TokenLiteral() string {
10421050

10431051
func (u UnaryExpression) Children() []Node { return []Node{u.Expr} }
10441052

1053+
// NamedArgument represents a function argument of the form `name => expr`,
1054+
// used by Snowflake (FLATTEN(input => col), GENERATOR(rowcount => 100)),
1055+
// BigQuery, Oracle, and PostgreSQL procedural calls.
1056+
type NamedArgument struct {
1057+
Name string
1058+
Value Expression
1059+
Pos models.Location
1060+
}
1061+
1062+
func (n *NamedArgument) expressionNode() {}
1063+
func (n NamedArgument) TokenLiteral() string { return n.Name }
1064+
func (n NamedArgument) Children() []Node {
1065+
if n.Value == nil {
1066+
return nil
1067+
}
1068+
return []Node{n.Value}
1069+
}
1070+
10451071
// CastExpression represents CAST(expr AS type) or TRY_CAST(expr AS type).
10461072
// Try is set when the expression originated from TRY_CAST (Snowflake / SQL
10471073
// Server / BigQuery), which returns NULL on conversion failure instead of

pkg/sql/parser/pivot.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,3 +264,17 @@ func (p *Parser) parseUnpivotClause() (*ast.UnpivotClause, error) {
264264
Pos: pos,
265265
}, nil
266266
}
267+
268+
// supportsTableFunction reports whether the current dialect allows
269+
// function-call style table references in the FROM list — Snowflake
270+
// (FLATTEN, TABLE, IDENTIFIER, GENERATOR), BigQuery (UNNEST), and
271+
// PostgreSQL (unnest, generate_series, json_each, ...).
272+
func (p *Parser) supportsTableFunction() bool {
273+
switch p.dialect {
274+
case string(keywords.DialectSnowflake),
275+
string(keywords.DialectBigQuery),
276+
string(keywords.DialectPostgreSQL):
277+
return true
278+
}
279+
return false
280+
}

pkg/sql/parser/select_subquery.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,18 @@ func (p *Parser) parseFromTableReference() (ast.TableReference, error) {
8282
Name: qualifiedName,
8383
Lateral: isLateral,
8484
}
85+
86+
// Function-call table reference (Snowflake FLATTEN, TABLE(...),
87+
// IDENTIFIER(...), PostgreSQL unnest(...), BigQuery UNNEST(...)).
88+
// If the parsed name is followed by '(' at FROM position, reparse
89+
// it as a function call. Gated to dialects that actually use this.
90+
if p.isType(models.TokenTypeLParen) && p.supportsTableFunction() {
91+
funcCall, ferr := p.parseFunctionCall(qualifiedName)
92+
if ferr != nil {
93+
return tableRef, ferr
94+
}
95+
tableRef.TableFunc = funcCall
96+
}
8597
}
8698

8799
// Check for table alias (required for derived tables, optional for regular tables).
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// Copyright 2026 GoSQLX Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
5+
package parser_test
6+
7+
import (
8+
"testing"
9+
10+
"github.com/ajitpratap0/GoSQLX/pkg/gosqlx"
11+
"github.com/ajitpratap0/GoSQLX/pkg/sql/ast"
12+
"github.com/ajitpratap0/GoSQLX/pkg/sql/keywords"
13+
)
14+
15+
// TestSnowflakeFromTableFunctions verifies function-call style table refs
16+
// (LATERAL FLATTEN, TABLE(...), IDENTIFIER(...), GENERATOR(...)) parse in
17+
// the Snowflake dialect. Regression for #483.
18+
func TestSnowflakeFromTableFunctions(t *testing.T) {
19+
queries := map[string]string{
20+
"lateral_flatten_named": `SELECT value FROM LATERAL FLATTEN(input => array_col)`,
21+
22+
"lateral_flatten_with_alias": `SELECT f.value
23+
FROM events, LATERAL FLATTEN(input => events.tags) f`,
24+
25+
"table_of_udf": `SELECT * FROM TABLE(my_func(1, 2))`,
26+
27+
"identifier_wrapped": `SELECT * FROM IDENTIFIER('my_table')`,
28+
29+
"generator_named_args": `SELECT seq4()
30+
FROM TABLE(GENERATOR(rowcount => 100))`,
31+
}
32+
for name, q := range queries {
33+
q := q
34+
t.Run(name, func(t *testing.T) {
35+
if _, err := gosqlx.ParseWithDialect(q, keywords.DialectSnowflake); err != nil {
36+
t.Fatalf("parse failed: %v", err)
37+
}
38+
})
39+
}
40+
}
41+
42+
// TestNamedArgumentASTShape verifies the NamedArgument AST node is produced
43+
// for `name => expr` and is reachable via the visitor pattern.
44+
func TestNamedArgumentASTShape(t *testing.T) {
45+
q := `SELECT * FROM LATERAL FLATTEN(input => tags)`
46+
tree, err := gosqlx.ParseWithDialect(q, keywords.DialectSnowflake)
47+
if err != nil {
48+
t.Fatalf("parse failed: %v", err)
49+
}
50+
var found bool
51+
var visit func(n ast.Node)
52+
visit = func(n ast.Node) {
53+
if n == nil || found {
54+
return
55+
}
56+
if na, ok := n.(*ast.NamedArgument); ok {
57+
if na.Name != "input" {
58+
t.Fatalf("NamedArgument.Name: want %q, got %q", "input", na.Name)
59+
}
60+
if na.Value == nil {
61+
t.Fatal("NamedArgument.Value nil")
62+
}
63+
found = true
64+
return
65+
}
66+
for _, c := range n.Children() {
67+
visit(c)
68+
}
69+
}
70+
for _, stmt := range tree.Statements {
71+
visit(stmt)
72+
}
73+
if !found {
74+
t.Fatal("NamedArgument not found in AST")
75+
}
76+
}

pkg/sql/parser/window.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,34 @@ func (p *Parser) parseFunctionCall(funcName string) (*ast.FunctionCall, error) {
4747
// Parse arguments if not empty
4848
if !p.isType(models.TokenTypeRParen) {
4949
for !p.isType(models.TokenTypeOrder) {
50+
// Named argument form: `name => expr` (Snowflake FLATTEN,
51+
// BigQuery, Oracle, PostgreSQL procedural calls). Detect by a
52+
// bare identifier immediately followed by =>.
53+
if p.isIdentifier() &&
54+
p.peekToken().Token.Type == models.TokenTypeRArrow {
55+
namePos := p.currentLocation()
56+
argName := p.currentToken.Token.Value
57+
p.advance() // name
58+
p.advance() // =>
59+
value, err := p.parseExpression()
60+
if err != nil {
61+
return nil, err
62+
}
63+
arguments = append(arguments, &ast.NamedArgument{
64+
Name: argName,
65+
Value: value,
66+
Pos: namePos,
67+
})
68+
if p.isType(models.TokenTypeComma) {
69+
p.advance()
70+
continue
71+
}
72+
if p.isType(models.TokenTypeRParen) {
73+
break
74+
}
75+
return nil, p.expectedError(", or )")
76+
}
77+
5078
arg, err := p.parseExpression()
5179
if err != nil {
5280
return nil, err

0 commit comments

Comments
 (0)