Skip to content

Commit 7ba0c10

Browse files
authored
fix(vscode): use paths instead of uris (#4448)
1 parent e24c9da commit 7ba0c10

7 files changed

Lines changed: 48 additions & 48 deletions

File tree

sqlmesh/lsp/completions.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ def get_models(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.
3333
all_models.update(file_info.names)
3434

3535
# Remove models from the current file
36-
if file_uri is not None and file_uri in context.map:
37-
file_info = context.map[file_uri]
36+
path = file_uri.to_path() if file_uri is not None else None
37+
if path is not None and path in context.map:
38+
file_info = context.map[path]
3839
if isinstance(file_info, ModelTarget):
3940
for model in file_info.names:
4041
all_models.discard(model)
@@ -53,8 +54,8 @@ def get_keywords(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) ->
5354
If both a context and a file_uri are provided, returns the keywords
5455
for the dialect of the model that the file belongs to.
5556
"""
56-
if file_uri is not None and context is not None and file_uri in context.map:
57-
file_info = context.map[file_uri]
57+
if file_uri is not None and context is not None and file_uri.to_path() in context.map:
58+
file_info = context.map[file_uri.to_path()]
5859

5960
# Handle ModelInfo objects
6061
if isinstance(file_info, ModelTarget) and file_info.names:

sqlmesh/lsp/context.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from dataclasses import dataclass
2+
from pathlib import Path
23
from sqlmesh.core.context import Context
34
import typing as t
45

5-
from sqlmesh.lsp.uri import URI
6-
76

87
@dataclass
98
class ModelTarget:
@@ -29,24 +28,24 @@ def __init__(self, context: Context) -> None:
2928
self.context = context
3029

3130
# Add models to the map
32-
model_map: t.Dict[URI, ModelTarget] = {}
31+
model_map: t.Dict[Path, ModelTarget] = {}
3332
for model in context.models.values():
3433
if model._path is not None:
35-
uri = URI.from_path(model._path)
34+
uri = model._path
3635
if uri in model_map:
3736
model_map[uri].names.append(model.name)
3837
else:
3938
model_map[uri] = ModelTarget(names=[model.name])
4039

4140
# Add standalone audits to the map
42-
audit_map: t.Dict[URI, AuditTarget] = {}
41+
audit_map: t.Dict[Path, AuditTarget] = {}
4342
for audit in context.standalone_audits.values():
4443
if audit._path is not None:
45-
uri = URI.from_path(audit._path)
44+
uri = audit._path
4645
if uri not in audit_map:
4746
audit_map[uri] = AuditTarget(name=audit.name)
4847

49-
self.map: t.Dict[URI, t.Union[ModelTarget, AuditTarget]] = {
48+
self.map: t.Dict[Path, t.Union[ModelTarget, AuditTarget]] = {
5049
**model_map,
5150
**audit_map,
5251
}

sqlmesh/lsp/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> Non
116116
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(self.lint_cache[uri]),
117117
)
118118
return
119-
models = context.map[uri]
119+
models = context.map[uri.to_path()]
120120
if models is None:
121121
return
122122
if not isinstance(models, ModelTarget):
@@ -134,7 +134,7 @@ def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> Non
134134
def did_change(ls: LanguageServer, params: types.DidChangeTextDocumentParams) -> None:
135135
uri = URI(params.text_document.uri)
136136
context = self._context_get_or_load(uri)
137-
models = context.map[uri]
137+
models = context.map[uri.to_path()]
138138
if models is None:
139139
return
140140
if not isinstance(models, ModelTarget):
@@ -152,7 +152,7 @@ def did_change(ls: LanguageServer, params: types.DidChangeTextDocumentParams) ->
152152
def did_save(ls: LanguageServer, params: types.DidSaveTextDocumentParams) -> None:
153153
uri = URI(params.text_document.uri)
154154
context = self._context_get_or_load(uri)
155-
models = context.map[uri]
155+
models = context.map[uri.to_path()]
156156
if models is None:
157157
return
158158
if not isinstance(models, ModelTarget):

sqlmesh/lsp/reference.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,14 @@ def get_model_definitions_for_a_path(
8888
- Try get_model before normalization
8989
- Match to models that the model refers to
9090
"""
91-
# Ensure the path is a sql file
92-
if document_uri.to_path().suffix != ".sql":
91+
path = document_uri.to_path()
92+
if path.suffix != ".sql":
9393
return []
9494
# Get the file info from the context map
95-
if document_uri not in lint_context.map:
95+
if path not in lint_context.map:
9696
return []
9797

98-
file_info = lint_context.map[document_uri]
99-
98+
file_info = lint_context.map[path]
10099
# Process based on whether it's a model or standalone audit
101100
if isinstance(file_info, ModelTarget):
102101
# It's a model

tests/lsp/test_completions.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from sqlmesh.core.context import Context
44
from sqlmesh.lsp.completions import get_keywords_from_tokenizer, get_sql_completions
55
from sqlmesh.lsp.context import LSPContext
6+
from sqlmesh.lsp.uri import URI
67

78

89
TOKENIZER_KEYWORDS = set(Tokenizer.KEYWORDS.keys())
@@ -36,9 +37,7 @@ def test_get_sql_completions_with_context_and_file_uri():
3637
context = Context(paths=["examples/sushi"])
3738
lsp_context = LSPContext(context)
3839

39-
file_uri = next(
40-
key for key in lsp_context.map.keys() if str(key.to_path()).endswith("active_customers.sql")
41-
)
42-
completions = get_sql_completions(lsp_context, file_uri)
40+
file_uri = next(key for key in lsp_context.map.keys() if key.name == "active_customers.sql")
41+
completions = get_sql_completions(lsp_context, URI.from_path(file_uri))
4342
assert len(completions.keywords) > len(TOKENIZER_KEYWORDS)
4443
assert "sushi.active_customers" not in completions.models

tests/lsp/test_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def test_lsp_context():
1414

1515
# find one model in the map
1616
active_customers_key = next(
17-
key for key in lsp_context.map.keys() if str(key.to_path()).endswith("active_customers.sql")
17+
key for key in lsp_context.map.keys() if key.name == "active_customers.sql"
1818
)
1919

2020
# Check that the value is a ModelInfo with the expected model name

tests/lsp/test_reference.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from sqlmesh.core.context import Context
44
from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
55
from sqlmesh.lsp.reference import get_model_definitions_for_a_path, by_position
6+
from sqlmesh.lsp.uri import URI
67

78

89
@pytest.mark.fast
@@ -11,21 +12,22 @@ def test_reference() -> None:
1112
lsp_context = LSPContext(context)
1213

1314
# Find model URIs
14-
active_customers_uri = next(
15-
uri
16-
for uri, info in lsp_context.map.items()
15+
active_customers_path = next(
16+
path
17+
for path, info in lsp_context.map.items()
1718
if isinstance(info, ModelTarget) and "sushi.active_customers" in info.names
1819
)
19-
sushi_customers_uri = next(
20-
uri
21-
for uri, info in lsp_context.map.items()
20+
sushi_customers_path = next(
21+
path
22+
for path, info in lsp_context.map.items()
2223
if isinstance(info, ModelTarget) and "sushi.customers" in info.names
2324
)
2425

26+
active_customers_uri = URI.from_path(active_customers_path)
2527
references = get_model_definitions_for_a_path(lsp_context, active_customers_uri)
2628

2729
assert len(references) == 1
28-
assert references[0].uri == sushi_customers_uri.value
30+
assert URI(references[0].uri) == URI.from_path(sushi_customers_path)
2931

3032
# Check that the reference in the correct range is sushi.customers
3133
path = active_customers_uri.to_path()
@@ -42,17 +44,18 @@ def test_reference_with_alias() -> None:
4244
context = Context(paths=["examples/sushi"])
4345
lsp_context = LSPContext(context)
4446

45-
waiter_revenue_by_day_uri = next(
47+
waiter_revenue_by_day_path = next(
4648
uri
4749
for uri, info in lsp_context.map.items()
4850
if isinstance(info, ModelTarget) and "sushi.waiter_revenue_by_day" in info.names
4951
)
5052

51-
references = get_model_definitions_for_a_path(lsp_context, waiter_revenue_by_day_uri)
53+
references = get_model_definitions_for_a_path(
54+
lsp_context, URI.from_path(waiter_revenue_by_day_path)
55+
)
5256
assert len(references) == 3
5357

54-
path = waiter_revenue_by_day_uri.to_path()
55-
with open(path, "r") as file:
58+
with open(waiter_revenue_by_day_path, "r") as file:
5659
read_file = file.readlines()
5760

5861
assert references[0].uri.endswith("orders.py")
@@ -70,27 +73,25 @@ def test_standalone_audit_reference() -> None:
7073
lsp_context = LSPContext(context)
7174

7275
# Find the standalone audit URI
73-
audit_uri = next(
76+
audit_path = next(
7477
uri
7578
for uri, info in lsp_context.map.items()
7679
if isinstance(info, AuditTarget) and info.name == "assert_item_price_above_zero"
7780
)
78-
7981
# Find the items model URI
80-
items_uri = next(
82+
items_path = next(
8183
uri
8284
for uri, info in lsp_context.map.items()
8385
if isinstance(info, ModelTarget) and "sushi.items" in info.names
8486
)
8587

86-
references = get_model_definitions_for_a_path(lsp_context, audit_uri)
88+
references = get_model_definitions_for_a_path(lsp_context, URI.from_path(audit_path))
8789

8890
assert len(references) == 1
89-
assert references[0].uri == items_uri.value
91+
assert references[0].uri == URI.from_path(items_path).value
9092

9193
# Check that the reference in the correct range is sushi.items
92-
path = audit_uri.to_path()
93-
with open(path, "r") as file:
94+
with open(audit_path, "r") as file:
9495
read_file = file.readlines()
9596
referenced_text = get_string_from_range(read_file, references[0].range)
9697
assert referenced_text == "sushi.items"
@@ -123,19 +124,20 @@ def test_filter_references_by_position() -> None:
123124
lsp_context = LSPContext(context)
124125

125126
# Use a file with multiple references (waiter_revenue_by_day)
126-
waiter_revenue_by_day_uri = next(
127-
uri
128-
for uri, info in lsp_context.map.items()
127+
waiter_revenue_by_day_path = next(
128+
path
129+
for path, info in lsp_context.map.items()
129130
if isinstance(info, ModelTarget) and "sushi.waiter_revenue_by_day" in info.names
130131
)
131132

132133
# Get all references in the file
133-
all_references = get_model_definitions_for_a_path(lsp_context, waiter_revenue_by_day_uri)
134+
all_references = get_model_definitions_for_a_path(
135+
lsp_context, URI.from_path(waiter_revenue_by_day_path)
136+
)
134137
assert len(all_references) == 3
135138

136139
# Get file contents to locate positions for testing
137-
path = waiter_revenue_by_day_uri.to_path()
138-
with open(path, "r") as file:
140+
with open(waiter_revenue_by_day_path, "r") as file:
139141
read_file = file.readlines()
140142

141143
# Test positions for each reference

0 commit comments

Comments
 (0)