Skip to content

Commit 6097a19

Browse files
Ajit Pratap SinghAjit Pratap Singh
authored andcommitted
fix(ast): TokenLiteral returns TRY_CAST when Try=true (#483 review)
1 parent a792960 commit 6097a19

2 files changed

Lines changed: 57 additions & 3 deletions

File tree

pkg/sql/ast/ast.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,9 +1044,14 @@ type CastExpression struct {
10441044
Try bool
10451045
}
10461046

1047-
func (c *CastExpression) expressionNode() {}
1048-
func (c CastExpression) TokenLiteral() string { return "CAST" }
1049-
func (c CastExpression) Children() []Node { return []Node{c.Expr} }
1047+
func (c *CastExpression) expressionNode() {}
1048+
func (c CastExpression) TokenLiteral() string {
1049+
if c.Try {
1050+
return "TRY_CAST"
1051+
}
1052+
return "CAST"
1053+
}
1054+
func (c CastExpression) Children() []Node { return []Node{c.Expr} }
10501055

10511056
// AliasedExpression represents an expression with an alias (expr AS alias)
10521057
type AliasedExpression struct {

pkg/sql/parser/snowflake_trycast_nulls_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"testing"
99

1010
"github.com/ajitpratap0/GoSQLX/pkg/gosqlx"
11+
"github.com/ajitpratap0/GoSQLX/pkg/sql/ast"
1112
"github.com/ajitpratap0/GoSQLX/pkg/sql/keywords"
1213
)
1314

@@ -47,3 +48,51 @@ func TestWindowNullTreatment(t *testing.T) {
4748
})
4849
}
4950
}
51+
52+
// TestTryCastASTShape verifies that a TRY_CAST expression has Try=true and
53+
// TokenLiteral() returns "TRY_CAST", while a plain CAST returns "CAST".
54+
func TestTryCastASTShape(t *testing.T) {
55+
tcs := map[string]struct {
56+
query string
57+
wantTry bool
58+
wantLit string
59+
}{
60+
"try_cast": {`SELECT TRY_CAST(value AS INT) FROM events`, true, "TRY_CAST"},
61+
"cast": {`SELECT CAST(value AS INT) FROM events`, false, "CAST"},
62+
}
63+
for name, tc := range tcs {
64+
tc := tc
65+
t.Run(name, func(t *testing.T) {
66+
tree, err := gosqlx.ParseWithDialect(tc.query, keywords.DialectSnowflake)
67+
if err != nil {
68+
t.Fatalf("parse failed: %v", err)
69+
}
70+
var found bool
71+
var visit func(n ast.Node)
72+
visit = func(n ast.Node) {
73+
if n == nil || found {
74+
return
75+
}
76+
if c, ok := n.(*ast.CastExpression); ok {
77+
if c.Try != tc.wantTry {
78+
t.Fatalf("Try: want %v, got %v", tc.wantTry, c.Try)
79+
}
80+
if c.TokenLiteral() != tc.wantLit {
81+
t.Fatalf("TokenLiteral: want %q, got %q", tc.wantLit, c.TokenLiteral())
82+
}
83+
found = true
84+
return
85+
}
86+
for _, ch := range n.Children() {
87+
visit(ch)
88+
}
89+
}
90+
for _, stmt := range tree.Statements {
91+
visit(stmt)
92+
}
93+
if !found {
94+
t.Fatal("CastExpression not found in AST")
95+
}
96+
})
97+
}
98+
}

0 commit comments

Comments
 (0)