|
1 | | - |
2 | | -import unittest |
3 | | -from unittest.mock import MagicMock |
| 1 | +from unittest.mock import MagicMock, ANY |
| 2 | +import pytest |
4 | 3 | from nl2sql.pipeline.nodes.aggregator.node import AggregatorNode |
5 | | -from nl2sql.pipeline.nodes.aggregator.schemas import AggregatedResponse |
6 | 4 | from nl2sql.pipeline.state import GraphState |
7 | | -from nl2sql.common.errors import ErrorCode |
| 5 | +from nl2sql.common.errors import PipelineError, ErrorSeverity, ErrorCode |
8 | 6 |
|
9 | | -class TestAggregatorNode(unittest.TestCase): |
10 | | - def setUp(self): |
11 | | - self.mock_llm = MagicMock() |
12 | | - self.node = AggregatorNode(self.mock_llm) |
13 | | - self.node.chain = self.mock_llm # Bypass prompt chain |
| 7 | +class TestAggregatorNode: |
| 8 | + """Unit tests for the AggregatorNode.""" |
14 | 9 |
|
15 | | - def test_fast_path(self): |
16 | | - """Test direct data return for single result with output_mode='data'.""" |
17 | | - state = GraphState( |
18 | | - user_query="q", |
19 | | - intermediate_results=[{"id": 1, "val": "A"}], |
20 | | - output_mode="data" |
21 | | - ) |
22 | | - |
23 | | - result = self.node(state) |
24 | | - |
25 | | - self.assertEqual(result["final_answer"], {"id": 1, "val": "A"}) |
26 | | - self.assertIn("Fast path", result["reasoning"][0]["content"]) |
| 10 | + @pytest.fixture |
| 11 | + def mock_llm(self): |
| 12 | + """Creates a mock LLM runnable.""" |
| 13 | + mock = MagicMock() |
| 14 | + mock.invoke.return_value = MagicMock(summary="Summary", content="Content", format_type="text") |
| 15 | + return mock |
27 | 16 |
|
28 | | - def test_slow_path_llm(self): |
29 | | - """Test LLM synthesis for complex or multiple results.""" |
30 | | - state = GraphState( |
31 | | - user_query="q", |
32 | | - intermediate_results=[{"id": 1}], |
33 | | - output_mode="synthesis" |
| 17 | + def test_sanitization_of_sensitive_errors(self, mock_llm): |
| 18 | + """Verifies that sensitive database errors are sanitized before reaching the LLM.""" |
| 19 | + # Setup |
| 20 | + node = AggregatorNode(llm=mock_llm) |
| 21 | + |
| 22 | + # Create a state with a sensitive DB error |
| 23 | + secret_message = "Syntax error in table 'confidential_users', column 'ssn'" |
| 24 | + error = PipelineError( |
| 25 | + node="executor", |
| 26 | + message=secret_message, |
| 27 | + severity=ErrorSeverity.ERROR, |
| 28 | + error_code=ErrorCode.DB_EXECUTION_ERROR, |
| 29 | + stack_trace="Traceback: ..." |
34 | 30 | ) |
35 | 31 |
|
36 | | - # Mock LLM Response |
37 | | - self.mock_llm.invoke.return_value = AggregatedResponse( |
38 | | - summary="Found 1 item.", |
39 | | - content="Item details...", |
40 | | - format_type="text" |
| 32 | + state = GraphState( |
| 33 | + user_query="SELECT * FROM users", |
| 34 | + intermediate_results=[], |
| 35 | + errors=[error] |
41 | 36 | ) |
| 37 | + |
| 38 | + # Mock the chain directly to avoid LangChain internals complexity |
| 39 | + node.chain = MagicMock() |
| 40 | + node.chain.invoke.return_value = MagicMock(summary="Safe", content="Safe", format_type="text") |
| 41 | + |
| 42 | + # Execute internal method that prepares prompt |
| 43 | + node._display_result_with_llm(state) |
| 44 | + |
| 45 | + # Verify CHAIN invoke arguments (input dict) |
| 46 | + call_args = node.chain.invoke.call_args[0][0] |
| 47 | + intermediate_res_str = call_args["intermediate_results"] |
42 | 48 |
|
43 | | - result = self.node(state) |
| 49 | + # Assertion: Secrets should NOT be present |
| 50 | + assert "confidential_users" not in intermediate_res_str |
| 51 | + assert "ssn" not in intermediate_res_str |
44 | 52 |
|
45 | | - self.assertIn("Found 1 item", result["final_answer"]) |
46 | | - self.assertIn("LLM Aggregation used", result["reasoning"][0]["content"]) |
| 53 | + # Assertion: Safe message SHOULD be present |
| 54 | + assert "An internal database error occurred" in intermediate_res_str |
47 | 55 |
|
48 | | - def test_slow_path_multiple_results(self): |
49 | | - """Test that multiple results force LLM path even if mode is data (actually, does it?). |
50 | | - Code says: if len(results) == 1 and not errors and mode == data -> Fast. |
51 | | - So 2 results -> Slow. |
52 | | - """ |
53 | | - state = GraphState( |
54 | | - user_query="q", |
55 | | - intermediate_results=[{"a": 1}, {"b": 2}], |
56 | | - output_mode="data" |
| 56 | + def test_pass_through_of_safe_errors(self, mock_llm): |
| 57 | + """Verifies that non-sensitive errors are passed through safely.""" |
| 58 | + node = AggregatorNode(llm=mock_llm) |
| 59 | + node.chain = MagicMock() |
| 60 | + node.chain.invoke.return_value = MagicMock(summary="Safe", content="Safe", format_type="text") |
| 61 | + |
| 62 | + safe_message = "I could not find a plan for this query." |
| 63 | + error = PipelineError( |
| 64 | + node="planner", |
| 65 | + message=safe_message, |
| 66 | + severity=ErrorSeverity.WARNING, |
| 67 | + error_code=ErrorCode.PLANNING_FAILURE |
57 | 68 | ) |
58 | 69 |
|
59 | | - self.mock_llm.invoke.return_value = AggregatedResponse(summary="Multi", content="Multi", format_type="text") |
60 | | - |
61 | | - result = self.node(state) |
| 70 | + state = GraphState( |
| 71 | + user_query="Help", |
| 72 | + intermediate_results=[], |
| 73 | + errors=[error] |
| 74 | + ) |
62 | 75 |
|
63 | | - self.assertIn("LLM Aggregation used", result["reasoning"][0]["content"]) |
64 | | - |
65 | | - def test_error_handling(self): |
66 | | - """Test that exception behaves correctly.""" |
67 | | - self.mock_llm.invoke.side_effect = Exception("Boom") |
| 76 | + node._display_result_with_llm(state) |
68 | 77 |
|
69 | | - state = GraphState(user_query="q", intermediate_results=[], output_mode="synthesis") |
70 | | - result = self.node(state) |
| 78 | + call_args = node.chain.invoke.call_args[0][0] |
| 79 | + intermediate_res_str = call_args["intermediate_results"] |
71 | 80 |
|
72 | | - self.assertEqual(len(result["errors"]), 1) |
73 | | - self.assertEqual(result["errors"][0].error_code, ErrorCode.AGGREGATOR_FAILED) |
74 | | - self.assertIn("Boom", result["final_answer"]) |
75 | | - |
76 | | -if __name__ == "__main__": |
77 | | - unittest.main() |
| 81 | + assert safe_message in intermediate_res_str |
0 commit comments