|
10 | 10 | """ |
11 | 11 |
|
12 | 12 | import time |
| 13 | +from collections.abc import Callable |
13 | 14 | from pathlib import Path |
14 | 15 | from typing import Any |
15 | 16 |
|
16 | 17 | import pytest |
17 | 18 |
|
18 | 19 | from sqlspec.core import SQL |
19 | | -from sqlspec.exceptions import SQLFileNotFoundError, SQLFileParseError |
| 20 | +from sqlspec.exceptions import SQLFileNotFoundError, SQLFileParseError, SQLStatementNotFoundError |
20 | 21 | from sqlspec.loader import ( |
21 | 22 | NamedStatement, |
22 | 23 | SQLFile, |
@@ -510,6 +511,44 @@ def test_get_query_text_not_found() -> None: |
510 | 511 | loader.get_query_text("nonexistent") |
511 | 512 |
|
512 | 513 |
|
| 514 | +@pytest.mark.parametrize("accessor", [SQLFileLoader.get_sql, SQLFileLoader.get_query_text]) |
| 515 | +def test_missing_statement_empty_loader_message(accessor: Callable[[SQLFileLoader, str], object]) -> None: |
| 516 | + """Missing statements in an empty loader should return a bounded message.""" |
| 517 | + loader = SQLFileLoader() |
| 518 | + |
| 519 | + with pytest.raises(SQLStatementNotFoundError) as exc_info: |
| 520 | + accessor(loader, "missing-secret") |
| 521 | + |
| 522 | + exc = exc_info.value |
| 523 | + assert str(exc) == "SQL statement 'missing-secret' not found. No SQL statements are loaded." |
| 524 | + assert exc.name == "missing-secret" |
| 525 | + assert exc.normalized_name == "missing_secret" |
| 526 | + assert exc.query_count == 0 |
| 527 | + |
| 528 | + |
| 529 | +@pytest.mark.parametrize("accessor", [SQLFileLoader.get_sql, SQLFileLoader.get_query_text]) |
| 530 | +def test_missing_statement_loaded_registry_message_does_not_leak_names( |
| 531 | + accessor: Callable[[SQLFileLoader, str], object], |
| 532 | +) -> None: |
| 533 | + """Missing statements should not dump the loaded statement registry.""" |
| 534 | + loader = SQLFileLoader() |
| 535 | + loaded_query_names = [f"tenant_{index}_private_query" for index in range(20)] |
| 536 | + for query_name in loaded_query_names: |
| 537 | + loader.add_named_sql(query_name, "SELECT 1") |
| 538 | + |
| 539 | + with pytest.raises(SQLStatementNotFoundError) as exc_info: |
| 540 | + accessor(loader, "missing-secret") |
| 541 | + |
| 542 | + message = str(exc_info.value) |
| 543 | + assert message == ( |
| 544 | + "SQL statement 'missing-secret' not found. 20 SQL statements are loaded. " |
| 545 | + "Use list_queries() to inspect available statement names." |
| 546 | + ) |
| 547 | + assert "Available statements" not in message |
| 548 | + for query_name in loaded_query_names: |
| 549 | + assert query_name not in message |
| 550 | + |
| 551 | + |
513 | 552 | def test_clear_cache() -> None: |
514 | 553 | """Test clearing loader cache.""" |
515 | 554 | loader = SQLFileLoader() |
@@ -582,7 +621,7 @@ def test_get_sql_not_found() -> None: |
582 | 621 | with pytest.raises(SQLFileNotFoundError) as exc_info: |
583 | 622 | loader.get_sql("nonexistent") |
584 | 623 |
|
585 | | - assert "Statement 'nonexistent' not found" in str(exc_info.value) |
| 624 | + assert "SQL statement 'nonexistent' not found" in str(exc_info.value) |
586 | 625 |
|
587 | 626 |
|
588 | 627 | def test_get_sql_name_normalization() -> None: |
|
0 commit comments