Skip to content

Commit 31a5d60

Browse files
fix(sqlalchemy): address review comments - init_script as list, add query type validation, remove dead code, expand README
1 parent b73a19a commit 31a5d60

3 files changed

Lines changed: 98 additions & 26 deletions

File tree

integrations/sqlalchemy/README.md

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,73 @@
77

88
---
99

10+
A Haystack integration for querying SQL databases via [SQLAlchemy](https://www.sqlalchemy.org/).
11+
Provides a `SQLAlchemyTableRetriever` component that connects to any SQLAlchemy-supported database,
12+
executes a SQL query, and returns results as a Pandas DataFrame and an optional Markdown table.
13+
14+
## Installation
15+
16+
```bash
17+
pip install sqlalchemy-haystack
18+
```
19+
20+
You also need to install the appropriate database driver for your backend:
21+
22+
| Backend | Driver package |
23+
|---------|---------------|
24+
| PostgreSQL | `psycopg2-binary` or `psycopg[binary]` |
25+
| MySQL / MariaDB | `pymysql` or `mysqlclient` |
26+
| SQLite | built-in (no extra package needed) |
27+
| MSSQL | `pyodbc` |
28+
| Oracle | `cx_oracle` or `oracledb` |
29+
30+
## Usage
31+
32+
```python
33+
from haystack_integrations.components.retrievers.sqlalchemy import SQLAlchemyTableRetriever
34+
35+
# SQLite in-memory example (no driver needed)
36+
retriever = SQLAlchemyTableRetriever(
37+
drivername="sqlite",
38+
database=":memory:",
39+
init_script=[
40+
"CREATE TABLE products (id INTEGER, name TEXT, price REAL)",
41+
"INSERT INTO products VALUES (1, 'Widget', 9.99)",
42+
"INSERT INTO products VALUES (2, 'Gadget', 19.99)",
43+
],
44+
)
45+
retriever.warm_up()
46+
47+
result = retriever.run(query="SELECT * FROM products WHERE price < 15")
48+
print(result["dataframe"])
49+
print(result["table"])
50+
```
51+
52+
For PostgreSQL:
53+
54+
```python
55+
from haystack.utils import Secret
56+
57+
retriever = SQLAlchemyTableRetriever(
58+
drivername="postgresql+psycopg2",
59+
host="localhost",
60+
port=5432,
61+
database="mydb",
62+
username="myuser",
63+
password=Secret.from_env_var("DB_PASSWORD"),
64+
)
65+
```
66+
67+
## Security
68+
69+
This component executes raw SQL queries passed at runtime. Keep the following in mind:
70+
71+
- **Never pass unsanitised user input** directly as a query — this exposes you to SQL injection.
72+
- **Use a read-only database user.** This is the most effective mitigation. Even if a malicious
73+
query is executed, a read-only user cannot modify or delete data.
74+
- **Restrict database permissions** to the minimum required — specific tables and schemas only,
75+
no DDL privileges (no `CREATE`, `DROP`, `ALTER`).
76+
1077
## Contributing
1178

1279
Refer to the general [Contribution Guidelines](https://github.com/deepset-ai/haystack-core-integrations/blob/main/CONTRIBUTING.md).

integrations/sqlalchemy/src/haystack_integrations/components/retrievers/sqlalchemy/sqlalchemy_table_retriever.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(
4848
host: str | None = None,
4949
port: int | None = None,
5050
database: str | None = None,
51-
init_script: str | None = None,
51+
init_script: list[str] | None = None,
5252
) -> None:
5353
"""
5454
Initialize SQLAlchemyTableRetriever.
@@ -60,9 +60,9 @@ def __init__(
6060
:param host: Database host.
6161
:param port: Database port.
6262
:param database: Database name or path (e.g., ``":memory:"`` for SQLite in-memory).
63-
:param init_script: Optional SQL statements executed once on ``warm_up()``
64-
(e.g., to create tables or insert seed data). Multiple statements should be
65-
separated by semicolons.
63+
:param init_script: Optional list of SQL statements executed once on ``warm_up()``
64+
(e.g., to create tables or insert seed data). Each statement should be a
65+
separate string in the list.
6666
"""
6767
self.drivername = drivername
6868
self.username = username
@@ -101,7 +101,7 @@ def warm_up(self) -> None:
101101

102102
if self.init_script:
103103
with self._engine.connect() as conn:
104-
for stmt in self.init_script.split(";"):
104+
for stmt in self.init_script:
105105
stripped = stmt.strip()
106106
if stripped:
107107
conn.execute(text(stripped))
@@ -158,16 +158,16 @@ def run(self, query: str) -> dict[str, Any]:
158158
- ``table``: A Markdown-formatted string of the results.
159159
- ``error``: An error message if the query failed, otherwise an empty string.
160160
"""
161+
if not isinstance(query, str):
162+
logger.warning("Query is not a string, returning empty DataFrame")
163+
return {"dataframe": DataFrame(), "table": "", "error": "query is not a string"}
164+
161165
if not query:
162166
return {"dataframe": DataFrame(), "table": "", "error": "empty query"}
163167

164168
if not self._warmed_up:
165169
self.warm_up()
166170

167-
if self._engine is None:
168-
msg = "Engine is not initialized. Call warm_up() first."
169-
raise RuntimeError(msg)
170-
171171
try:
172172
with self._engine.connect() as conn:
173173
result = conn.execute(text(query))

integrations/sqlalchemy/tests/test_sqlalchemy_table_retriever.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@ def test_init_all_params(self):
2929
host="localhost",
3030
port=5432,
3131
database="mydb",
32-
init_script="CREATE TABLE t (x INTEGER)",
32+
init_script=["CREATE TABLE t (x INTEGER)"],
3333
)
3434
assert retriever.drivername == "postgresql+psycopg2"
3535
assert retriever.username == "user"
3636
assert retriever.password is password
3737
assert retriever.host == "localhost"
3838
assert retriever.port == 5432
3939
assert retriever.database == "mydb"
40-
assert retriever.init_script == "CREATE TABLE t (x INTEGER)"
40+
assert retriever.init_script == ["CREATE TABLE t (x INTEGER)"]
4141

4242

4343
class TestSQLAlchemyTableRetrieverSerialization:
@@ -77,17 +77,12 @@ def test_from_dict(self, monkeypatch):
7777
class TestSQLAlchemyTableRetrieverRun:
7878
@pytest.fixture()
7979
def retriever_with_data(self):
80-
init_sql = (
81-
"CREATE TABLE employees ("
82-
" id INTEGER PRIMARY KEY,"
83-
" name TEXT NOT NULL,"
84-
" department TEXT NOT NULL,"
85-
" salary INTEGER NOT NULL"
86-
");"
87-
"INSERT INTO employees VALUES (1, 'Alice', 'Engineering', 95000);"
88-
"INSERT INTO employees VALUES (2, 'Bob', 'Marketing', 72000);"
89-
"INSERT INTO employees VALUES (3, 'Carol', 'Engineering', 88000)"
90-
)
80+
init_sql = [
81+
"CREATE TABLE employees (id INTEGER PRIMARY KEY, name TEXT NOT NULL, department TEXT NOT NULL, salary INTEGER NOT NULL)",
82+
"INSERT INTO employees VALUES (1, 'Alice', 'Engineering', 95000)",
83+
"INSERT INTO employees VALUES (2, 'Bob', 'Marketing', 72000)",
84+
"INSERT INTO employees VALUES (3, 'Carol', 'Engineering', 88000)",
85+
]
9186
retriever = SQLAlchemyTableRetriever(
9287
drivername="sqlite",
9388
database=":memory:",
@@ -96,6 +91,13 @@ def retriever_with_data(self):
9691
retriever.warm_up()
9792
return retriever
9893

94+
def test_run_non_string_query(self):
95+
retriever = SQLAlchemyTableRetriever(drivername="sqlite", database=":memory:")
96+
result = retriever.run(query=123) # type: ignore[arg-type]
97+
assert result["error"] == "query is not a string"
98+
assert result["dataframe"].empty
99+
assert result["table"] == ""
100+
99101
def test_run_empty_query(self):
100102
retriever = SQLAlchemyTableRetriever(drivername="sqlite", database=":memory:")
101103
result = retriever.run(query="")
@@ -132,9 +134,12 @@ def test_run_sql_error(self):
132134
assert result["dataframe"].empty
133135

134136
def test_max_row_limit(self, monkeypatch):
135-
init_sql = (
136-
"CREATE TABLE t (x INTEGER);INSERT INTO t VALUES (1);INSERT INTO t VALUES (2);INSERT INTO t VALUES (3)"
137-
)
137+
init_sql = [
138+
"CREATE TABLE t (x INTEGER)",
139+
"INSERT INTO t VALUES (1)",
140+
"INSERT INTO t VALUES (2)",
141+
"INSERT INTO t VALUES (3)",
142+
]
138143
retriever = SQLAlchemyTableRetriever(
139144
drivername="sqlite",
140145
database=":memory:",
@@ -146,7 +151,7 @@ def test_max_row_limit(self, monkeypatch):
146151
assert len(result["dataframe"]) == 2
147152

148153
def test_warm_up_with_init_script(self):
149-
init_sql = "CREATE TABLE greetings (msg TEXT); INSERT INTO greetings VALUES ('hello')"
154+
init_sql = ["CREATE TABLE greetings (msg TEXT)", "INSERT INTO greetings VALUES ('hello')"]
150155
retriever = SQLAlchemyTableRetriever(
151156
drivername="sqlite",
152157
database=":memory:",

0 commit comments

Comments
 (0)