Skip to content

Commit ec3f4d4

Browse files
Ajit Pratap SinghAjit Pratap Singh
authored andcommitted
fix(parser): polish PIVOT/UNPIVOT round-trip and validation
- Formatter emits AS before PIVOT/UNPIVOT aliases for clean round-trip. - Tokenizer records Quote='[' on SQL Server bracket-quoted identifiers; pivot parser uses renderQuotedIdent to preserve [North] etc. in PivotClause.InValues and UnpivotClause.InColumns. - Reject empty IN lists for both PIVOT and UNPIVOT. - Extract parsePivotAlias helper, collapsing four duplicated alias blocks in select_subquery.go. - Add TestPivotNegativeCases (missing parens, missing FOR/IN, empty IN) and TestPivotBracketedInValuesPreserved. Full test suite passes with -race.
1 parent a1e78c0 commit ec3f4d4

5 files changed

Lines changed: 106 additions & 47 deletions

File tree

pkg/formatter/render.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1205,7 +1205,13 @@ func tableRefSQL(t *ast.TableReference) string {
12051205
sb.WriteString("))")
12061206
}
12071207
if t.Alias != "" {
1208-
sb.WriteString(" ")
1208+
// PIVOT/UNPIVOT aliases conventionally use AS to avoid ambiguity
1209+
// with the closing paren of the clause.
1210+
if t.Pivot != nil || t.Unpivot != nil {
1211+
sb.WriteString(" AS ")
1212+
} else {
1213+
sb.WriteString(" ")
1214+
}
12091215
sb.WriteString(t.Alias)
12101216
}
12111217
return sb.String()

pkg/sql/parser/pivot.go

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,43 @@ import (
2525
"github.com/ajitpratap0/GoSQLX/pkg/sql/keywords"
2626
)
2727

28+
// renderQuotedIdent reproduces the original delimiters of a quoted identifier
29+
// token so the parsed value round-trips through the formatter. The tokenizer
30+
// strips delimiters but records them in QuoteStyle.
31+
func renderQuotedIdent(tok models.Token) string {
32+
q := tok.Quote
33+
if q == 0 && tok.Word != nil {
34+
q = tok.Word.QuoteStyle
35+
}
36+
switch q {
37+
case '[':
38+
return "[" + tok.Value + "]"
39+
case '"':
40+
return "\"" + tok.Value + "\""
41+
case '`':
42+
return "`" + tok.Value + "`"
43+
}
44+
return tok.Value
45+
}
46+
47+
// parsePivotAlias consumes an optional alias (with or without AS) following a
48+
// PIVOT/UNPIVOT clause. Extracted to avoid four copies of the same logic in
49+
// the table-reference and join paths.
50+
func (p *Parser) parsePivotAlias(ref *ast.TableReference) {
51+
if p.isType(models.TokenTypeAs) {
52+
p.advance() // consume AS
53+
if p.isIdentifier() {
54+
ref.Alias = p.currentToken.Token.Value
55+
p.advance()
56+
}
57+
return
58+
}
59+
if p.isIdentifier() {
60+
ref.Alias = p.currentToken.Token.Value
61+
p.advance()
62+
}
63+
}
64+
2865
// pivotDialectAllowed reports whether PIVOT/UNPIVOT is a recognized clause
2966
// for the parser's current dialect. PIVOT/UNPIVOT are SQL Server / Oracle
3067
// extensions; in other dialects the words must remain valid identifiers.
@@ -107,7 +144,7 @@ func (p *Parser) parsePivotClause() (*ast.PivotClause, error) {
107144
if !p.isIdentifier() && !p.isType(models.TokenTypeNumber) && !p.isStringLiteral() {
108145
return nil, p.expectedError("value in PIVOT IN list")
109146
}
110-
inValues = append(inValues, p.currentToken.Token.Value)
147+
inValues = append(inValues, renderQuotedIdent(p.currentToken.Token))
111148
p.advance()
112149
if p.isType(models.TokenTypeComma) {
113150
p.advance()
@@ -117,6 +154,9 @@ func (p *Parser) parsePivotClause() (*ast.PivotClause, error) {
117154
if !p.isType(models.TokenTypeRParen) {
118155
return nil, p.expectedError(") to close PIVOT IN list")
119156
}
157+
if len(inValues) == 0 {
158+
return nil, p.expectedError("at least one value in PIVOT IN list")
159+
}
120160
p.advance() // close IN list )
121161

122162
if !p.isType(models.TokenTypeRParen) {
@@ -181,7 +221,7 @@ func (p *Parser) parseUnpivotClause() (*ast.UnpivotClause, error) {
181221
if !p.isIdentifier() {
182222
return nil, p.expectedError("column name in UNPIVOT IN list")
183223
}
184-
cols = append(cols, p.currentToken.Token.Value)
224+
cols = append(cols, renderQuotedIdent(p.currentToken.Token))
185225
p.advance()
186226
if p.isType(models.TokenTypeComma) {
187227
p.advance()
@@ -191,6 +231,9 @@ func (p *Parser) parseUnpivotClause() (*ast.UnpivotClause, error) {
191231
if !p.isType(models.TokenTypeRParen) {
192232
return nil, p.expectedError(") to close UNPIVOT IN list")
193233
}
234+
if len(cols) == 0 {
235+
return nil, p.expectedError("at least one column in UNPIVOT IN list")
236+
}
194237
p.advance() // close IN list )
195238

196239
if !p.isType(models.TokenTypeRParen) {

pkg/sql/parser/select_subquery.go

Lines changed: 4 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -134,17 +134,7 @@ func (p *Parser) parseFromTableReference() (ast.TableReference, error) {
134134
return tableRef, err
135135
}
136136
tableRef.Pivot = pivot
137-
// PIVOT result often has its own alias: PIVOT (...) AS pvt
138-
if p.isType(models.TokenTypeAs) {
139-
p.advance() // consume AS
140-
if p.isIdentifier() {
141-
tableRef.Alias = p.currentToken.Token.Value
142-
p.advance()
143-
}
144-
} else if p.isIdentifier() {
145-
tableRef.Alias = p.currentToken.Token.Value
146-
p.advance()
147-
}
137+
p.parsePivotAlias(&tableRef)
148138
}
149139

150140
// SQL Server / Oracle UNPIVOT clause
@@ -154,17 +144,7 @@ func (p *Parser) parseFromTableReference() (ast.TableReference, error) {
154144
return tableRef, err
155145
}
156146
tableRef.Unpivot = unpivot
157-
// UNPIVOT result alias: UNPIVOT (...) AS unpvt
158-
if p.isType(models.TokenTypeAs) {
159-
p.advance() // consume AS
160-
if p.isIdentifier() {
161-
tableRef.Alias = p.currentToken.Token.Value
162-
p.advance()
163-
}
164-
} else if p.isIdentifier() {
165-
tableRef.Alias = p.currentToken.Token.Value
166-
p.advance()
167-
}
147+
p.parsePivotAlias(&tableRef)
168148
}
169149

170150
return tableRef, nil
@@ -268,16 +248,7 @@ func (p *Parser) parseJoinedTableRef(joinType string) (ast.TableReference, error
268248
return ref, err
269249
}
270250
ref.Pivot = pivot
271-
if p.isType(models.TokenTypeAs) {
272-
p.advance()
273-
if p.isIdentifier() {
274-
ref.Alias = p.currentToken.Token.Value
275-
p.advance()
276-
}
277-
} else if p.isIdentifier() {
278-
ref.Alias = p.currentToken.Token.Value
279-
p.advance()
280-
}
251+
p.parsePivotAlias(&ref)
281252
}
282253

283254
// SQL Server / Oracle UNPIVOT clause
@@ -287,16 +258,7 @@ func (p *Parser) parseJoinedTableRef(joinType string) (ast.TableReference, error
287258
return ref, err
288259
}
289260
ref.Unpivot = unpivot
290-
if p.isType(models.TokenTypeAs) {
291-
p.advance()
292-
if p.isIdentifier() {
293-
ref.Alias = p.currentToken.Token.Value
294-
p.advance()
295-
}
296-
} else if p.isIdentifier() {
297-
ref.Alias = p.currentToken.Token.Value
298-
p.advance()
299-
}
261+
p.parsePivotAlias(&ref)
300262
}
301263

302264
return ref, nil

pkg/sql/parser/tsql_test.go

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ PIVOT (
433433
if len(ref.Pivot.InValues) != 4 {
434434
t.Errorf("expected 4 IN values, got %d", len(ref.Pivot.InValues))
435435
}
436-
expected := []string{"North", "South", "East", "West"}
436+
expected := []string{"[North]", "[South]", "[East]", "[West]"}
437437
for i, v := range expected {
438438
if i < len(ref.Pivot.InValues) && ref.Pivot.InValues[i] != v {
439439
t.Errorf("IN value [%d]: expected %q, got %q", i, v, ref.Pivot.InValues[i])
@@ -540,6 +540,54 @@ func TestTSQL_PivotWithASAlias(t *testing.T) {
540540
}
541541
}
542542

543+
// TestPivotNegativeCases covers parser error paths for malformed PIVOT/UNPIVOT.
544+
func TestPivotNegativeCases(t *testing.T) {
545+
cases := []struct {
546+
name string
547+
sql string
548+
}{
549+
{"missing_lparen", "SELECT * FROM t PIVOT SUM(x) FOR c IN (a))"},
550+
{"missing_for", "SELECT * FROM t PIVOT (SUM(x) c IN (a))"},
551+
{"missing_in", "SELECT * FROM t PIVOT (SUM(x) FOR c (a))"},
552+
{"missing_in_lparen", "SELECT * FROM t PIVOT (SUM(x) FOR c IN a)"},
553+
{"empty_in_list", "SELECT * FROM t PIVOT (SUM(x) FOR c IN ())"},
554+
{"unpivot_missing_for", "SELECT * FROM t UNPIVOT (v c IN (a))"},
555+
{"unpivot_empty_in_list", "SELECT * FROM t UNPIVOT (v FOR n IN ())"},
556+
}
557+
for _, tc := range cases {
558+
t.Run(tc.name, func(t *testing.T) {
559+
_, err := ParseWithDialect(tc.sql, keywords.DialectSQLServer)
560+
if err == nil {
561+
t.Fatalf("expected parse error, got nil for: %s", tc.sql)
562+
}
563+
})
564+
}
565+
}
566+
567+
// TestPivotBracketedInValuesPreserved verifies SQL Server bracket-quoted IN
568+
// values survive parsing so the formatter can re-emit them.
569+
func TestPivotBracketedInValuesPreserved(t *testing.T) {
570+
sql := `SELECT * FROM sales PIVOT (SUM(amt) FOR region IN ([North], [South])) AS p`
571+
result, err := ParseWithDialect(sql, keywords.DialectSQLServer)
572+
if err != nil {
573+
t.Fatalf("unexpected error: %v", err)
574+
}
575+
stmt, ok := result.Statements[0].(*ast.SelectStatement)
576+
if !ok {
577+
t.Fatalf("expected SelectStatement, got %T", result.Statements[0])
578+
}
579+
got := stmt.From[0].Pivot.InValues
580+
want := []string{"[North]", "[South]"}
581+
if len(got) != len(want) {
582+
t.Fatalf("expected %d values, got %d (%v)", len(want), len(got), got)
583+
}
584+
for i := range want {
585+
if got[i] != want[i] {
586+
t.Errorf("InValues[%d] = %q, want %q", i, got[i], want[i])
587+
}
588+
}
589+
}
590+
543591
// TestPivotIdentifierInNonTSQLDialects verifies PIVOT/UNPIVOT remain valid
544592
// identifiers in dialects that don't recognize the contextual clause.
545593
// Regression for global-tokenizer-keyword leak.

pkg/sql/tokenizer/tokenizer.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1283,7 +1283,7 @@ func (t *Tokenizer) readPunctuation() (models.Token, error) {
12831283
ch, chSize := utf8.DecodeRune(t.input[t.pos.Index:])
12841284
if ch == ']' {
12851285
t.pos.AdvanceRune(ch, chSize) // Consume ]
1286-
return models.Token{Type: models.TokenTypeIdentifier, Value: string(ident)}, nil
1286+
return models.Token{Type: models.TokenTypeIdentifier, Value: string(ident), Quote: '['}, nil
12871287
}
12881288
ident = append(ident, t.input[t.pos.Index:t.pos.Index+chSize]...)
12891289
t.pos.AdvanceRune(ch, chSize)

0 commit comments

Comments
 (0)