Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/sushi/models/customers.sql
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ LEFT JOIN (
@ADD_ONE(1) AS another_column,
FROM current_marketing_outer
)
SELECT * FROM current_marketing
SELECT current_marketing.* FROM current_marketing WHERE current_marketing.customer_id != 100
) AS m
ON o.customer_id = m.customer_id
LEFT JOIN raw.demographics AS d
ON o.customer_id = d.customer_id
WHERE sushi.orders.customer_id > 0
134 changes: 134 additions & 0 deletions sqlmesh/lsp/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,17 @@ def get_model_definitions_for_a_path(
target_range=target_range,
)
)

column_references = _process_column_references(
scope=scope,
reference_name=table.name,
read_file=read_file,
referenced_model_uri=document_uri,
description="",
reference_type="cte",
cte_target_range=target_range,
)
references.extend(column_references)
continue

# For non-CTE tables, process as before (external model references)
Expand Down Expand Up @@ -276,6 +287,19 @@ def get_model_definitions_for_a_path(
target_range=yaml_target_range,
)
)

column_references = _process_column_references(
scope=scope,
reference_name=normalized_reference_name,
read_file=read_file,
referenced_model_uri=referenced_model_uri,
description=description,
yaml_target_range=yaml_target_range,
reference_type="external_model",
default_catalog=lint_context.context.default_catalog,
dialect=dialect,
)
references.extend(column_references)
else:
references.append(
LSPModelReference(
Expand All @@ -288,6 +312,18 @@ def get_model_definitions_for_a_path(
)
)

column_references = _process_column_references(
scope=scope,
reference_name=normalized_reference_name,
read_file=read_file,
referenced_model_uri=referenced_model_uri,
description=description,
reference_type="model",
default_catalog=lint_context.context.default_catalog,
dialect=dialect,
)
references.extend(column_references)

return references


Expand Down Expand Up @@ -735,6 +771,104 @@ def _position_within_range(position: Position, range: Range) -> bool:
)


def _get_column_table_range(column: exp.Column, read_file: t.List[str]) -> Range:
"""
Get the range for a column's table reference, handling both simple and qualified table names.

Args:
column: The column expression
read_file: The file content as list of lines

Returns:
The Range covering the table reference in the column
"""

table_parts = column.parts[:-1]

start_range = TokenPositionDetails.from_meta(table_parts[0].meta).to_range(read_file)
end_range = TokenPositionDetails.from_meta(table_parts[-1].meta).to_range(read_file)

return Range(
start=to_lsp_position(start_range.start),
end=to_lsp_position(end_range.end),
)


def _process_column_references(
scope: t.Any,
reference_name: str,
read_file: t.List[str],
referenced_model_uri: URI,
description: t.Optional[str] = None,
yaml_target_range: t.Optional[Range] = None,
reference_type: t.Literal["model", "external_model", "cte"] = "model",
default_catalog: t.Optional[str] = None,
dialect: t.Optional[str] = None,
cte_target_range: t.Optional[Range] = None,
) -> t.List[Reference]:
"""
Process column references for a given table and create appropriate reference objects.

Args:
scope: The SQL scope to search for columns
reference_name: The full reference name (may include database/catalog)
read_file: The file content as list of lines
referenced_model_uri: URI of the referenced model
description: Markdown description for the reference
yaml_target_range: Target range for external models (YAML files)
reference_type: Type of reference - "model", "external_model", or "cte"
default_catalog: Default catalog for normalization
dialect: SQL dialect for normalization
cte_target_range: Target range for CTE references

Returns:
List of table references for column usages
"""

references: t.List[Reference] = []
for column in scope.find_all(exp.Column):
if column.table:
if reference_type == "cte":
if column.table == reference_name:
table_range = _get_column_table_range(column, read_file)
references.append(
LSPCteReference(
uri=referenced_model_uri.value,
range=table_range,
target_range=cte_target_range,
)
)
else:
table_parts = [part.sql(dialect) for part in column.parts[:-1]]
table_ref = ".".join(table_parts)
normalized_reference_name = normalize_model_name(
table_ref,
default_catalog=default_catalog,
dialect=dialect,
)
if normalized_reference_name == reference_name:
table_range = _get_column_table_range(column, read_file)
if reference_type == "external_model":
references.append(
LSPExternalModelReference(
uri=referenced_model_uri.value,
range=table_range,
markdown_description=description,
target_range=yaml_target_range,
)
)
else:
references.append(
LSPModelReference(
uri=referenced_model_uri.value,
range=table_range,
markdown_description=description,
)
)

return references


def _get_yaml_model_range(path: Path, model_name: str) -> t.Optional[Range]:
"""
Find the range of a specific model block in a YAML file.
Expand Down
9 changes: 4 additions & 5 deletions tests/lsp/test_reference_cte_find_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@ def test_cte_find_all_references():

# Test finding all references of "current_marketing"
ranges = find_ranges_from_regex(read_file, r"current_marketing(?!_outer)")
assert len(ranges) == 2
assert len(ranges) == 2 # regex finds 2 occurrences (definition and FROM clause)

# Click on the CTE definition
position = Position(line=ranges[0].start.line, character=ranges[0].start.character + 4)
references = get_cte_references(lsp_context, URI.from_path(sushi_customers_path), position)

# Should find both the definition and the usage
assert len(references) == 2
# Should find the definition, FROM clause, and column prefix usages
assert len(references) == 4 # definition + FROM + 2 column prefix uses
assert all(ref.uri == URI.from_path(sushi_customers_path).value for ref in references)

reference_ranges = [ref.range for ref in references]
Expand All @@ -46,7 +45,7 @@ def test_cte_find_all_references():
references = get_cte_references(lsp_context, URI.from_path(sushi_customers_path), position)

# Should find the same references
assert len(references) == 2
assert len(references) == 4 # definition + FROM + 2 column prefix uses
assert all(ref.uri == URI.from_path(sushi_customers_path).value for ref in references)

reference_ranges = [ref.range for ref in references]
Expand Down
Loading