Skip to content

Commit 8a4a735

Browse files
authored
fix(pg): preserve completion state during paren lookahead (#113)
* fix(pg): preserve completion state during paren lookahead * fix(pg): add completion-safe backtrack snapshot
1 parent eea88ff commit 8a4a735

5 files changed

Lines changed: 132 additions & 18 deletions

File tree

pg/parser/backtrack.go

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@ package parser
88
// SCOPE: this is a TOKEN-STREAM snapshot, not a complete parser/lexer
99
// snapshot. It does NOT cover mid-token-content lexer state (literalbuf,
1010
// dolqstart, utf16FirstPart, xcdepth, stateBeforeStrStop, warning flags)
11-
// or completion-mode state (candidates, collecting). Those fields are
12-
// either reset at token boundaries (lexer internals) or not used during
13-
// speculative parses (completion mode), so they don't need to be saved
14-
// here for token-stream rollback to be sound.
11+
// or completion-mode state (candidates, collecting). Lexer internals are
12+
// reset at token boundaries. Completion-mode speculative callers that can
13+
// scan past the cursor should use snapshotTokenStreamAndCompletion.
1514
//
1615
// If a future caller needs to roll back from INSIDE a token (e.g., from
1716
// inside a string literal or dollar-quoted block), this struct is
@@ -42,6 +41,13 @@ type tokenStreamState struct {
4241
lexerState LexerState
4342
}
4443

44+
type tokenStreamAndCompletionState struct {
45+
tokenStream tokenStreamState
46+
collecting bool
47+
collectDepth int
48+
candidates *CandidateSet
49+
}
50+
4551
// snapshotTokenStream captures the current token-stream position for
4652
// later restoration via restoreTokenStream. See tokenStreamState for
4753
// scope and limitations.
@@ -58,16 +64,31 @@ func (p *Parser) snapshotTokenStream() tokenStreamState {
5864
}
5965
}
6066

67+
// snapshotTokenStreamAndCompletion captures token-stream state plus the
68+
// completion state that advance() can mutate when a speculative walk crosses
69+
// the cursor. Use this for completion-mode lookahead that can scan arbitrary
70+
// user input before rolling back.
71+
func (p *Parser) snapshotTokenStreamAndCompletion() tokenStreamAndCompletionState {
72+
s := tokenStreamAndCompletionState{
73+
tokenStream: p.snapshotTokenStream(),
74+
collecting: p.collecting,
75+
collectDepth: p.collectDepth,
76+
}
77+
if p.candidates != nil {
78+
s.candidates = p.candidates.snapshot()
79+
}
80+
return s
81+
}
82+
6183
// restoreTokenStream rewinds parser + lexer state to a previously
6284
// captured snapshot. After restore, the next advance() will emit the
6385
// same token as it would have at the moment snapshotTokenStream() was
6486
// called.
6587
//
6688
// Caller responsibility: do not interleave restore with completion-mode
67-
// queries or with any operation that mutates lexer state outside the
68-
// token stream (string literal scanning, etc). The current speculative
69-
// parse sites in parseFuncArg and parseFuncType only consume keyword
70-
// tokens and punctuation, so they are safe.
89+
// queries or with any operation that mutates lexer state outside the token
90+
// stream. Use restoreTokenStreamAndCompletion for lookahead that may cross
91+
// the completion cursor.
7192
func (p *Parser) restoreTokenStream(s tokenStreamState) {
7293
p.cur = s.cur
7394
p.prev = s.prev
@@ -78,3 +99,14 @@ func (p *Parser) restoreTokenStream(s tokenStreamState) {
7899
p.lexer.start = s.lexerStart
79100
p.lexer.state = s.lexerState
80101
}
102+
103+
// restoreTokenStreamAndCompletion rewinds token-stream and completion state
104+
// captured by snapshotTokenStreamAndCompletion.
105+
func (p *Parser) restoreTokenStreamAndCompletion(s tokenStreamAndCompletionState) {
106+
p.restoreTokenStream(s.tokenStream)
107+
p.collecting = s.collecting
108+
p.collectDepth = s.collectDepth
109+
if p.candidates != nil && s.candidates != nil {
110+
p.candidates.restore(s.candidates)
111+
}
112+
}

pg/parser/backtrack_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,52 @@ func TestSnapshotRestoreIdentity(t *testing.T) {
7575
}
7676
}
7777

78+
func TestSnapshotRestoreCompletionState(t *testing.T) {
79+
sql := "SELECT * FROM (SELECT FROM t1) a"
80+
cursor := len("SELECT * FROM (SELECT ")
81+
cs := newCandidateSet()
82+
p := &Parser{
83+
lexer: NewLexer(sql),
84+
source: sql,
85+
completing: true,
86+
cursorOff: cursor,
87+
candidates: cs,
88+
}
89+
p.advance()
90+
91+
snap := p.snapshotTokenStreamAndCompletion()
92+
for p.cur.Type != lex_EOF && !p.collecting {
93+
p.advance()
94+
}
95+
if !p.collecting {
96+
t.Fatal("expected speculative walk to cross cursor")
97+
}
98+
p.addRuleCandidate("leaked_rule")
99+
p.addTokenCandidate(SELECT)
100+
p.addCTEPosition(123)
101+
p.addSelectAliasPosition(456)
102+
103+
p.restoreTokenStreamAndCompletion(snap)
104+
if p.collecting {
105+
t.Fatal("expected collecting to be restored to false")
106+
}
107+
if cs.HasRule("leaked_rule") {
108+
t.Fatal("expected rule candidates to be restored")
109+
}
110+
if cs.HasToken(SELECT) {
111+
t.Fatal("expected token candidates to be restored")
112+
}
113+
if len(cs.CTEPositions) != 0 {
114+
t.Fatalf("expected CTE positions to be restored, got %v", cs.CTEPositions)
115+
}
116+
if len(cs.SelectAliasPositions) != 0 {
117+
t.Fatalf("expected select alias positions to be restored, got %v", cs.SelectAliasPositions)
118+
}
119+
if p.cur.Type != SELECT {
120+
t.Fatalf("expected token stream to be restored to SELECT, got %d", p.cur.Type)
121+
}
122+
}
123+
78124
// walkTokens parses sql and returns the full token sequence (excluding EOF).
79125
func walkTokens(t *testing.T, sql string) []Token {
80126
t.Helper()

pg/parser/complete.go

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,10 @@ type RuleCandidate struct {
7575
// CandidateSet holds the token and rule candidates collected during a
7676
// completion-mode parse.
7777
type CandidateSet struct {
78-
Tokens []int // token type candidates
79-
Rules []RuleCandidate // grammar rule candidates
80-
seen map[int]bool // dedup tokens
81-
seenR map[string]bool // dedup rules
78+
Tokens []int // token type candidates
79+
Rules []RuleCandidate // grammar rule candidates
80+
seen map[int]bool // dedup tokens
81+
seenR map[string]bool // dedup rules
8282

8383
// CTEPositions holds the byte offsets of WITH clause starts encountered
8484
// before the cursor. Bytebase uses these to re-parse CTE definitions
@@ -246,13 +246,17 @@ func (p *Parser) addKeywordsByCategory(categories ...KeywordCategory) {
246246
// snapshot returns a copy of the current candidate set state.
247247
func (cs *CandidateSet) snapshot() *CandidateSet {
248248
s := &CandidateSet{
249-
Tokens: make([]int, len(cs.Tokens)),
250-
Rules: make([]RuleCandidate, len(cs.Rules)),
251-
seen: make(map[int]bool, len(cs.seen)),
252-
seenR: make(map[string]bool, len(cs.seenR)),
249+
Tokens: make([]int, len(cs.Tokens)),
250+
Rules: make([]RuleCandidate, len(cs.Rules)),
251+
seen: make(map[int]bool, len(cs.seen)),
252+
seenR: make(map[string]bool, len(cs.seenR)),
253+
CTEPositions: make([]int, len(cs.CTEPositions)),
254+
SelectAliasPositions: make([]int, len(cs.SelectAliasPositions)),
253255
}
254256
copy(s.Tokens, cs.Tokens)
255257
copy(s.Rules, cs.Rules)
258+
copy(s.CTEPositions, cs.CTEPositions)
259+
copy(s.SelectAliasPositions, cs.SelectAliasPositions)
256260
for k, v := range cs.seen {
257261
s.seen[k] = v
258262
}
@@ -262,6 +266,23 @@ func (cs *CandidateSet) snapshot() *CandidateSet {
262266
return s
263267
}
264268

269+
// restore replaces cs with snapshot's contents while preserving cs's identity.
270+
func (cs *CandidateSet) restore(snapshot *CandidateSet) {
271+
cs.Tokens = append(cs.Tokens[:0], snapshot.Tokens...)
272+
cs.Rules = append(cs.Rules[:0], snapshot.Rules...)
273+
cs.CTEPositions = append(cs.CTEPositions[:0], snapshot.CTEPositions...)
274+
cs.SelectAliasPositions = append(cs.SelectAliasPositions[:0], snapshot.SelectAliasPositions...)
275+
276+
cs.seen = make(map[int]bool, len(snapshot.seen))
277+
for k, v := range snapshot.seen {
278+
cs.seen[k] = v
279+
}
280+
cs.seenR = make(map[string]bool, len(snapshot.seenR))
281+
for k, v := range snapshot.seenR {
282+
cs.seenR[k] = v
283+
}
284+
}
285+
265286
// diff returns candidates in cs that are not in before.
266287
func (cs *CandidateSet) diff(before *CandidateSet) *CandidateSet {
267288
d := newCandidateSet()

pg/parser/complete_test.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,21 @@ func TestCollectAfterSelect(t *testing.T) {
7272
}
7373
}
7474

75+
func TestCollectNestedFromSubqueryAfterSelect(t *testing.T) {
76+
prefix := "SELECT * FROM (SELECT * FROM (SELECT "
77+
sql := prefix + " FROM t1) a) b"
78+
candidates := Collect(sql, len(prefix))
79+
if candidates == nil {
80+
t.Fatal("expected non-nil candidates")
81+
}
82+
if !candidates.HasRule("columnref") {
83+
t.Error("expected columnref rule candidate in nested SELECT target list")
84+
}
85+
if !candidates.HasRule("func_name") {
86+
t.Error("expected func_name rule candidate in nested SELECT target list")
87+
}
88+
}
89+
7590
func TestCollectAfterFrom(t *testing.T) {
7691
candidates := Collect("SELECT 1 FROM ", 14)
7792
if candidates == nil {

pg/parser/select.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,8 +1268,8 @@ func (p *Parser) parenBeginsSubquery() bool {
12681268
if p.cur.Type != '(' {
12691269
return false
12701270
}
1271-
snap := p.snapshotTokenStream()
1272-
defer p.restoreTokenStream(snap)
1271+
snap := p.snapshotTokenStreamAndCompletion()
1272+
defer p.restoreTokenStreamAndCompletion(snap)
12731273
return p.consumeMatchedParenIsSubquery()
12741274
}
12751275

0 commit comments

Comments
 (0)