Skip to content

Commit a9a51be

Browse files
Ajit Pratap Singhclaude
authored andcommitted
fix(mariadb): address second code review pass — Pos, NO CACHE, CONNECT BY, dedup
Parser dispatch (parser.go, ddl.go, select.go): - Populate Pos on CreateSequenceStatement (at SEQUENCE token in ddl.go) - Populate Pos on DropSequenceStatement (at DROP token in parser.go) - Populate Pos on AlterSequenceStatement (at ALTER token in parser.go) - Populate Pos on ConnectByClause (at CONNECT token in select.go) - Populate Pos on PeriodDefinition (at PERIOD token in ddl.go) mariadb.go: - Fix NO CACHE (two-token) to also set opts.NoCache=true, matching NOCACHE - Fix parseConnectByCondition to handle complex AND/OR chains: CONNECT BY PRIOR id = parent_id AND active = 1 now fully parsed - Extract isMariaDBClauseStart() method (was duplicated closure in two functions) - Populate Pos on ForSystemTimeClause (at SYSTEM_TIME token) - Add comment clarifying IF NOT EXISTS is a non-standard permissive extension select_subquery.go: - Remove both isMariaDBClauseKeyword closures, replace with p.isMariaDBClauseStart() ast.go: - Update DropSequenceStatement doc to show [IF EXISTS | IF NOT EXISTS] Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent f003e2b commit a9a51be

6 files changed

Lines changed: 110 additions & 66 deletions

File tree

pkg/sql/ast/ast.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1878,7 +1878,7 @@ func (s *CreateSequenceStatement) Children() []Node {
18781878

18791879
// DropSequenceStatement represents:
18801880
//
1881-
// DROP SEQUENCE [IF EXISTS] name
1881+
// DROP SEQUENCE [IF EXISTS | IF NOT EXISTS] name
18821882
type DropSequenceStatement struct {
18831883
Name *Identifier
18841884
IfExists bool

pkg/sql/parser/ddl.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,16 @@ func (p *Parser) parseCreateStatement() (ast.Statement, error) {
8383
p.advance() // Consume INDEX
8484
return p.parseCreateIndex(true) // Unique
8585
} else if p.isMariaDB() && p.isTokenMatch("SEQUENCE") {
86-
p.advance() // Consume SEQUENCE
87-
return p.parseCreateSequenceStatement(orReplace)
86+
seqPos := p.currentLocation() // position of SEQUENCE token
87+
p.advance() // Consume SEQUENCE
88+
stmt, err := p.parseCreateSequenceStatement(orReplace)
89+
if err != nil {
90+
return nil, err
91+
}
92+
if stmt.Pos.IsZero() {
93+
stmt.Pos = seqPos
94+
}
95+
return stmt, nil
8896
}
8997
return nil, p.expectedError("TABLE, VIEW, MATERIALIZED VIEW, or INDEX after CREATE")
9098
}
@@ -126,10 +134,12 @@ func (p *Parser) parseCreateTable(temporary bool) (*ast.CreateTableStatement, er
126134
for {
127135
// MariaDB: PERIOD FOR name (start_col, end_col) — application-time or system-time period
128136
if p.isMariaDB() && p.isTokenMatch("PERIOD") {
137+
periodPos := p.currentLocation() // position of PERIOD keyword
129138
pd, err := p.parsePeriodDefinition()
130139
if err != nil {
131140
return nil, err
132141
}
142+
pd.Pos = periodPos
133143
stmt.PeriodDefinitions = append(stmt.PeriodDefinitions, pd)
134144
} else if p.isAnyType(models.TokenTypePrimary, models.TokenTypeForeign,
135145
models.TokenTypeUnique, models.TokenTypeCheck, models.TokenTypeConstraint) {

pkg/sql/parser/mariadb.go

Lines changed: 73 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,25 @@ func (p *Parser) isMariaDB() bool {
2828
return p.dialect == string(keywords.DialectMariaDB)
2929
}
3030

31+
// isMariaDBClauseStart returns true when the current token is the start of a
32+
// MariaDB hierarchical-query clause (CONNECT BY or START WITH) rather than a
33+
// table alias. Used to guard alias parsing in FROM and JOIN table references.
34+
func (p *Parser) isMariaDBClauseStart() bool {
35+
if !p.isMariaDB() {
36+
return false
37+
}
38+
val := strings.ToUpper(p.currentToken.Token.Value)
39+
if val == "CONNECT" {
40+
next := p.peekToken()
41+
return strings.EqualFold(next.Token.Value, "BY")
42+
}
43+
if val == "START" {
44+
next := p.peekToken()
45+
return strings.EqualFold(next.Token.Value, "WITH")
46+
}
47+
return false
48+
}
49+
3150
// parseCreateSequenceStatement parses:
3251
//
3352
// CREATE [OR REPLACE] SEQUENCE [IF NOT EXISTS] name [options...]
@@ -73,7 +92,9 @@ func (p *Parser) parseDropSequenceStatement() (*ast.DropSequenceStatement, error
7392
if strings.EqualFold(p.currentToken.Token.Value, "IF") {
7493
p.advance()
7594
if strings.EqualFold(p.currentToken.Token.Value, "NOT") {
76-
// IF NOT EXISTS — treated as "no error if absent" (same semantics as IF EXISTS)
95+
// IF NOT EXISTS is a non-standard permissive extension (MariaDB only supports
96+
// IF EXISTS natively). We accept it and reuse the IfExists flag since both
97+
// forms mean "suppress the error if the sequence is absent".
7798
p.advance()
7899
if !strings.EqualFold(p.currentToken.Token.Value, "EXISTS") {
79100
return nil, p.expectedError("EXISTS")
@@ -181,6 +202,7 @@ func (p *Parser) parseSequenceOptions() (ast.SequenceOptions, error) {
181202
opts.NoCycle = true
182203
case "CACHE":
183204
opts.Cache = nil
205+
opts.NoCache = true
184206
default:
185207
return opts, fmt.Errorf("unexpected token after NO in SEQUENCE options: %s", sub)
186208
}
@@ -239,9 +261,11 @@ func (p *Parser) parseForSystemTimeClause() (*ast.ForSystemTimeClause, error) {
239261
if !strings.EqualFold(p.currentToken.Token.Value, "SYSTEM_TIME") {
240262
return nil, fmt.Errorf("expected SYSTEM_TIME after FOR, got %q", p.currentToken.Token.Value)
241263
}
264+
sysTimePos := p.currentLocation() // position of SYSTEM_TIME token
242265
p.advance()
243266

244267
clause := &ast.ForSystemTimeClause{}
268+
clause.Pos = sysTimePos
245269
word := strings.ToUpper(p.currentToken.Token.Value)
246270

247271
switch word {
@@ -327,12 +351,15 @@ func (p *Parser) parseTemporalPointExpression() (ast.Expression, error) {
327351
// parseConnectByCondition parses the condition expression for CONNECT BY.
328352
// It handles the PRIOR prefix operator in either position:
329353
//
330-
// CONNECT BY PRIOR id = parent_id (PRIOR on left)
331-
// CONNECT BY id = PRIOR parent_id (PRIOR on right)
354+
// CONNECT BY PRIOR id = parent_id (PRIOR on left)
355+
// CONNECT BY id = PRIOR parent_id (PRIOR on right)
356+
// CONNECT BY PRIOR id = parent_id AND active = 1 (complex with AND/OR)
332357
//
333358
// PRIOR references the value from the parent row in the hierarchy.
334359
// It is modeled as UnaryExpression{Operator: ast.Prior, Expr: <column>}.
335360
func (p *Parser) parseConnectByCondition() (ast.Expression, error) {
361+
var base ast.Expression
362+
336363
// Case 1: PRIOR col op col
337364
if strings.EqualFold(p.currentToken.Token.Value, "PRIOR") {
338365
p.advance()
@@ -351,38 +378,57 @@ func (p *Parser) parseConnectByCondition() (ast.Expression, error) {
351378
if err != nil {
352379
return nil, err
353380
}
354-
return &ast.BinaryExpression{Left: priorExpr, Operator: op, Right: right}, nil
381+
base = &ast.BinaryExpression{Left: priorExpr, Operator: op, Right: right}
382+
} else {
383+
base = priorExpr
355384
}
356-
return priorExpr, nil
357-
}
358-
359-
// Case 2: col op PRIOR col (PRIOR on the right-hand side)
360-
left, err := p.parsePrimaryExpression()
361-
if err != nil {
362-
return nil, err
363-
}
364-
if p.isType(models.TokenTypeEq) || p.isType(models.TokenTypeNeq) ||
365-
p.isType(models.TokenTypeLt) || p.isType(models.TokenTypeGt) ||
366-
p.isType(models.TokenTypeLtEq) || p.isType(models.TokenTypeGtEq) {
367-
op := p.currentToken.Token.Value
368-
p.advance()
369-
// Check for PRIOR on the right side
370-
if strings.EqualFold(p.currentToken.Token.Value, "PRIOR") {
385+
} else {
386+
// Case 2: col op PRIOR col (PRIOR on the right-hand side)
387+
// or plain expression (no PRIOR)
388+
left, err := p.parsePrimaryExpression()
389+
if err != nil {
390+
return nil, err
391+
}
392+
if p.isType(models.TokenTypeEq) || p.isType(models.TokenTypeNeq) ||
393+
p.isType(models.TokenTypeLt) || p.isType(models.TokenTypeGt) ||
394+
p.isType(models.TokenTypeLtEq) || p.isType(models.TokenTypeGtEq) {
395+
op := p.currentToken.Token.Value
371396
p.advance()
372-
priorIdent := p.parseIdent()
373-
if priorIdent == nil || priorIdent.Name == "" {
374-
return nil, p.expectedError("column name after PRIOR")
397+
// Check for PRIOR on the right side
398+
if strings.EqualFold(p.currentToken.Token.Value, "PRIOR") {
399+
p.advance()
400+
priorIdent := p.parseIdent()
401+
if priorIdent == nil || priorIdent.Name == "" {
402+
return nil, p.expectedError("column name after PRIOR")
403+
}
404+
priorExpr := &ast.UnaryExpression{Operator: ast.Prior, Expr: priorIdent}
405+
base = &ast.BinaryExpression{Left: left, Operator: op, Right: priorExpr}
406+
} else {
407+
right, err := p.parsePrimaryExpression()
408+
if err != nil {
409+
return nil, err
410+
}
411+
base = &ast.BinaryExpression{Left: left, Operator: op, Right: right}
375412
}
376-
priorExpr := &ast.UnaryExpression{Operator: ast.Prior, Expr: priorIdent}
377-
return &ast.BinaryExpression{Left: left, Operator: op, Right: priorExpr}, nil
413+
} else {
414+
base = left
378415
}
379-
right, err := p.parsePrimaryExpression()
416+
}
417+
418+
// Handle AND/OR chaining for complex conditions like:
419+
// PRIOR id = parent_id AND active = 1
420+
for strings.EqualFold(p.currentToken.Token.Value, "AND") ||
421+
strings.EqualFold(p.currentToken.Token.Value, "OR") {
422+
logicOp := p.currentToken.Token.Value
423+
p.advance()
424+
rest, err := p.parseConnectByCondition()
380425
if err != nil {
381426
return nil, err
382427
}
383-
return &ast.BinaryExpression{Left: left, Operator: op, Right: right}, nil
428+
base = &ast.BinaryExpression{Left: base, Operator: logicOp, Right: rest}
384429
}
385-
return left, nil
430+
431+
return base, nil
386432
}
387433

388434
// parsePeriodDefinition parses: PERIOD FOR name (start_col, end_col)

pkg/sql/parser/parser.go

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -629,11 +629,19 @@ func (p *Parser) parseStatement() (ast.Statement, error) {
629629
}
630630
return stmt, nil
631631
case models.TokenTypeAlter:
632+
stmtPos := p.currentLocation()
632633
p.advance()
633634
// MariaDB: ALTER SEQUENCE [IF EXISTS] name [options...]
634635
if p.isMariaDB() && p.isTokenMatch("SEQUENCE") {
635636
p.advance() // Consume SEQUENCE
636-
return p.parseAlterSequenceStatement()
637+
stmt, err := p.parseAlterSequenceStatement()
638+
if err != nil {
639+
return nil, err
640+
}
641+
if stmt.Pos.IsZero() {
642+
stmt.Pos = stmtPos
643+
}
644+
return stmt, nil
637645
}
638646
return p.parseAlterTableStmt()
639647
case models.TokenTypeMerge:
@@ -643,11 +651,19 @@ func (p *Parser) parseStatement() (ast.Statement, error) {
643651
p.advance()
644652
return p.parseCreateStatement()
645653
case models.TokenTypeDrop:
654+
stmtPos := p.currentLocation()
646655
p.advance()
647-
// MariaDB: DROP SEQUENCE [IF EXISTS] name
656+
// MariaDB: DROP SEQUENCE [IF EXISTS | IF NOT EXISTS] name
648657
if p.isMariaDB() && p.isTokenMatch("SEQUENCE") {
649658
p.advance() // Consume SEQUENCE
650-
return p.parseDropSequenceStatement()
659+
stmt, err := p.parseDropSequenceStatement()
660+
if err != nil {
661+
return nil, err
662+
}
663+
if stmt.Pos.IsZero() {
664+
stmt.Pos = stmtPos
665+
}
666+
return stmt, nil
651667
}
652668
return p.parseDropStatement()
653669
case models.TokenTypeRefresh:

pkg/sql/parser/select.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,14 @@ func (p *Parser) parseSelectStatement() (ast.Statement, error) {
124124
selectStmt.StartWith = startExpr
125125
}
126126
if strings.EqualFold(p.currentToken.Token.Value, "CONNECT") {
127-
p.advance() // Consume CONNECT
127+
connectPos := p.currentLocation() // position of CONNECT keyword
128+
p.advance() // Consume CONNECT
128129
if !strings.EqualFold(p.currentToken.Token.Value, "BY") {
129130
return nil, fmt.Errorf("expected BY after CONNECT, got %q", p.currentToken.Token.Value)
130131
}
131132
p.advance() // Consume BY
132133
cb := &ast.ConnectByClause{}
134+
cb.Pos = connectPos
133135
if strings.EqualFold(p.currentToken.Token.Value, "NOCYCLE") {
134136
cb.NoCycle = true
135137
p.advance() // Consume NOCYCLE

pkg/sql/parser/select_subquery.go

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -87,22 +87,7 @@ func (p *Parser) parseFromTableReference() (ast.TableReference, error) {
8787
// Check for table alias (required for derived tables, optional for regular tables).
8888
// Guard: in MariaDB, CONNECT followed by BY is a hierarchical query clause, not an alias.
8989
// Similarly, START followed by WITH is a hierarchical query seed, not an alias.
90-
isMariaDBClauseKeyword := func() bool {
91-
if !p.isMariaDB() {
92-
return false
93-
}
94-
val := strings.ToUpper(p.currentToken.Token.Value)
95-
if val == "CONNECT" {
96-
next := p.peekToken()
97-
return strings.EqualFold(next.Token.Value, "BY")
98-
}
99-
if val == "START" {
100-
next := p.peekToken()
101-
return strings.EqualFold(next.Token.Value, "WITH")
102-
}
103-
return false
104-
}
105-
if (p.isIdentifier() || p.isType(models.TokenTypeAs)) && !isMariaDBClauseKeyword() {
90+
if (p.isIdentifier() || p.isType(models.TokenTypeAs)) && !p.isMariaDBClauseStart() {
10691
if p.isType(models.TokenTypeAs) {
10792
p.advance() // Consume AS
10893
if !p.isIdentifier() {
@@ -194,22 +179,7 @@ func (p *Parser) parseJoinedTableRef(joinType string) (ast.TableReference, error
194179
// Optional alias.
195180
// Guard: in MariaDB, CONNECT followed by BY is a hierarchical query clause, not an alias.
196181
// Similarly, START followed by WITH is a hierarchical query seed, not an alias.
197-
isMariaDBClauseKeyword := func() bool {
198-
if !p.isMariaDB() {
199-
return false
200-
}
201-
val := strings.ToUpper(p.currentToken.Token.Value)
202-
if val == "CONNECT" {
203-
next := p.peekToken()
204-
return strings.EqualFold(next.Token.Value, "BY")
205-
}
206-
if val == "START" {
207-
next := p.peekToken()
208-
return strings.EqualFold(next.Token.Value, "WITH")
209-
}
210-
return false
211-
}
212-
if (p.isIdentifier() || p.isType(models.TokenTypeAs)) && !isMariaDBClauseKeyword() {
182+
if (p.isIdentifier() || p.isType(models.TokenTypeAs)) && !p.isMariaDBClauseStart() {
213183
if p.isType(models.TokenTypeAs) {
214184
p.advance()
215185
if !p.isIdentifier() {

0 commit comments

Comments
 (0)