Skip to content

Commit 00d30f7

Browse files
Feat(lsp): Add rename functionality for CTEs in vscode (#4718)
1 parent e593492 commit 00d30f7

5 files changed

Lines changed: 736 additions & 0 deletions

File tree

sqlmesh/lsp/main.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
get_references,
5353
get_all_references,
5454
)
55+
from sqlmesh.lsp.rename import prepare_rename, rename_symbol, get_document_highlights
5556
from sqlmesh.lsp.uri import URI
5657
from web.server.api.endpoints.lineage import column_lineage, model_lineage
5758
from web.server.api.endpoints.models import get_models
@@ -435,6 +436,59 @@ def find_references(
435436
ls.show_message(f"Error getting locations: {e}", types.MessageType.Error)
436437
return None
437438

439+
@self.server.feature(types.TEXT_DOCUMENT_PREPARE_RENAME)
440+
def prepare_rename_handler(
441+
ls: LanguageServer, params: types.PrepareRenameParams
442+
) -> t.Optional[types.PrepareRenameResult]:
443+
"""Prepare for rename operation by checking if the symbol can be renamed."""
444+
try:
445+
uri = URI(params.text_document.uri)
446+
self._ensure_context_for_document(uri)
447+
if self.lsp_context is None:
448+
raise RuntimeError(f"No context found for document: {uri}")
449+
450+
result = prepare_rename(self.lsp_context, uri, params.position)
451+
return result
452+
except Exception as e:
453+
ls.log_trace(f"Error preparing rename: {e}")
454+
return None
455+
456+
@self.server.feature(types.TEXT_DOCUMENT_RENAME)
457+
def rename_handler(
458+
ls: LanguageServer, params: types.RenameParams
459+
) -> t.Optional[types.WorkspaceEdit]:
460+
"""Perform rename operation on the symbol at the given position."""
461+
try:
462+
uri = URI(params.text_document.uri)
463+
self._ensure_context_for_document(uri)
464+
if self.lsp_context is None:
465+
raise RuntimeError(f"No context found for document: {uri}")
466+
467+
workspace_edit = rename_symbol(
468+
self.lsp_context, uri, params.position, params.new_name
469+
)
470+
return workspace_edit
471+
except Exception as e:
472+
ls.show_message(f"Error performing rename: {e}", types.MessageType.Error)
473+
return None
474+
475+
@self.server.feature(types.TEXT_DOCUMENT_DOCUMENT_HIGHLIGHT)
476+
def document_highlight_handler(
477+
ls: LanguageServer, params: types.DocumentHighlightParams
478+
) -> t.Optional[t.List[types.DocumentHighlight]]:
479+
"""Highlight all occurrences of the symbol at the given position."""
480+
try:
481+
uri = URI(params.text_document.uri)
482+
self._ensure_context_for_document(uri)
483+
if self.lsp_context is None:
484+
raise RuntimeError(f"No context found for document: {uri}")
485+
486+
highlights = get_document_highlights(self.lsp_context, uri, params.position)
487+
return highlights
488+
except Exception as e:
489+
ls.log_trace(f"Error getting document highlights: {e}")
490+
return None
491+
438492
@self.server.feature(types.TEXT_DOCUMENT_DIAGNOSTIC)
439493
def diagnostic(
440494
ls: LanguageServer, params: types.DocumentDiagnosticParams

sqlmesh/lsp/rename.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import typing as t
2+
from lsprotocol.types import (
3+
Position,
4+
TextEdit,
5+
WorkspaceEdit,
6+
PrepareRenameResult_Type1,
7+
DocumentHighlight,
8+
DocumentHighlightKind,
9+
)
10+
11+
from sqlmesh.lsp.context import LSPContext
12+
from sqlmesh.lsp.reference import (
13+
_position_within_range,
14+
get_cte_references,
15+
LSPCteReference,
16+
)
17+
from sqlmesh.lsp.uri import URI
18+
19+
20+
def prepare_rename(
21+
lsp_context: LSPContext, document_uri: URI, position: Position
22+
) -> t.Optional[PrepareRenameResult_Type1]:
23+
"""
24+
Prepare for rename operation by checking if the symbol at the position can be renamed.
25+
26+
Args:
27+
lsp_context: The LSP context
28+
document_uri: The URI of the document
29+
position: The position in the document
30+
31+
Returns:
32+
PrepareRenameResult if the symbol can be renamed, None otherwise
33+
"""
34+
# Check if there's a CTE at this position
35+
cte_references = get_cte_references(lsp_context, document_uri, position)
36+
if cte_references:
37+
# Find the target CTE definition to get its range
38+
target_range = None
39+
for ref in cte_references:
40+
# Check if cursor is on a CTE usage
41+
if _position_within_range(position, ref.range):
42+
target_range = ref.target_range
43+
break
44+
# Check if cursor is on the CTE definition
45+
elif _position_within_range(position, ref.target_range):
46+
target_range = ref.target_range
47+
break
48+
if target_range:
49+
return PrepareRenameResult_Type1(range=target_range, placeholder="cte_name")
50+
51+
# For now, only CTEs are supported
52+
return None
53+
54+
55+
def rename_symbol(
56+
lsp_context: LSPContext, document_uri: URI, position: Position, new_name: str
57+
) -> t.Optional[WorkspaceEdit]:
58+
"""
59+
Perform rename operation on the symbol at the given position.
60+
61+
Args:
62+
lsp_context: The LSP context
63+
document_uri: The URI of the document
64+
position: The position in the document
65+
new_name: The new name for the symbol
66+
67+
Returns:
68+
WorkspaceEdit with the changes, or None if no symbol to rename
69+
"""
70+
# Check if there's a CTE at this position
71+
cte_references = get_cte_references(lsp_context, document_uri, position)
72+
if cte_references:
73+
return _rename_cte(cte_references, new_name)
74+
75+
# For now, only CTEs are supported
76+
return None
77+
78+
79+
def _rename_cte(cte_references: t.List[LSPCteReference], new_name: str) -> WorkspaceEdit:
80+
"""
81+
Create a WorkspaceEdit for renaming a CTE.
82+
83+
Args:
84+
cte_references: List of CTE references (definition and usages)
85+
new_name: The new name for the CTE
86+
87+
Returns:
88+
WorkspaceEdit with the text edits for renaming the CTE
89+
"""
90+
changes: t.Dict[str, t.List[TextEdit]] = {}
91+
92+
for ref in cte_references:
93+
uri = ref.uri
94+
if uri not in changes:
95+
changes[uri] = []
96+
97+
# Create a text edit for this reference
98+
text_edit = TextEdit(range=ref.range, new_text=new_name)
99+
changes[uri].append(text_edit)
100+
101+
return WorkspaceEdit(changes=changes)
102+
103+
104+
def get_document_highlights(
105+
lsp_context: LSPContext, document_uri: URI, position: Position
106+
) -> t.Optional[t.List[DocumentHighlight]]:
107+
"""
108+
Get document highlights for all occurrences of the symbol at the given position.
109+
110+
This function finds all occurrences of a symbol (CTE) within the current document
111+
and returns them as DocumentHighlight objects for "Change All Occurrences" feature.
112+
113+
Args:
114+
lsp_context: The LSP context
115+
document_uri: The URI of the document
116+
position: The position in the document to find highlights for
117+
118+
Returns:
119+
List of DocumentHighlight objects or None if no symbol found
120+
"""
121+
# Check if there's a CTE at this position
122+
cte_references = get_cte_references(lsp_context, document_uri, position)
123+
if cte_references:
124+
highlights = []
125+
for ref in cte_references:
126+
# Determine the highlight kind based on whether it's a definition or usage
127+
kind = (
128+
DocumentHighlightKind.Write
129+
if ref.range == ref.target_range
130+
else DocumentHighlightKind.Read
131+
)
132+
133+
highlights.append(DocumentHighlight(range=ref.range, kind=kind))
134+
return highlights
135+
136+
# For now, only CTEs are supported
137+
return None
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from lsprotocol.types import Position, DocumentHighlightKind
2+
3+
from sqlmesh.core.context import Context
4+
from sqlmesh.lsp.context import LSPContext, ModelTarget
5+
from sqlmesh.lsp.rename import get_document_highlights
6+
from sqlmesh.lsp.uri import URI
7+
from tests.lsp.test_reference_cte import find_ranges_from_regex
8+
9+
10+
def test_get_document_highlights_cte():
11+
context = Context(paths=["examples/sushi"])
12+
lsp_context = LSPContext(context)
13+
14+
# Use the existing customers.sql model which has CTEs
15+
sushi_customers_path = next(
16+
path
17+
for path, info in lsp_context.map.items()
18+
if isinstance(info, ModelTarget) and "sushi.customers" in info.names
19+
)
20+
21+
with open(sushi_customers_path, "r", encoding="utf-8") as file:
22+
read_file = file.readlines()
23+
24+
test_uri = URI.from_path(sushi_customers_path)
25+
26+
# Find the ranges for "current_marketing" CTE (not outer one)
27+
ranges = find_ranges_from_regex(read_file, r"current_marketing(?!_outer)")
28+
assert len(ranges) >= 2 # Should have definition + usage
29+
30+
# Test highlighting CTE definition - position on "current_marketing" definition
31+
position = Position(line=ranges[0].start.line, character=ranges[0].start.character + 4)
32+
highlights = get_document_highlights(lsp_context, test_uri, position)
33+
34+
assert highlights is not None
35+
assert len(highlights) >= 2 # Definition + at least 1 usage
36+
37+
# Check that we have both definition (Write) and usage (Read) highlights
38+
highlight_kinds = [h.kind for h in highlights]
39+
assert DocumentHighlightKind.Write in highlight_kinds # CTE definition
40+
assert DocumentHighlightKind.Read in highlight_kinds # CTE usage
41+
42+
# Test highlighting CTE usage - position on "current_marketing" usage
43+
position = Position(line=ranges[1].start.line, character=ranges[1].start.character + 4)
44+
highlights = get_document_highlights(lsp_context, test_uri, position)
45+
46+
assert highlights is not None
47+
assert len(highlights) >= 2 # Should find the same references
48+
49+
50+
def test_get_document_highlights_no_symbol():
51+
context = Context(paths=["examples/sushi"])
52+
lsp_context = LSPContext(context)
53+
54+
# Use the existing customers.sql model
55+
sushi_customers_path = next(
56+
path
57+
for path, info in lsp_context.map.items()
58+
if isinstance(info, ModelTarget) and "sushi.customers" in info.names
59+
)
60+
61+
test_uri = URI.from_path(sushi_customers_path)
62+
63+
# Test position not on any CTE symbol - just on a random keyword
64+
position = Position(line=5, character=5)
65+
highlights = get_document_highlights(lsp_context, test_uri, position)
66+
67+
assert highlights is None
68+
69+
70+
def test_get_document_highlights_multiple_ctes():
71+
context = Context(paths=["examples/sushi"])
72+
lsp_context = LSPContext(context)
73+
74+
# Use the existing customers.sql model which has both outer and inner CTEs
75+
sushi_customers_path = next(
76+
path
77+
for path, info in lsp_context.map.items()
78+
if isinstance(info, ModelTarget) and "sushi.customers" in info.names
79+
)
80+
81+
with open(sushi_customers_path, "r", encoding="utf-8") as file:
82+
read_file = file.readlines()
83+
84+
test_uri = URI.from_path(sushi_customers_path)
85+
86+
# Test the outer CTE - "current_marketing_outer"
87+
outer_ranges = find_ranges_from_regex(read_file, r"current_marketing_outer")
88+
assert len(outer_ranges) >= 2 # Should have definition + usage
89+
90+
# Test highlighting outer CTE - should only highlight that CTE
91+
position = Position(
92+
line=outer_ranges[0].start.line, character=outer_ranges[0].start.character + 4
93+
)
94+
highlights = get_document_highlights(lsp_context, test_uri, position)
95+
96+
assert highlights is not None
97+
assert len(highlights) == len(outer_ranges) # Should match all occurrences of outer CTE
98+
99+
# Test the inner CTE - "current_marketing" (not outer)
100+
inner_ranges = find_ranges_from_regex(read_file, r"current_marketing(?!_outer)")
101+
assert len(inner_ranges) >= 2 # Should have definition + usage
102+
103+
# Test highlighting inner CTE - should only highlight that CTE, not the outer one
104+
position = Position(
105+
line=inner_ranges[0].start.line, character=inner_ranges[0].start.character + 4
106+
)
107+
highlights = get_document_highlights(lsp_context, test_uri, position)
108+
109+
# This should return the column usages as well
110+
assert highlights is not None
111+
assert len(highlights) == 4

0 commit comments

Comments
 (0)