Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
#
# SPDX-License-Identifier: Apache-2.0

import re
from datetime import datetime
from itertools import chain
from typing import Any, Literal

from haystack.errors import FilterError
from psycopg.sql import SQL, Composed
from psycopg.sql import SQL, Composable, Composed, Identifier
from psycopg.sql import Literal as SQLLiteral
from psycopg.types.json import Jsonb

# we need this mapping to cast meta values to the correct type,
Expand All @@ -21,7 +21,6 @@
}

NO_VALUE = "no_value"
SAFE_META_FIELD_RE = re.compile(r"^[a-zA-Z0-9_-]+$")


def _validate_filters(filters: dict[str, Any] | None = None) -> None:
Expand All @@ -48,13 +47,13 @@ def _convert_filters_to_where_clause_and_params(
else:
query, values = _parse_logical_condition(filters)

where_clause = SQL(f" {operator} ") + SQL(query)
where_clause = SQL(f" {operator} ") + query
params = tuple(value for value in values if value != NO_VALUE)

return where_clause, params


def _parse_logical_condition(condition: dict[str, Any]) -> tuple[str, list[Any]]:
def _parse_logical_condition(condition: dict[str, Any]) -> tuple[Composed, list[Any]]:
if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
raise FilterError(msg)
Expand Down Expand Up @@ -84,17 +83,14 @@ def _parse_logical_condition(condition: dict[str, Any]) -> tuple[str, list[Any]]
values = list(chain.from_iterable(values))

if operator == "AND":
sql_query = f"({' AND '.join(query_parts)})"
elif operator == "OR":
sql_query = f"({' OR '.join(query_parts)})"
else:
msg = f"Unknown logical operator '{operator}'"
raise FilterError(msg)
sql_query = SQL("(") + SQL(" AND ").join(query_parts) + SQL(")")
else: # operator == "OR"
sql_query = SQL("(") + SQL(" OR ").join(query_parts) + SQL(")")
Comment thread
anakin87 marked this conversation as resolved.

return sql_query, values


def _parse_comparison_condition(condition: dict[str, Any]) -> tuple[str, list[Any]]:
def _parse_comparison_condition(condition: dict[str, Any]) -> tuple[Composed, list[Any]]:
field: str = condition["field"]
if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
Expand All @@ -110,34 +106,37 @@ def _parse_comparison_condition(condition: dict[str, Any]) -> tuple[str, list[An
value: Any = condition["value"]

if field.startswith("meta."):
field = _treat_meta_field(field, value)
sql_field: Composable = _treat_meta_field(field, value)
else:
sql_field = Identifier(field)

field, value = COMPARISON_OPERATORS[operator](field, value)
return field, [value]
sql_expr, value = COMPARISON_OPERATORS[operator](sql_field, value)
return sql_expr, [value]


def _treat_meta_field(field: str, value: Any) -> str:
def _treat_meta_field(field: str, value: Any) -> Composed:
"""
Internal method that modifies the field str
to make the meta JSONB field queryable.
Internal method that returns a psycopg Composed object
to make the meta JSONB field queryable safely.

Uses psycopg.sql.Literal to embed the field name, preventing SQL injection
via metadata field names without requiring regex validation.

Use the ->> operator to access keys in the meta JSONB field.

Examples:
>>> _treat_meta_field(field="meta.number", value=9)
"(meta->>'number')::integer"
Composed([SQL('(meta->>'), Literal('number'), SQL(')::integer')])

>>> _treat_meta_field(field="meta.name", value="my_name")
"meta->>'name'"
"""
Composed([SQL('meta->>'), Literal('name')])

# use the ->> operator to access keys in the meta JSONB field
Comment thread
anakin87 marked this conversation as resolved.
"""
field_name = field.split(".", 1)[-1]
if not SAFE_META_FIELD_RE.match(field_name):
msg = (
f"Invalid metadata field name '{field_name}'. "
"Only alphanumeric characters, dashes, and underscores are allowed."
)
raise FilterError(msg)
field = f"meta->>'{field_name}'"

# Use SQLLiteral to safely embed the field name as a SQL string literal,
# preventing SQL injection via metadata field names.
composed: Composed = SQL("meta->>") + SQLLiteral(field_name)

# meta fields are stored as strings in the JSONB field,
# so we need to cast them to the correct type
Expand All @@ -146,25 +145,24 @@ def _treat_meta_field(field: str, value: Any) -> str:
type_value = PYTHON_TYPES_TO_PG_TYPES.get(type(value[0]))

if type_value:
field = f"({field})::{type_value}"
composed = SQL("(") + composed + SQL(f")::{type_value}")

return field
return composed


def _equal(field: str, value: Any) -> tuple[str, Any]:
def _equal(field: Composable, value: Any) -> tuple[Composed, Any]:
if value is None:
# NO_VALUE is a placeholder that will be removed in _convert_filters_to_where_clause_and_params
return f"{field} IS NULL", NO_VALUE
return f"{field} = %s", value
return SQL("{} IS NULL").format(field), NO_VALUE
return SQL("{} = %s").format(field), value


def _not_equal(field: str, value: Any) -> tuple[str, Any]:
def _not_equal(field: Composable, value: Any) -> tuple[Composed, Any]:
# we use IS DISTINCT FROM to correctly handle NULL values
# (not handled by !=)
return f"{field} IS DISTINCT FROM %s", value
return SQL("{} IS DISTINCT FROM %s").format(field), value


def _greater_than(field: str, value: Any) -> tuple[str, Any]:
def _greater_than(field: Composable, value: Any) -> tuple[Composed, Any]:
if isinstance(value, str):
try:
datetime.fromisoformat(value)
Expand All @@ -177,11 +175,10 @@ def _greater_than(field: str, value: Any) -> tuple[str, Any]:
if type(value) in [list, Jsonb]:
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
raise FilterError(msg)
return SQL("{} > %s").format(field), value

return f"{field} > %s", value


def _greater_than_equal(field: str, value: Any) -> tuple[str, Any]:
def _greater_than_equal(field: Composable, value: Any) -> tuple[Composed, Any]:
if isinstance(value, str):
try:
datetime.fromisoformat(value)
Expand All @@ -194,11 +191,10 @@ def _greater_than_equal(field: str, value: Any) -> tuple[str, Any]:
if type(value) in [list, Jsonb]:
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
raise FilterError(msg)

return f"{field} >= %s", value
return SQL("{} >= %s").format(field), value


def _less_than(field: str, value: Any) -> tuple[str, Any]:
def _less_than(field: Composable, value: Any) -> tuple[Composed, Any]:
if isinstance(value, str):
try:
datetime.fromisoformat(value)
Expand All @@ -211,11 +207,10 @@ def _less_than(field: str, value: Any) -> tuple[str, Any]:
if type(value) in [list, Jsonb]:
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
raise FilterError(msg)
return SQL("{} < %s").format(field), value

return f"{field} < %s", value


def _less_than_equal(field: str, value: Any) -> tuple[str, Any]:
def _less_than_equal(field: Composable, value: Any) -> tuple[Composed, Any]:
if isinstance(value, str):
try:
datetime.fromisoformat(value)
Expand All @@ -228,39 +223,37 @@ def _less_than_equal(field: str, value: Any) -> tuple[str, Any]:
if type(value) in [list, Jsonb]:
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
raise FilterError(msg)

return f"{field} <= %s", value
return SQL("{} <= %s").format(field), value


def _not_in(field: str, value: Any) -> tuple[str, list]:
def _not_in(field: Composable, value: Any) -> tuple[Composed, list]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone"
raise FilterError(msg)

return f"{field} IS NULL OR {field} != ALL(%s)", [value]
return SQL("{} IS NULL OR {} != ALL(%s)").format(field, field), [value]


def _in(field: str, value: Any) -> tuple[str, list]:
def _in(field: Composable, value: Any) -> tuple[Composed, list]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'in' comparator in Pinecone"
raise FilterError(msg)

# see https://www.psycopg.org/psycopg3/docs/basic/adapt.html#lists-adaptation
return f"{field} = ANY(%s)", [value]
return SQL("{} = ANY(%s)").format(field), [value]


def _like(field: str, value: Any) -> tuple[str, Any]:
def _like(field: Composable, value: Any) -> tuple[Composed, Any]:
if not isinstance(value, str):
msg = f"{field}'s value must be a str when using 'LIKE' "
raise FilterError(msg)
return f"{field} LIKE %s", value
return SQL("{} LIKE %s").format(field), value


def _not_like(field: str, value: Any) -> tuple[str, Any]:
def _not_like(field: Composable, value: Any) -> tuple[Composed, Any]:
if not isinstance(value, str):
msg = f"{field}'s value must be a str when using 'LIKE' "
raise FilterError(msg)
return f"{field} NOT LIKE %s", value
return SQL("{} NOT LIKE %s").format(field), value


COMPARISON_OPERATORS = {
Expand Down
65 changes: 44 additions & 21 deletions integrations/pgvector/tests/test_filters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

import pytest
from haystack.dataclasses.document import Document
from haystack.testing.document_store import FilterDocumentsTest
from psycopg.sql import SQL
from psycopg.adapt import Transformer
from psycopg.sql import SQL, Composed
from psycopg.sql import Literal as SQLLiteral

from haystack_integrations.document_stores.pgvector.filters import (
FilterError,
Expand All @@ -13,6 +19,11 @@
)


def _render(composed: Composed) -> str:
"""Render a psycopg Composed object to a plain SQL string for assertions."""
return composed.as_string(Transformer())


@pytest.mark.integration
class TestFilters(FilterDocumentsTest):
def assert_documents_are_equal(self, received: list[Document], expected: list[Document]):
Expand Down Expand Up @@ -129,24 +140,26 @@ def test_complex_filter(self, document_store, filterable_docs):


def test_treat_meta_field():
assert _treat_meta_field(field="meta.number", value=9) == "(meta->>'number')::integer"
assert _treat_meta_field(field="meta.number", value=[1, 2, 3]) == "(meta->>'number')::integer"
assert _treat_meta_field(field="meta.name", value="my_name") == "meta->>'name'"
assert _treat_meta_field(field="meta.name", value=["my_name"]) == "meta->>'name'"
assert _treat_meta_field(field="meta.number", value=1.1) == "(meta->>'number')::real"
assert _treat_meta_field(field="meta.number", value=[1.1, 2.2, 3.3]) == "(meta->>'number')::real"
assert _treat_meta_field(field="meta.bool", value=True) == "(meta->>'bool')::boolean"
assert _treat_meta_field(field="meta.bool", value=[True, False, True]) == "(meta->>'bool')::boolean"
cast_integer = SQL("(") + SQL("meta->>") + SQLLiteral("number") + SQL(")::integer")
no_cast_name = SQL("meta->>") + SQLLiteral("name")
cast_real = SQL("(") + SQL("meta->>") + SQLLiteral("number") + SQL(")::real")
cast_boolean = SQL("(") + SQL("meta->>") + SQLLiteral("bool") + SQL(")::boolean")

# do not cast the field if its value is not one of the known types, an empty list or None
assert _treat_meta_field(field="meta.other", value={"a": 3, "b": "example"}) == "meta->>'other'"
assert _treat_meta_field(field="meta.empty_list", value=[]) == "meta->>'empty_list'"
assert _treat_meta_field(field="meta.name", value=None) == "meta->>'name'"
assert _treat_meta_field(field="meta.number", value=9) == cast_integer
assert _treat_meta_field(field="meta.number", value=[1, 2, 3]) == cast_integer
assert _treat_meta_field(field="meta.name", value="my_name") == no_cast_name
assert _treat_meta_field(field="meta.name", value=["my_name"]) == no_cast_name
assert _treat_meta_field(field="meta.number", value=1.1) == cast_real
assert _treat_meta_field(field="meta.number", value=[1.1, 2.2]) == cast_real
assert _treat_meta_field(field="meta.bool", value=True) == cast_boolean
assert _treat_meta_field(field="meta.bool", value=[True, False]) == cast_boolean


def test_treat_meta_field_rejects_unsafe_metadata_key():
with pytest.raises(FilterError, match="Invalid metadata field name"):
_treat_meta_field(field="meta.name' OR 1=1 --", value="x")
def test_treat_meta_field_sql_injection_is_safely_escaped():
# SQL injection attempts are safely escaped by SQLLiteral rather than rejected
result = _treat_meta_field(field="meta.name' OR 1=1 --", value="x")
assert isinstance(result, Composed)
assert result == SQL("meta->>") + SQLLiteral("name' OR 1=1 --")


def test_comparison_condition_missing_operator():
Expand Down Expand Up @@ -206,10 +219,16 @@ def test_logical_condition_nested():
],
}
query, values = _parse_logical_condition(condition)
assert query == (
"((meta->>'domain' IS DISTINCT FROM %s OR meta->>'chapter' = ANY(%s)) "
"AND ((meta->>'number')::integer >= %s OR meta->>'author' IS NULL OR meta->>'author' != ALL(%s)))"
assert isinstance(query, Composed)

expected_sql = (
"("
"(meta->>'domain' IS DISTINCT FROM %s OR meta->>'chapter' = ANY(%s))"
" AND "
"((meta->>'number')::integer >= %s OR meta->>'author' IS NULL OR meta->>'author' != ALL(%s))"
")"
)
assert _render(query) == expected_sql
assert values == ["science", [["intro", "conclusion"]], 90, [["John", "Jane"]]]


Expand All @@ -222,7 +241,9 @@ def test_convert_filters_to_where_clause_and_params():
],
}
where_clause, params = _convert_filters_to_where_clause_and_params(filters)
assert where_clause == SQL(" WHERE ") + SQL("((meta->>'number')::integer = %s AND meta->>'chapter' = %s)")

expected_sql = " WHERE ((meta->>'number')::integer = %s AND meta->>'chapter' = %s)"
assert _render(where_clause) == expected_sql
assert params == (100, "intro")


Expand All @@ -235,7 +256,9 @@ def test_convert_filters_to_where_clause_and_params_handle_null():
],
}
where_clause, params = _convert_filters_to_where_clause_and_params(filters)
assert where_clause == SQL(" WHERE ") + SQL("(meta->>'number' IS NULL AND meta->>'chapter' = %s)")

expected_sql = " WHERE (meta->>'number' IS NULL AND meta->>'chapter' = %s)"
assert _render(where_clause) == expected_sql
assert params == ("intro",)


Expand Down