|
2 | 2 | # |
3 | 3 | # SPDX-License-Identifier: Apache-2.0 |
4 | 4 |
|
| 5 | +import pytest |
5 | 6 | from haystack.utils import Secret |
6 | 7 |
|
7 | 8 | import haystack_integrations.components.retrievers.sqlalchemy.sqlalchemy_table_retriever as module |
@@ -74,25 +75,54 @@ def test_from_dict(self, monkeypatch): |
74 | 75 |
|
75 | 76 |
|
76 | 77 | class TestSQLAlchemyTableRetrieverRun: |
| 78 | + @pytest.fixture() |
| 79 | + def retriever_with_data(self): |
| 80 | + init_sql = ( |
| 81 | + "CREATE TABLE employees (" |
| 82 | + " id INTEGER PRIMARY KEY," |
| 83 | + " name TEXT NOT NULL," |
| 84 | + " department TEXT NOT NULL," |
| 85 | + " salary INTEGER NOT NULL" |
| 86 | + ");" |
| 87 | + "INSERT INTO employees VALUES (1, 'Alice', 'Engineering', 95000);" |
| 88 | + "INSERT INTO employees VALUES (2, 'Bob', 'Marketing', 72000);" |
| 89 | + "INSERT INTO employees VALUES (3, 'Carol', 'Engineering', 88000)" |
| 90 | + ) |
| 91 | + retriever = SQLAlchemyTableRetriever( |
| 92 | + drivername="sqlite", |
| 93 | + database=":memory:", |
| 94 | + init_script=init_sql, |
| 95 | + ) |
| 96 | + retriever.warm_up() |
| 97 | + return retriever |
| 98 | + |
77 | 99 | def test_run_empty_query(self): |
78 | 100 | retriever = SQLAlchemyTableRetriever(drivername="sqlite", database=":memory:") |
79 | 101 | result = retriever.run(query="") |
80 | 102 | assert result["error"] == "empty query" |
81 | 103 | assert result["dataframe"].empty |
82 | 104 | assert result["table"] == "" |
83 | 105 |
|
84 | | - def test_run_returns_dataframe(self): |
85 | | - retriever = SQLAlchemyTableRetriever(drivername="sqlite", database=":memory:") |
86 | | - retriever.warm_up() |
87 | | - result = retriever.run(query="SELECT 1 AS value") |
88 | | - assert not result["dataframe"].empty |
| 106 | + def test_run_returns_dataframe(self, retriever_with_data): |
| 107 | + result = retriever_with_data.run(query="SELECT * FROM employees ORDER BY id") |
| 108 | + df = result["dataframe"] |
89 | 109 | assert result["error"] == "" |
90 | | - |
91 | | - def test_run_returns_markdown(self): |
92 | | - retriever = SQLAlchemyTableRetriever(drivername="sqlite", database=":memory:") |
93 | | - retriever.warm_up() |
94 | | - result = retriever.run(query="SELECT 1 AS value") |
95 | | - assert "|" in result["table"] |
| 110 | + assert list(df.columns) == ["id", "name", "department", "salary"] |
| 111 | + assert len(df) == 3 |
| 112 | + assert df.iloc[0]["name"] == "Alice" |
| 113 | + assert df.iloc[1]["department"] == "Marketing" |
| 114 | + assert df.iloc[2]["salary"] == 88000 |
| 115 | + |
| 116 | + def test_run_returns_markdown(self, retriever_with_data): |
| 117 | + result = retriever_with_data.run(query="SELECT * FROM employees ORDER BY id") |
| 118 | + expected = ( |
| 119 | + "| id | name | department | salary |\n" |
| 120 | + "| --- | --- | --- | --- |\n" |
| 121 | + "| 1 | Alice | Engineering | 95000 |\n" |
| 122 | + "| 2 | Bob | Marketing | 72000 |\n" |
| 123 | + "| 3 | Carol | Engineering | 88000 |" |
| 124 | + ) |
| 125 | + assert result["table"] == expected |
96 | 126 |
|
97 | 127 | def test_run_sql_error(self): |
98 | 128 | retriever = SQLAlchemyTableRetriever(drivername="sqlite", database=":memory:") |
|
0 commit comments