Skip to content

Commit b73a19a

Browse files
committed
improving the tests
1 parent 1b1259a commit b73a19a

1 file changed

Lines changed: 41 additions & 11 deletions

File tree

integrations/sqlalchemy/tests/test_sqlalchemy_table_retriever.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import pytest
56
from haystack.utils import Secret
67

78
import haystack_integrations.components.retrievers.sqlalchemy.sqlalchemy_table_retriever as module
@@ -74,25 +75,54 @@ def test_from_dict(self, monkeypatch):
7475

7576

7677
class TestSQLAlchemyTableRetrieverRun:
78+
@pytest.fixture()
79+
def retriever_with_data(self):
80+
init_sql = (
81+
"CREATE TABLE employees ("
82+
" id INTEGER PRIMARY KEY,"
83+
" name TEXT NOT NULL,"
84+
" department TEXT NOT NULL,"
85+
" salary INTEGER NOT NULL"
86+
");"
87+
"INSERT INTO employees VALUES (1, 'Alice', 'Engineering', 95000);"
88+
"INSERT INTO employees VALUES (2, 'Bob', 'Marketing', 72000);"
89+
"INSERT INTO employees VALUES (3, 'Carol', 'Engineering', 88000)"
90+
)
91+
retriever = SQLAlchemyTableRetriever(
92+
drivername="sqlite",
93+
database=":memory:",
94+
init_script=init_sql,
95+
)
96+
retriever.warm_up()
97+
return retriever
98+
7799
def test_run_empty_query(self):
78100
retriever = SQLAlchemyTableRetriever(drivername="sqlite", database=":memory:")
79101
result = retriever.run(query="")
80102
assert result["error"] == "empty query"
81103
assert result["dataframe"].empty
82104
assert result["table"] == ""
83105

84-
def test_run_returns_dataframe(self):
85-
retriever = SQLAlchemyTableRetriever(drivername="sqlite", database=":memory:")
86-
retriever.warm_up()
87-
result = retriever.run(query="SELECT 1 AS value")
88-
assert not result["dataframe"].empty
106+
def test_run_returns_dataframe(self, retriever_with_data):
107+
result = retriever_with_data.run(query="SELECT * FROM employees ORDER BY id")
108+
df = result["dataframe"]
89109
assert result["error"] == ""
90-
91-
def test_run_returns_markdown(self):
92-
retriever = SQLAlchemyTableRetriever(drivername="sqlite", database=":memory:")
93-
retriever.warm_up()
94-
result = retriever.run(query="SELECT 1 AS value")
95-
assert "|" in result["table"]
110+
assert list(df.columns) == ["id", "name", "department", "salary"]
111+
assert len(df) == 3
112+
assert df.iloc[0]["name"] == "Alice"
113+
assert df.iloc[1]["department"] == "Marketing"
114+
assert df.iloc[2]["salary"] == 88000
115+
116+
def test_run_returns_markdown(self, retriever_with_data):
117+
result = retriever_with_data.run(query="SELECT * FROM employees ORDER BY id")
118+
expected = (
119+
"| id | name | department | salary |\n"
120+
"| --- | --- | --- | --- |\n"
121+
"| 1 | Alice | Engineering | 95000 |\n"
122+
"| 2 | Bob | Marketing | 72000 |\n"
123+
"| 3 | Carol | Engineering | 88000 |"
124+
)
125+
assert result["table"] == expected
96126

97127
def test_run_sql_error(self):
98128
retriever = SQLAlchemyTableRetriever(drivername="sqlite", database=":memory:")

0 commit comments

Comments
 (0)