diff --git a/examples/sushi/models/customers.sql b/examples/sushi/models/customers.sql index 24b3aaa208..f91f1166e8 100644 --- a/examples/sushi/models/customers.sql +++ b/examples/sushi/models/customers.sql @@ -37,8 +37,9 @@ LEFT JOIN ( @ADD_ONE(1) AS another_column, FROM current_marketing_outer ) - SELECT * FROM current_marketing + SELECT current_marketing.* FROM current_marketing WHERE current_marketing.customer_id != 100 ) AS m ON o.customer_id = m.customer_id LEFT JOIN raw.demographics AS d ON o.customer_id = d.customer_id +WHERE sushi.orders.customer_id > 0 \ No newline at end of file diff --git a/sqlmesh/lsp/reference.py b/sqlmesh/lsp/reference.py index ac4d5374b6..96db4dc63d 100644 --- a/sqlmesh/lsp/reference.py +++ b/sqlmesh/lsp/reference.py @@ -208,6 +208,17 @@ def get_model_definitions_for_a_path( target_range=target_range, ) ) + + column_references = _process_column_references( + scope=scope, + reference_name=table.name, + read_file=read_file, + referenced_model_uri=document_uri, + description="", + reference_type="cte", + cte_target_range=target_range, + ) + references.extend(column_references) continue # For non-CTE tables, process as before (external model references) @@ -276,6 +287,19 @@ def get_model_definitions_for_a_path( target_range=yaml_target_range, ) ) + + column_references = _process_column_references( + scope=scope, + reference_name=normalized_reference_name, + read_file=read_file, + referenced_model_uri=referenced_model_uri, + description=description, + yaml_target_range=yaml_target_range, + reference_type="external_model", + default_catalog=lint_context.context.default_catalog, + dialect=dialect, + ) + references.extend(column_references) else: references.append( LSPModelReference( @@ -288,6 +312,18 @@ def get_model_definitions_for_a_path( ) ) + column_references = _process_column_references( + scope=scope, + reference_name=normalized_reference_name, + read_file=read_file, + referenced_model_uri=referenced_model_uri, + description=description, + reference_type="model", + default_catalog=lint_context.context.default_catalog, + dialect=dialect, + ) + references.extend(column_references) + return references @@ -735,6 +771,104 @@ def _position_within_range(position: Position, range: Range) -> bool: ) +def _get_column_table_range(column: exp.Column, read_file: t.List[str]) -> Range: + """ + Get the range for a column's table reference, handling both simple and qualified table names. + + Args: + column: The column expression + read_file: The file content as list of lines + + Returns: + The Range covering the table reference in the column + """ + + table_parts = column.parts[:-1] + + start_range = TokenPositionDetails.from_meta(table_parts[0].meta).to_range(read_file) + end_range = TokenPositionDetails.from_meta(table_parts[-1].meta).to_range(read_file) + + return Range( + start=to_lsp_position(start_range.start), + end=to_lsp_position(end_range.end), + ) + + +def _process_column_references( + scope: t.Any, + reference_name: str, + read_file: t.List[str], + referenced_model_uri: URI, + description: t.Optional[str] = None, + yaml_target_range: t.Optional[Range] = None, + reference_type: t.Literal["model", "external_model", "cte"] = "model", + default_catalog: t.Optional[str] = None, + dialect: t.Optional[str] = None, + cte_target_range: t.Optional[Range] = None, +) -> t.List[Reference]: + """ + Process column references for a given table and create appropriate reference objects. + + Args: + scope: The SQL scope to search for columns + reference_name: The full reference name (may include database/catalog) + read_file: The file content as list of lines + referenced_model_uri: URI of the referenced model + description: Markdown description for the reference + yaml_target_range: Target range for external models (YAML files) + reference_type: Type of reference - "model", "external_model", or "cte" + default_catalog: Default catalog for normalization + dialect: SQL dialect for normalization + cte_target_range: Target range for CTE references + + Returns: + List of table references for column usages + """ + + references: t.List[Reference] = [] + for column in scope.find_all(exp.Column): + if column.table: + if reference_type == "cte": + if column.table == reference_name: + table_range = _get_column_table_range(column, read_file) + references.append( + LSPCteReference( + uri=referenced_model_uri.value, + range=table_range, + target_range=cte_target_range, + ) + ) + else: + table_parts = [part.sql(dialect) for part in column.parts[:-1]] + table_ref = ".".join(table_parts) + normalized_reference_name = normalize_model_name( + table_ref, + default_catalog=default_catalog, + dialect=dialect, + ) + if normalized_reference_name == reference_name: + table_range = _get_column_table_range(column, read_file) + if reference_type == "external_model": + references.append( + LSPExternalModelReference( + uri=referenced_model_uri.value, + range=table_range, + markdown_description=description, + target_range=yaml_target_range, + ) + ) + else: + references.append( + LSPModelReference( + uri=referenced_model_uri.value, + range=table_range, + markdown_description=description, + ) + ) + + return references + + def _get_yaml_model_range(path: Path, model_name: str) -> t.Optional[Range]: """ Find the range of a specific model block in a YAML file. diff --git a/tests/lsp/test_reference_cte_find_all.py b/tests/lsp/test_reference_cte_find_all.py index d57c996a6a..6a29224e75 100644 --- a/tests/lsp/test_reference_cte_find_all.py +++ b/tests/lsp/test_reference_cte_find_all.py @@ -21,14 +21,13 @@ def test_cte_find_all_references(): # Test finding all references of "current_marketing" ranges = find_ranges_from_regex(read_file, r"current_marketing(?!_outer)") - assert len(ranges) == 2 + assert len(ranges) == 2 # regex finds 2 occurrences (definition and FROM clause) # Click on the CTE definition position = Position(line=ranges[0].start.line, character=ranges[0].start.character + 4) references = get_cte_references(lsp_context, URI.from_path(sushi_customers_path), position) - - # Should find both the definition and the usage - assert len(references) == 2 + # Should find the definition, FROM clause, and column prefix usages + assert len(references) == 4 # definition + FROM + 2 column prefix uses assert all(ref.uri == URI.from_path(sushi_customers_path).value for ref in references) reference_ranges = [ref.range for ref in references] @@ -46,7 +45,7 @@ def test_cte_find_all_references(): references = get_cte_references(lsp_context, URI.from_path(sushi_customers_path), position) # Should find the same references - assert len(references) == 2 + assert len(references) == 4 # definition + FROM + 2 column prefix uses assert all(ref.uri == URI.from_path(sushi_customers_path).value for ref in references) reference_ranges = [ref.range for ref in references] diff --git a/tests/lsp/test_reference_model_column_prefix.py b/tests/lsp/test_reference_model_column_prefix.py new file mode 100644 index 0000000000..88be689810 --- /dev/null +++ b/tests/lsp/test_reference_model_column_prefix.py @@ -0,0 +1,212 @@ +from pathlib import Path + +from lsprotocol.types import Position +from sqlmesh.cli.example_project import init_example_project +from sqlmesh.core.context import Context +from sqlmesh.lsp.context import LSPContext, ModelTarget +from sqlmesh.lsp.reference import get_all_references +from sqlmesh.lsp.uri import URI +from tests.lsp.test_reference_cte import find_ranges_from_regex + + +def test_model_reference_with_column_prefix(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + sushi_customers_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customers" in info.names + ) + + with open(sushi_customers_path, "r", encoding="utf-8") as file: + read_file = file.readlines() + + # Test finding references for "sushi.orders" + ranges = find_ranges_from_regex(read_file, r"sushi\.orders") + + # Click on the table reference in FROM clause (should be the second occurrence) + from_clause_range = None + for r in ranges: + line_content = read_file[r.start.line].strip() + if "FROM" in line_content: + from_clause_range = r + break + + assert from_clause_range is not None, "Should find FROM clause with sushi.orders" + + position = Position( + line=from_clause_range.start.line, character=from_clause_range.start.character + 6 + ) + + model_refs = get_all_references(lsp_context, URI.from_path(sushi_customers_path), position) + + assert len(model_refs) >= 7 + + # Verify that we have the FROM clause reference + assert any(ref.range.start.line == from_clause_range.start.line for ref in model_refs), ( + "Should find FROM clause reference" + ) + + +def test_column_prefix_references_are_found(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + sushi_customers_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customers" in info.names + ) + + with open(sushi_customers_path, "r", encoding="utf-8") as file: + read_file = file.readlines() + + # Find all occurrences of sushi.orders in the file + ranges = find_ranges_from_regex(read_file, r"sushi\.orders") + + # Should find exactly 2: FROM clause and WHERE clause with column prefix + assert len(ranges) == 2, f"Expected 2 occurrences of 'sushi.orders', found {len(ranges)}" + + # Verify we have the expected lines + line_contents = [read_file[r.start.line].strip() for r in ranges] + + # Should find FROM clause + assert any("FROM sushi.orders" in content for content in line_contents), ( + "Should find FROM clause with sushi.orders" + ) + + # Should find customer_id in WHERE clause with column prefix + assert any("WHERE sushi.orders.customer_id" in content for content in line_contents), ( + "Should find WHERE clause with sushi.orders.customer_id" + ) + + +def test_quoted_uppercase_table_and_column_references(tmp_path: Path): + # Initialize example project in temporary directory with case sensitive normalization + init_example_project(tmp_path, dialect="duckdb,normalization_strategy=case_sensitive") + + # Create a model with quoted uppercase schema and table names + models_dir = tmp_path / "models" + + # First, create the uppercase SUSHI.orders model that will be referenced + uppercase_orders_path = models_dir / "uppercase_orders.sql" + uppercase_orders_path.write_text("""MODEL ( + name "SUSHI".orders, + kind FULL +); + +SELECT + 1 as id, + 1 as customer_id, + 1 as item_id""") + + # Second, create the lowercase sushi.orders model that will be referenced + lowercase_orders_path = models_dir / "lowercase_orders.sql" + lowercase_orders_path.write_text("""MODEL ( + name sushi.orders, + kind FULL +); + +SELECT + 1 as id, + 1 as customer_id""") + + quoted_test_path = models_dir / "quoted_test.sql" + quoted_test_path.write_text("""MODEL ( + name "SUSHI".quoted_test, + kind FULL +); + +SELECT + o.id, + o.customer_id, + o.item_id, + c.item_id as c_item_id +FROM "SUSHI".orders AS o, sushi.orders as c +WHERE "SUSHI".orders.id > 0 + AND "SUSHI".orders.customer_id IS NOT NULL + AND sushi.orders.id > 0""") + + context = Context(paths=tmp_path) + lsp_context = LSPContext(context) + + # Find the quoted test model + quoted_test_model_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and '"SUSHI".quoted_test' in info.names + ) + + with open(quoted_test_model_path, "r", encoding="utf-8") as file: + read_file = file.readlines() + + # Test finding references for quoted "SUSHI".orders + ranges = find_ranges_from_regex(read_file, r'"SUSHI"\.orders') + + # Should find 3 occurrences: FROM clause and 2 in WHERE clause with column prefix + assert len(ranges) == 3, f"Expected 3 occurrences of '\"SUSHI\".orders', found {len(ranges)}" + + # Click on the table reference in FROM clause + from_clause_range = None + for r in ranges: + line_content = read_file[r.start.line].strip() + if "FROM" in line_content: + from_clause_range = r + break + + assert from_clause_range is not None, 'Should find FROM clause with "SUSHI".orders' + + position = Position( + line=from_clause_range.start.line, character=from_clause_range.start.character + 5 + ) + + model_refs = get_all_references(lsp_context, URI.from_path(quoted_test_model_path), position) + + # Should find only references to "SUSHI".orders (3 total: FROM clause and 2 column prefixes in WHERE) + # The lowercase sushi.orders should NOT be included if case sensitivity is working + assert len(model_refs) == 4, ( + f'Expected exactly 3 references for "SUSHI".orders, found {len(model_refs)}' + ) + + # Verify that we have all 3 references + ref_lines = [ref.range.start.line for ref in model_refs] + + # Count how many references are on each line + from_line = from_clause_range.start.line + where_lines = [r.start.line for r in ranges if r.start.line != from_line] + + assert from_line in ref_lines, "Should find FROM clause reference" + for where_line in where_lines: + assert where_line in ref_lines, f"Should find WHERE clause reference on line {where_line}" + + # Now test that lowercase sushi.orders references are separate + lowercase_ranges = find_ranges_from_regex(read_file, r"sushi\.orders") + + # Should find 2 occurrences: FROM clause and 1 in WHERE clause + assert len(lowercase_ranges) == 2, ( + f"Expected 2 occurrences of 'sushi.orders', found {len(lowercase_ranges)}" + ) + + # Click on the lowercase table reference + lowercase_from_range = None + for r in lowercase_ranges: + line_content = read_file[r.start.line].strip() + if "FROM" in line_content: + lowercase_from_range = r + break + + assert lowercase_from_range is not None, "Should find FROM clause with sushi.orders" + + lowercase_position = Position( + line=lowercase_from_range.start.line, character=lowercase_from_range.start.character + 5 + ) + + lowercase_refs = get_all_references( + lsp_context, URI.from_path(quoted_test_model_path), lowercase_position + ) + + # Should find only references to lowercase sushi.orders, NOT the uppercase ones + assert len(lowercase_refs) == 3, ( + f"Expected exactly 2 references for sushi.orders, found {len(lowercase_refs)}" + ) diff --git a/tests/lsp/test_reference_model_find_all.py b/tests/lsp/test_reference_model_find_all.py index c494ef7af3..7bb998150f 100644 --- a/tests/lsp/test_reference_model_find_all.py +++ b/tests/lsp/test_reference_model_find_all.py @@ -30,8 +30,8 @@ def test_find_references_for_model_usages(): # Click on the model reference position = Position(line=ranges[0].start.line, character=ranges[0].start.character + 6) references = get_model_find_all_references(lsp_context, URI.from_path(customers_path), position) - assert len(references) >= 6, ( - f"Expected at least 6 references to sushi.orders, found {len(references)}" + assert len(references) >= 7, ( + f"Expected at least 7 references to sushi.orders (including column prefix), found {len(references)}" ) # Verify expected files are present @@ -50,15 +50,18 @@ def test_find_references_for_model_usages(): ) # Verify exact ranges for each reference pattern + # Note: customers file has multiple references due to column prefix support expected_ranges = { - "orders": (0, 0, 0, 0), # the start for the model itself - "customers": (30, 7, 30, 19), - "waiter_revenue_by_day": (19, 5, 19, 17), - "customer_revenue_lifetime": (38, 7, 38, 19), - "customer_revenue_by_day": (33, 5, 33, 17), - "latest_order": (12, 5, 12, 17), + "orders": [(0, 0, 0, 0)], # the start for the model itself + "customers": [(30, 7, 30, 19), (44, 6, 44, 18)], # FROM clause and WHERE clause + "waiter_revenue_by_day": [(19, 5, 19, 17)], + "customer_revenue_lifetime": [(38, 7, 38, 19)], + "customer_revenue_by_day": [(33, 5, 33, 17)], + "latest_order": [(12, 5, 12, 17)], } + # Group references by file pattern + refs_by_pattern = {} for ref in references: matched_pattern = None for pattern in expected_patterns: @@ -66,28 +69,43 @@ def test_find_references_for_model_usages(): matched_pattern = pattern break - assert matched_pattern is not None, ( - f"Reference URI {ref.uri} doesn't match any expected pattern" - ) + if matched_pattern: + if matched_pattern not in refs_by_pattern: + refs_by_pattern[matched_pattern] = [] + refs_by_pattern[matched_pattern].append(ref) - # Get expected range for this model - expected_start_line, expected_start_char, expected_end_line, expected_end_char = ( - expected_ranges[matched_pattern] - ) + # Verify each pattern has the expected references + for pattern, expected_range_list in expected_ranges.items(): + assert pattern in refs_by_pattern, f"Missing references for pattern '{pattern}'" - # Assert exact range match - assert ref.range.start.line == expected_start_line, ( - f"Expected {matched_pattern} reference start line {expected_start_line}, found {ref.range.start.line}" - ) - assert ref.range.start.character == expected_start_char, ( - f"Expected {matched_pattern} reference start character {expected_start_char}, found {ref.range.start.character}" + actual_refs = refs_by_pattern[pattern] + assert len(actual_refs) == len(expected_range_list), ( + f"Expected {len(expected_range_list)} references for {pattern}, found {len(actual_refs)}" ) - assert ref.range.end.line == expected_end_line, ( - f"Expected {matched_pattern} reference end line {expected_end_line}, found {ref.range.end.line}" - ) - assert ref.range.end.character == expected_end_char, ( - f"Expected {matched_pattern} reference end character {expected_end_char}, found {ref.range.end.character}" + + # Sort both actual and expected by line number for consistent comparison + actual_refs_sorted = sorted( + actual_refs, key=lambda r: (r.range.start.line, r.range.start.character) ) + expected_sorted = sorted(expected_range_list, key=lambda r: (r[0], r[1])) + + for i, (ref, expected_range) in enumerate(zip(actual_refs_sorted, expected_sorted)): + expected_start_line, expected_start_char, expected_end_line, expected_end_char = ( + expected_range + ) + + assert ref.range.start.line == expected_start_line, ( + f"Expected {pattern} reference #{i + 1} start line {expected_start_line}, found {ref.range.start.line}" + ) + assert ref.range.start.character == expected_start_char, ( + f"Expected {pattern} reference #{i + 1} start character {expected_start_char}, found {ref.range.start.character}" + ) + assert ref.range.end.line == expected_end_line, ( + f"Expected {pattern} reference #{i + 1} end line {expected_end_line}, found {ref.range.end.line}" + ) + assert ref.range.end.character == expected_end_char, ( + f"Expected {pattern} reference #{i + 1} end character {expected_end_char}, found {ref.range.end.character}" + ) def test_find_references_for_marketing_model(): diff --git a/tests/test_forking.py b/tests/test_forking.py index 1cd50d9dec..d11379a158 100644 --- a/tests/test_forking.py +++ b/tests/test_forking.py @@ -55,10 +55,14 @@ def test_parallel_load(assert_exp_eq, mocker): "current_marketing"."status" AS "status", "current_marketing"."another_column" AS "another_column" FROM "current_marketing" AS "current_marketing" + WHERE + "current_marketing"."customer_id" <> 100 ) AS "m" ON "m"."customer_id" = "o"."customer_id" LEFT JOIN "memory"."raw"."demographics" AS "d" ON "d"."customer_id" = "o"."customer_id" + WHERE + "o"."customer_id" > 0 """, ) diff --git a/tests/web/test_lineage.py b/tests/web/test_lineage.py index 1ed40431ef..0cffd3ecc3 100644 --- a/tests/web/test_lineage.py +++ b/tests/web/test_lineage.py @@ -47,7 +47,7 @@ def test_get_lineage(client: TestClient, web_sushi_context: Context) -> None: "customer_id": { "expression": 'CAST("o"."customer_id" AS INT) AS "customer_id" /* this comment should not be registered */', "models": {'"memory"."sushi"."orders"': ["customer_id"]}, - "source": '''WITH "current_marketing_outer" AS ( + "source": """WITH "current_marketing_outer" AS ( SELECT "marketing"."customer_id" AS "customer_id", "marketing"."status" AS "status" @@ -71,10 +71,14 @@ def test_get_lineage(client: TestClient, web_sushi_context: Context) -> None: "current_marketing"."status" AS "status", "current_marketing"."another_column" AS "another_column" FROM "current_marketing" AS "current_marketing" + WHERE + "current_marketing"."customer_id" <> 100 ) AS "m" ON "m"."customer_id" = "o"."customer_id" LEFT JOIN "memory"."raw"."demographics" AS "d" - ON "d"."customer_id" = "o"."customer_id"''', + ON "d"."customer_id" = "o"."customer_id" +WHERE + "o"."customer_id" > 0""", } }, '"memory"."sushi"."orders"': {