Skip to content

Commit ed55ef1

Browse files
fix: PgVectorDocumentStore _treat_meta_field and comparison functions now return Composed - string escaping done by psycopg (#2964)
* _treat_meta_field and comparison functions now return Composed * using Composable in function sigs, base class of both SQL and Composed * removing all str safeguards + adding back a comment * replacing test_treat_meta_field by a more readiable version * improve filter test readability using SQL string assertions with psycopg.adapt.Transformer * formatting
1 parent ed939e4 commit ed55ef1

2 files changed

Lines changed: 95 additions & 79 deletions

File tree

integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py

Lines changed: 51 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
import re
65
from datetime import datetime
76
from itertools import chain
87
from typing import Any, Literal
98

109
from haystack.errors import FilterError
11-
from psycopg.sql import SQL, Composed
10+
from psycopg.sql import SQL, Composable, Composed, Identifier
11+
from psycopg.sql import Literal as SQLLiteral
1212
from psycopg.types.json import Jsonb
1313

1414
# we need this mapping to cast meta values to the correct type,
@@ -21,7 +21,6 @@
2121
}
2222

2323
NO_VALUE = "no_value"
24-
SAFE_META_FIELD_RE = re.compile(r"^[a-zA-Z0-9_-]+$")
2524

2625

2726
def _validate_filters(filters: dict[str, Any] | None = None) -> None:
@@ -48,13 +47,13 @@ def _convert_filters_to_where_clause_and_params(
4847
else:
4948
query, values = _parse_logical_condition(filters)
5049

51-
where_clause = SQL(f" {operator} ") + SQL(query)
50+
where_clause = SQL(f" {operator} ") + query
5251
params = tuple(value for value in values if value != NO_VALUE)
5352

5453
return where_clause, params
5554

5655

57-
def _parse_logical_condition(condition: dict[str, Any]) -> tuple[str, list[Any]]:
56+
def _parse_logical_condition(condition: dict[str, Any]) -> tuple[Composed, list[Any]]:
5857
if "operator" not in condition:
5958
msg = f"'operator' key missing in {condition}"
6059
raise FilterError(msg)
@@ -84,17 +83,14 @@ def _parse_logical_condition(condition: dict[str, Any]) -> tuple[str, list[Any]]
8483
values = list(chain.from_iterable(values))
8584

8685
if operator == "AND":
87-
sql_query = f"({' AND '.join(query_parts)})"
88-
elif operator == "OR":
89-
sql_query = f"({' OR '.join(query_parts)})"
90-
else:
91-
msg = f"Unknown logical operator '{operator}'"
92-
raise FilterError(msg)
86+
sql_query = SQL("(") + SQL(" AND ").join(query_parts) + SQL(")")
87+
else: # operator == "OR"
88+
sql_query = SQL("(") + SQL(" OR ").join(query_parts) + SQL(")")
9389

9490
return sql_query, values
9591

9692

97-
def _parse_comparison_condition(condition: dict[str, Any]) -> tuple[str, list[Any]]:
93+
def _parse_comparison_condition(condition: dict[str, Any]) -> tuple[Composed, list[Any]]:
9894
field: str = condition["field"]
9995
if "operator" not in condition:
10096
msg = f"'operator' key missing in {condition}"
@@ -110,34 +106,37 @@ def _parse_comparison_condition(condition: dict[str, Any]) -> tuple[str, list[An
110106
value: Any = condition["value"]
111107

112108
if field.startswith("meta."):
113-
field = _treat_meta_field(field, value)
109+
sql_field: Composable = _treat_meta_field(field, value)
110+
else:
111+
sql_field = Identifier(field)
114112

115-
field, value = COMPARISON_OPERATORS[operator](field, value)
116-
return field, [value]
113+
sql_expr, value = COMPARISON_OPERATORS[operator](sql_field, value)
114+
return sql_expr, [value]
117115

118116

119-
def _treat_meta_field(field: str, value: Any) -> str:
117+
def _treat_meta_field(field: str, value: Any) -> Composed:
120118
"""
121-
Internal method that modifies the field str
122-
to make the meta JSONB field queryable.
119+
Internal method that returns a psycopg Composed object
120+
to make the meta JSONB field queryable safely.
121+
122+
Uses psycopg.sql.Literal to embed the field name, preventing SQL injection
123+
via metadata field names without requiring regex validation.
124+
125+
Use the ->> operator to access keys in the meta JSONB field.
123126
124127
Examples:
125128
>>> _treat_meta_field(field="meta.number", value=9)
126-
"(meta->>'number')::integer"
129+
Composed([SQL('(meta->>'), Literal('number'), SQL(')::integer')])
127130
128131
>>> _treat_meta_field(field="meta.name", value="my_name")
129-
"meta->>'name'"
130-
"""
132+
Composed([SQL('meta->>'), Literal('name')])
131133
132-
# use the ->> operator to access keys in the meta JSONB field
134+
"""
133135
field_name = field.split(".", 1)[-1]
134-
if not SAFE_META_FIELD_RE.match(field_name):
135-
msg = (
136-
f"Invalid metadata field name '{field_name}'. "
137-
"Only alphanumeric characters, dashes, and underscores are allowed."
138-
)
139-
raise FilterError(msg)
140-
field = f"meta->>'{field_name}'"
136+
137+
# Use SQLLiteral to safely embed the field name as a SQL string literal,
138+
# preventing SQL injection via metadata field names.
139+
composed: Composed = SQL("meta->>") + SQLLiteral(field_name)
141140

142141
# meta fields are stored as strings in the JSONB field,
143142
# so we need to cast them to the correct type
@@ -146,25 +145,24 @@ def _treat_meta_field(field: str, value: Any) -> str:
146145
type_value = PYTHON_TYPES_TO_PG_TYPES.get(type(value[0]))
147146

148147
if type_value:
149-
field = f"({field})::{type_value}"
148+
composed = SQL("(") + composed + SQL(f")::{type_value}")
150149

151-
return field
150+
return composed
152151

153152

154-
def _equal(field: str, value: Any) -> tuple[str, Any]:
153+
def _equal(field: Composable, value: Any) -> tuple[Composed, Any]:
155154
if value is None:
156-
# NO_VALUE is a placeholder that will be removed in _convert_filters_to_where_clause_and_params
157-
return f"{field} IS NULL", NO_VALUE
158-
return f"{field} = %s", value
155+
return SQL("{} IS NULL").format(field), NO_VALUE
156+
return SQL("{} = %s").format(field), value
159157

160158

161-
def _not_equal(field: str, value: Any) -> tuple[str, Any]:
159+
def _not_equal(field: Composable, value: Any) -> tuple[Composed, Any]:
162160
# we use IS DISTINCT FROM to correctly handle NULL values
163161
# (not handled by !=)
164-
return f"{field} IS DISTINCT FROM %s", value
162+
return SQL("{} IS DISTINCT FROM %s").format(field), value
165163

166164

167-
def _greater_than(field: str, value: Any) -> tuple[str, Any]:
165+
def _greater_than(field: Composable, value: Any) -> tuple[Composed, Any]:
168166
if isinstance(value, str):
169167
try:
170168
datetime.fromisoformat(value)
@@ -177,11 +175,10 @@ def _greater_than(field: str, value: Any) -> tuple[str, Any]:
177175
if type(value) in [list, Jsonb]:
178176
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
179177
raise FilterError(msg)
178+
return SQL("{} > %s").format(field), value
180179

181-
return f"{field} > %s", value
182180

183-
184-
def _greater_than_equal(field: str, value: Any) -> tuple[str, Any]:
181+
def _greater_than_equal(field: Composable, value: Any) -> tuple[Composed, Any]:
185182
if isinstance(value, str):
186183
try:
187184
datetime.fromisoformat(value)
@@ -194,11 +191,10 @@ def _greater_than_equal(field: str, value: Any) -> tuple[str, Any]:
194191
if type(value) in [list, Jsonb]:
195192
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
196193
raise FilterError(msg)
197-
198-
return f"{field} >= %s", value
194+
return SQL("{} >= %s").format(field), value
199195

200196

201-
def _less_than(field: str, value: Any) -> tuple[str, Any]:
197+
def _less_than(field: Composable, value: Any) -> tuple[Composed, Any]:
202198
if isinstance(value, str):
203199
try:
204200
datetime.fromisoformat(value)
@@ -211,11 +207,10 @@ def _less_than(field: str, value: Any) -> tuple[str, Any]:
211207
if type(value) in [list, Jsonb]:
212208
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
213209
raise FilterError(msg)
210+
return SQL("{} < %s").format(field), value
214211

215-
return f"{field} < %s", value
216212

217-
218-
def _less_than_equal(field: str, value: Any) -> tuple[str, Any]:
213+
def _less_than_equal(field: Composable, value: Any) -> tuple[Composed, Any]:
219214
if isinstance(value, str):
220215
try:
221216
datetime.fromisoformat(value)
@@ -228,39 +223,37 @@ def _less_than_equal(field: str, value: Any) -> tuple[str, Any]:
228223
if type(value) in [list, Jsonb]:
229224
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
230225
raise FilterError(msg)
231-
232-
return f"{field} <= %s", value
226+
return SQL("{} <= %s").format(field), value
233227

234228

235-
def _not_in(field: str, value: Any) -> tuple[str, list]:
229+
def _not_in(field: Composable, value: Any) -> tuple[Composed, list]:
236230
if not isinstance(value, list):
237231
msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone"
238232
raise FilterError(msg)
239-
240-
return f"{field} IS NULL OR {field} != ALL(%s)", [value]
233+
return SQL("{} IS NULL OR {} != ALL(%s)").format(field, field), [value]
241234

242235

243-
def _in(field: str, value: Any) -> tuple[str, list]:
236+
def _in(field: Composable, value: Any) -> tuple[Composed, list]:
244237
if not isinstance(value, list):
245238
msg = f"{field}'s value must be a list when using 'in' comparator in Pinecone"
246239
raise FilterError(msg)
247240

248241
# see https://www.psycopg.org/psycopg3/docs/basic/adapt.html#lists-adaptation
249-
return f"{field} = ANY(%s)", [value]
242+
return SQL("{} = ANY(%s)").format(field), [value]
250243

251244

252-
def _like(field: str, value: Any) -> tuple[str, Any]:
245+
def _like(field: Composable, value: Any) -> tuple[Composed, Any]:
253246
if not isinstance(value, str):
254247
msg = f"{field}'s value must be a str when using 'LIKE' "
255248
raise FilterError(msg)
256-
return f"{field} LIKE %s", value
249+
return SQL("{} LIKE %s").format(field), value
257250

258251

259-
def _not_like(field: str, value: Any) -> tuple[str, Any]:
252+
def _not_like(field: Composable, value: Any) -> tuple[Composed, Any]:
260253
if not isinstance(value, str):
261254
msg = f"{field}'s value must be a str when using 'LIKE' "
262255
raise FilterError(msg)
263-
return f"{field} NOT LIKE %s", value
256+
return SQL("{} NOT LIKE %s").format(field), value
264257

265258

266259
COMPARISON_OPERATORS = {

integrations/pgvector/tests/test_filters.py

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1+
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
15
import pytest
26
from haystack.dataclasses.document import Document
37
from haystack.testing.document_store import FilterDocumentsTest
4-
from psycopg.sql import SQL
8+
from psycopg.adapt import Transformer
9+
from psycopg.sql import SQL, Composed
10+
from psycopg.sql import Literal as SQLLiteral
511

612
from haystack_integrations.document_stores.pgvector.filters import (
713
FilterError,
@@ -13,6 +19,11 @@
1319
)
1420

1521

22+
def _render(composed: Composed) -> str:
23+
"""Render a psycopg Composed object to a plain SQL string for assertions."""
24+
return composed.as_string(Transformer())
25+
26+
1627
@pytest.mark.integration
1728
class TestFilters(FilterDocumentsTest):
1829
def assert_documents_are_equal(self, received: list[Document], expected: list[Document]):
@@ -129,24 +140,26 @@ def test_complex_filter(self, document_store, filterable_docs):
129140

130141

131142
def test_treat_meta_field():
132-
assert _treat_meta_field(field="meta.number", value=9) == "(meta->>'number')::integer"
133-
assert _treat_meta_field(field="meta.number", value=[1, 2, 3]) == "(meta->>'number')::integer"
134-
assert _treat_meta_field(field="meta.name", value="my_name") == "meta->>'name'"
135-
assert _treat_meta_field(field="meta.name", value=["my_name"]) == "meta->>'name'"
136-
assert _treat_meta_field(field="meta.number", value=1.1) == "(meta->>'number')::real"
137-
assert _treat_meta_field(field="meta.number", value=[1.1, 2.2, 3.3]) == "(meta->>'number')::real"
138-
assert _treat_meta_field(field="meta.bool", value=True) == "(meta->>'bool')::boolean"
139-
assert _treat_meta_field(field="meta.bool", value=[True, False, True]) == "(meta->>'bool')::boolean"
143+
cast_integer = SQL("(") + SQL("meta->>") + SQLLiteral("number") + SQL(")::integer")
144+
no_cast_name = SQL("meta->>") + SQLLiteral("name")
145+
cast_real = SQL("(") + SQL("meta->>") + SQLLiteral("number") + SQL(")::real")
146+
cast_boolean = SQL("(") + SQL("meta->>") + SQLLiteral("bool") + SQL(")::boolean")
140147

141-
# do not cast the field if its value is not one of the known types, an empty list or None
142-
assert _treat_meta_field(field="meta.other", value={"a": 3, "b": "example"}) == "meta->>'other'"
143-
assert _treat_meta_field(field="meta.empty_list", value=[]) == "meta->>'empty_list'"
144-
assert _treat_meta_field(field="meta.name", value=None) == "meta->>'name'"
148+
assert _treat_meta_field(field="meta.number", value=9) == cast_integer
149+
assert _treat_meta_field(field="meta.number", value=[1, 2, 3]) == cast_integer
150+
assert _treat_meta_field(field="meta.name", value="my_name") == no_cast_name
151+
assert _treat_meta_field(field="meta.name", value=["my_name"]) == no_cast_name
152+
assert _treat_meta_field(field="meta.number", value=1.1) == cast_real
153+
assert _treat_meta_field(field="meta.number", value=[1.1, 2.2]) == cast_real
154+
assert _treat_meta_field(field="meta.bool", value=True) == cast_boolean
155+
assert _treat_meta_field(field="meta.bool", value=[True, False]) == cast_boolean
145156

146157

147-
def test_treat_meta_field_rejects_unsafe_metadata_key():
148-
with pytest.raises(FilterError, match="Invalid metadata field name"):
149-
_treat_meta_field(field="meta.name' OR 1=1 --", value="x")
158+
def test_treat_meta_field_sql_injection_is_safely_escaped():
159+
# SQL injection attempts are safely escaped by SQLLiteral rather than rejected
160+
result = _treat_meta_field(field="meta.name' OR 1=1 --", value="x")
161+
assert isinstance(result, Composed)
162+
assert result == SQL("meta->>") + SQLLiteral("name' OR 1=1 --")
150163

151164

152165
def test_comparison_condition_missing_operator():
@@ -206,10 +219,16 @@ def test_logical_condition_nested():
206219
],
207220
}
208221
query, values = _parse_logical_condition(condition)
209-
assert query == (
210-
"((meta->>'domain' IS DISTINCT FROM %s OR meta->>'chapter' = ANY(%s)) "
211-
"AND ((meta->>'number')::integer >= %s OR meta->>'author' IS NULL OR meta->>'author' != ALL(%s)))"
222+
assert isinstance(query, Composed)
223+
224+
expected_sql = (
225+
"("
226+
"(meta->>'domain' IS DISTINCT FROM %s OR meta->>'chapter' = ANY(%s))"
227+
" AND "
228+
"((meta->>'number')::integer >= %s OR meta->>'author' IS NULL OR meta->>'author' != ALL(%s))"
229+
")"
212230
)
231+
assert _render(query) == expected_sql
213232
assert values == ["science", [["intro", "conclusion"]], 90, [["John", "Jane"]]]
214233

215234

@@ -222,7 +241,9 @@ def test_convert_filters_to_where_clause_and_params():
222241
],
223242
}
224243
where_clause, params = _convert_filters_to_where_clause_and_params(filters)
225-
assert where_clause == SQL(" WHERE ") + SQL("((meta->>'number')::integer = %s AND meta->>'chapter' = %s)")
244+
245+
expected_sql = " WHERE ((meta->>'number')::integer = %s AND meta->>'chapter' = %s)"
246+
assert _render(where_clause) == expected_sql
226247
assert params == (100, "intro")
227248

228249

@@ -235,7 +256,9 @@ def test_convert_filters_to_where_clause_and_params_handle_null():
235256
],
236257
}
237258
where_clause, params = _convert_filters_to_where_clause_and_params(filters)
238-
assert where_clause == SQL(" WHERE ") + SQL("(meta->>'number' IS NULL AND meta->>'chapter' = %s)")
259+
260+
expected_sql = " WHERE (meta->>'number' IS NULL AND meta->>'chapter' = %s)"
261+
assert _render(where_clause) == expected_sql
239262
assert params == ("intro",)
240263

241264

0 commit comments

Comments
 (0)