Skip to content

Commit 32c873c

Browse files
committed
fix CTE autocompletion
1 parent 10053f2 commit 32c873c

5 files changed

Lines changed: 928 additions & 13 deletions

File tree

src/autocomplete/content-assist.ts

Lines changed: 322 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ export interface ContentAssistResult {
6262
nextTokenTypes: TokenType[]
6363
/** Tables/aliases found in the query (for column suggestions) */
6464
tablesInScope: TableRef[]
65+
/** Columns from CTEs, keyed by CTE name (lowercase) */
66+
cteColumns: Record<string, { name: string; type: string }[]>
6567
/** The tokens before the cursor */
6668
tokensBefore: IToken[]
6769
/** Whether the cursor is in the middle of a word (partial token being typed) */
@@ -94,8 +96,62 @@ function normalizeTableName(value: unknown): string | undefined {
9496
return undefined
9597
}
9698

97-
function extractTablesFromAst(ast: unknown): TableRef[] {
99+
interface ExtractResult {
100+
tables: TableRef[]
101+
cteColumns: Record<string, { name: string; type: string }[]>
102+
}
103+
104+
/**
105+
* Infer column names from a CTE's inner SELECT columns.
106+
* Uses alias if present, otherwise derives name from expression.
107+
*/
108+
function extractCteColumnNames(
109+
columns: unknown[],
110+
): { name: string; type: string }[] {
111+
const result: { name: string; type: string }[] = []
112+
for (const col of columns) {
113+
if (!col || typeof col !== "object") continue
114+
const c = col as Record<string, unknown>
115+
116+
// StarSelectItem or QualifiedStarSelectItem — can't resolve without schema
117+
if (c.type === "star" || c.type === "qualifiedStar") continue
118+
119+
// ExpressionSelectItem
120+
if (c.type === "selectItem") {
121+
// Prefer alias
122+
if (typeof c.alias === "string") {
123+
result.push({ name: c.alias, type: "" })
124+
continue
125+
}
126+
127+
const expr = c.expression as Record<string, unknown> | undefined
128+
if (!expr) continue
129+
130+
// Column reference → use last part of qualified name
131+
if (expr.type === "column" && expr.name) {
132+
const qn = expr.name as Record<string, unknown>
133+
if (Array.isArray(qn.parts) && qn.parts.length > 0) {
134+
result.push({
135+
name: qn.parts[qn.parts.length - 1] as string,
136+
type: "",
137+
})
138+
continue
139+
}
140+
}
141+
142+
// Function call → use function name
143+
if (expr.type === "function" && typeof expr.name === "string") {
144+
result.push({ name: expr.name, type: "" })
145+
continue
146+
}
147+
}
148+
}
149+
return result
150+
}
151+
152+
function extractTablesFromAst(ast: unknown): ExtractResult {
98153
const tables: TableRef[] = []
154+
const cteColumns: Record<string, { name: string; type: string }[]> = {}
99155
const visited = new WeakSet()
100156

101157
function visit(node: unknown) {
@@ -105,6 +161,23 @@ function extractTablesFromAst(ast: unknown): TableRef[] {
105161

106162
const n = node as Record<string, unknown>
107163

164+
// Handle CTE definitions — surface CTE names as virtual tables
165+
// and extract column names from the inner SELECT
166+
if (n.type === "cte" && typeof n.name === "string") {
167+
tables.push({ table: n.name })
168+
const query = n.query as Record<string, unknown> | undefined
169+
if (query && Array.isArray(query.columns)) {
170+
const cols = extractCteColumnNames(query.columns)
171+
cteColumns[n.name.toLowerCase()] = cols
172+
} else {
173+
// Register CTE name even if no columns are extractable (e.g., SELECT *)
174+
cteColumns[n.name.toLowerCase()] = []
175+
}
176+
// Don't recurse into the CTE's inner query — its tables are not
177+
// in the outer scope. We already extracted the CTE name and columns.
178+
return
179+
}
180+
108181
// Handle table references in FROM clause
109182
if (n.type === "tableRef") {
110183
const tableName = normalizeTableName(n.table ?? n.name)
@@ -181,14 +254,196 @@ function extractTablesFromAst(ast: unknown): TableRef[] {
181254
visit(ast)
182255
}
183256

184-
return tables
257+
return { tables, cteColumns }
258+
}
259+
260+
/**
261+
* Extract CTE column information from the token stream by finding CTE
262+
* boundaries and parsing each inner SELECT independently.
263+
*
264+
* This is used when the full SQL fails to parse (incomplete outer query)
265+
* but the CTE definitions themselves are complete.
266+
*/
267+
interface CteTokenExtractResult {
268+
cteNames: string[]
269+
cteColumns: Record<string, { name: string; type: string }[]>
270+
/** Token index where the CTE block ends (first token of the outer query) */
271+
outerQueryStart: number
272+
}
273+
274+
/**
275+
* Extract CTE names and column information from the token stream by finding
276+
* CTE boundaries and parsing each inner SELECT independently.
277+
*
278+
* This is used when the full SQL fails to parse (incomplete outer query)
279+
* but the CTE definitions themselves are complete.
280+
*/
281+
function extractCtesFromTokens(
282+
fullSql: string,
283+
tokens: IToken[],
284+
): CteTokenExtractResult {
285+
const cteNames: string[] = []
286+
const cteColumns: Record<string, { name: string; type: string }[]> = {}
287+
if (tokens.length === 0 || tokens[0].tokenType.name !== "With")
288+
return { cteNames, cteColumns, outerQueryStart: 0 }
289+
290+
const isIdent = (t: IToken | undefined): boolean =>
291+
!!t &&
292+
(t.tokenType.name === "Identifier" ||
293+
t.tokenType.name === "QuotedIdentifier" ||
294+
IDENTIFIER_KEYWORD_TOKENS.has(t.tokenType.name))
295+
296+
let i = 1 // skip WITH
297+
while (i < tokens.length) {
298+
// Expect: cteName AS (
299+
if (!isIdent(tokens[i])) break
300+
const cteName = tokens[i].image
301+
i++
302+
if (tokens[i]?.tokenType.name !== "As") break
303+
i++
304+
if (tokens[i]?.tokenType.name !== "LParen") break
305+
const innerStart = i + 1
306+
i++
307+
308+
// Find matching RParen
309+
let depth = 1
310+
while (i < tokens.length && depth > 0) {
311+
if (tokens[i].tokenType.name === "LParen") depth++
312+
if (tokens[i].tokenType.name === "RParen") depth--
313+
i++
314+
}
315+
316+
// Always register the CTE name, even if we can't extract columns
317+
cteNames.push(cteName)
318+
// Initialize with empty array; will be populated below if columns are found
319+
cteColumns[cteName.toLowerCase()] = []
320+
321+
if (depth === 0) {
322+
// tokens[innerStart..i-2] is the inner SELECT body
323+
const innerEnd = i - 1 // RParen index
324+
if (innerStart < innerEnd) {
325+
const innerSql = fullSql.substring(
326+
tokens[innerStart].startOffset,
327+
tokens[innerEnd - 1].startOffset + tokens[innerEnd - 1].image.length,
328+
)
329+
try {
330+
const { cst } = parseRaw(innerSql)
331+
const ast = visitor.visit(cst) as Statement[]
332+
if (ast && ast.length > 0) {
333+
const stmt = ast[0] as unknown as Record<string, unknown>
334+
if (Array.isArray(stmt.columns)) {
335+
const cols = extractCteColumnNames(stmt.columns)
336+
if (cols.length > 0) {
337+
cteColumns[cteName.toLowerCase()] = cols
338+
}
339+
}
340+
}
341+
} catch {
342+
// Inner SELECT parse failed, skip this CTE's columns
343+
}
344+
}
345+
} else {
346+
// Unclosed paren — CTE body is incomplete, can't extract columns
347+
break
348+
}
349+
350+
// After RParen, expect Comma (another CTE) or SELECT/INSERT/UPDATE
351+
if (tokens[i]?.tokenType.name === "Comma") {
352+
i++ // next CTE
353+
continue
354+
}
355+
break
356+
}
357+
358+
return { cteNames, cteColumns, outerQueryStart: i }
359+
}
360+
361+
// =============================================================================
362+
// CTE cursor detection
363+
// =============================================================================
364+
365+
interface CteBodyContext {
366+
/** Name of the CTE containing the cursor (unquoted) */
367+
name: string
368+
/** Token index of the first token inside the CTE body (after LParen) */
369+
bodyTokenStart: number
370+
/** Token index of the RParen (or tokens.length if unclosed) */
371+
bodyTokenEnd: number
372+
}
373+
374+
/**
375+
* Detect whether the cursor is inside a CTE body. Returns the CTE name and
376+
* body token range so callers can exclude self-references and extract inner
377+
* table references.
378+
*/
379+
function findCteContainingCursor(
380+
tokens: IToken[],
381+
cursorOffset: number,
382+
): CteBodyContext | null {
383+
if (tokens.length === 0 || tokens[0].tokenType.name !== "With") return null
384+
385+
const isIdent = (t: IToken | undefined): boolean =>
386+
!!t &&
387+
(t.tokenType.name === "Identifier" ||
388+
t.tokenType.name === "QuotedIdentifier" ||
389+
IDENTIFIER_KEYWORD_TOKENS.has(t.tokenType.name))
390+
391+
let i = 1 // skip WITH
392+
while (i < tokens.length) {
393+
if (!isIdent(tokens[i])) break
394+
const rawName = tokens[i].image
395+
const cteName =
396+
tokens[i].tokenType.name === "QuotedIdentifier"
397+
? rawName.slice(1, -1)
398+
: rawName
399+
i++
400+
if (tokens[i]?.tokenType.name !== "As") break
401+
i++
402+
if (tokens[i]?.tokenType.name !== "LParen") break
403+
const lparenOffset = tokens[i].startOffset
404+
i++
405+
const bodyTokenStart = i
406+
407+
// Find matching RParen
408+
let depth = 1
409+
while (i < tokens.length && depth > 0) {
410+
if (tokens[i].tokenType.name === "LParen") depth++
411+
if (tokens[i].tokenType.name === "RParen") depth--
412+
i++
413+
}
414+
415+
if (depth === 0) {
416+
// i is past RParen; tokens[i-1] is RParen
417+
const rparenIdx = i - 1
418+
const rparenEndOffset =
419+
tokens[rparenIdx].startOffset + tokens[rparenIdx].image.length
420+
if (cursorOffset > lparenOffset && cursorOffset < rparenEndOffset) {
421+
return { name: cteName, bodyTokenStart, bodyTokenEnd: rparenIdx }
422+
}
423+
} else {
424+
// Unclosed paren — if cursor is after LParen, it's inside this CTE
425+
if (cursorOffset > lparenOffset) {
426+
return { name: cteName, bodyTokenStart, bodyTokenEnd: tokens.length }
427+
}
428+
break
429+
}
430+
431+
// After RParen, expect Comma or end of CTEs
432+
if (tokens[i]?.tokenType.name === "Comma") {
433+
i++
434+
continue
435+
}
436+
break
437+
}
438+
439+
return null
185440
}
186441

187442
/**
188443
* Try to extract tables by parsing the query.
189444
* If parsing fails, try to extract from tokens.
190445
*/
191-
function extractTables(fullSql: string, tokens: IToken[]): TableRef[] {
446+
function extractTables(fullSql: string, tokens: IToken[]): ExtractResult {
192447
// First, try to parse and extract from AST
193448
try {
194449
const { cst } = parseRaw(fullSql)
@@ -240,7 +495,17 @@ function extractTables(fullSql: string, tokens: IToken[]): TableRef[] {
240495
return { name: parts[parts.length - 1], nextIndex: i }
241496
}
242497

243-
for (let i = 0; i < tokens.length; i++) {
498+
// Extract CTE names and columns from the token stream. CTE definitions are
499+
// usually complete even when the outer query is incomplete, so parse each
500+
// inner SELECT independently.
501+
const { cteNames, cteColumns, outerQueryStart } = extractCtesFromTokens(
502+
fullSql,
503+
tokens,
504+
)
505+
506+
// Scan for FROM/JOIN table references only in the outer query (after CTEs).
507+
// This avoids leaking tables referenced inside CTE bodies into the outer scope.
508+
for (let i = outerQueryStart; i < tokens.length; i++) {
244509
if (!TABLE_PREFIX_TOKENS.has(tokens[i].tokenType.name)) continue
245510

246511
const tableNameResult = readQualifiedName(i + 1)
@@ -263,8 +528,11 @@ function extractTables(fullSql: string, tokens: IToken[]): TableRef[] {
263528
// Continue from where we consumed table/alias to avoid duplicate captures.
264529
i = alias ? aliasStart : tableNameResult.nextIndex - 1
265530
}
531+
for (const name of cteNames) {
532+
tables.push({ table: name })
533+
}
266534

267-
return tables
535+
return { tables, cteColumns }
268536
}
269537

270538
// =============================================================================
@@ -460,6 +728,7 @@ export function getContentAssist(
460728
return {
461729
nextTokenTypes: [],
462730
tablesInScope: [],
731+
cteColumns: {},
463732
tokensBefore: [],
464733
isMidWord: true,
465734
lexErrors: [],
@@ -496,8 +765,53 @@ export function getContentAssist(
496765
// This can happen with malformed input
497766
}
498767

499-
// Extract tables from the full query (reuses fullTokens from above)
500-
const tablesInScope = extractTables(fullSql, fullTokens)
768+
// Extract tables and CTE columns from the full query (reuses fullTokens from above)
769+
const { tables: tablesInScope, cteColumns } = extractTables(
770+
fullSql,
771+
fullTokens,
772+
)
773+
774+
// If cursor is inside a CTE body, exclude the CTE itself from scope
775+
// to prevent self-reference (e.g., "WITH x AS (SELECT |)" shouldn't
776+
// suggest x's own columns). Also extract tables from the CTE body so
777+
// columns from the inner FROM/JOIN are available.
778+
const cursorCte = findCteContainingCursor(fullTokens, cursorOffset)
779+
if (cursorCte) {
780+
const cteNameLower = cursorCte.name.toLowerCase()
781+
// Remove self-reference from tablesInScope
782+
for (let j = tablesInScope.length - 1; j >= 0; j--) {
783+
if (tablesInScope[j].table.toLowerCase() === cteNameLower) {
784+
tablesInScope.splice(j, 1)
785+
}
786+
}
787+
// Remove self-reference from cteColumns
788+
delete cteColumns[cteNameLower]
789+
790+
// Extract tables from the CTE body tokens so inner FROM/JOIN tables
791+
// are available for column scoping.
792+
const BODY_TABLE_PREFIXES = new Set(["From", "Join", "Update", "Into"])
793+
const seen = new Set(tablesInScope.map((t) => t.table.toLowerCase()))
794+
for (let j = cursorCte.bodyTokenStart; j < cursorCte.bodyTokenEnd; j++) {
795+
if (!BODY_TABLE_PREFIXES.has(fullTokens[j].tokenType.name)) continue
796+
const next = fullTokens[j + 1]
797+
if (
798+
next &&
799+
(next.tokenType.name === "Identifier" ||
800+
next.tokenType.name === "QuotedIdentifier" ||
801+
IDENTIFIER_KEYWORD_TOKENS.has(next.tokenType.name))
802+
) {
803+
const tableName =
804+
next.tokenType.name === "QuotedIdentifier"
805+
? next.image.slice(1, -1)
806+
: next.image
807+
const lower = tableName.toLowerCase()
808+
if (!seen.has(lower)) {
809+
seen.add(lower)
810+
tablesInScope.push({ table: tableName })
811+
}
812+
}
813+
}
814+
}
501815

502816
if (tablesInScope.length === 0) {
503817
const inferred = inferTableFromQualifiedRef(tokens, isMidWord)
@@ -507,6 +821,7 @@ export function getContentAssist(
507821
return {
508822
nextTokenTypes,
509823
tablesInScope,
824+
cteColumns,
510825
tokensBefore: tokens,
511826
isMidWord,
512827
lexErrors: lexResult.errors,

0 commit comments

Comments
 (0)