Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
#
# SPDX-License-Identifier: Apache-2.0

import re
from datetime import datetime
from itertools import chain
from typing import Any, Literal

from haystack.errors import FilterError
from psycopg.sql import SQL, Composed
from psycopg.sql import SQL, Composable, Composed, Identifier
from psycopg.sql import Literal as SQLLiteral
from psycopg.types.json import Jsonb

# we need this mapping to cast meta values to the correct type,
Expand All @@ -21,7 +21,6 @@
}

NO_VALUE = "no_value"
SAFE_META_FIELD_RE = re.compile(r"^[a-zA-Z0-9_-]+$")


def _validate_filters(filters: dict[str, Any] | None = None) -> None:
Expand All @@ -48,13 +47,13 @@ def _convert_filters_to_where_clause_and_params(
else:
query, values = _parse_logical_condition(filters)

where_clause = SQL(f" {operator} ") + SQL(query)
where_clause = SQL(f" {operator} ") + query
params = tuple(value for value in values if value != NO_VALUE)

return where_clause, params


def _parse_logical_condition(condition: dict[str, Any]) -> tuple[str, list[Any]]:
def _parse_logical_condition(condition: dict[str, Any]) -> tuple[Composed, list[Any]]:
if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
raise FilterError(msg)
Expand Down Expand Up @@ -84,17 +83,14 @@ def _parse_logical_condition(condition: dict[str, Any]) -> tuple[str, list[Any]]
values = list(chain.from_iterable(values))

if operator == "AND":
sql_query = f"({' AND '.join(query_parts)})"
elif operator == "OR":
sql_query = f"({' OR '.join(query_parts)})"
else:
msg = f"Unknown logical operator '{operator}'"
raise FilterError(msg)
sql_query = SQL("(") + SQL(" AND ").join(query_parts) + SQL(")")
else: # operator == "OR"
sql_query = SQL("(") + SQL(" OR ").join(query_parts) + SQL(")")
Comment thread
anakin87 marked this conversation as resolved.

return sql_query, values


def _parse_comparison_condition(condition: dict[str, Any]) -> tuple[str, list[Any]]:
def _parse_comparison_condition(condition: dict[str, Any]) -> tuple[Composed, list[Any]]:
field: str = condition["field"]
if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
Expand All @@ -110,34 +106,37 @@ def _parse_comparison_condition(condition: dict[str, Any]) -> tuple[str, list[An
value: Any = condition["value"]

if field.startswith("meta."):
field = _treat_meta_field(field, value)
sql_field: Composable = _treat_meta_field(field, value)
else:
sql_field = Identifier(field)

field, value = COMPARISON_OPERATORS[operator](field, value)
return field, [value]
sql_expr, value = COMPARISON_OPERATORS[operator](sql_field, value)
return sql_expr, [value]


def _treat_meta_field(field: str, value: Any) -> str:
def _treat_meta_field(field: str, value: Any) -> Composed:
"""
Internal method that modifies the field str
to make the meta JSONB field queryable.
Internal method that returns a psycopg Composed object
to make the meta JSONB field queryable safely.

Uses psycopg.sql.Literal to embed the field name, preventing SQL injection
via metadata field names without requiring regex validation.

Use the ->> operator to access keys in the meta JSONB field.

Examples:
>>> _treat_meta_field(field="meta.number", value=9)
"(meta->>'number')::integer"
Composed([SQL('(meta->>'), Literal('number'), SQL(')::integer')])

>>> _treat_meta_field(field="meta.name", value="my_name")
"meta->>'name'"
"""
Composed([SQL('meta->>'), Literal('name')])

# use the ->> operator to access keys in the meta JSONB field
Comment thread
anakin87 marked this conversation as resolved.
"""
field_name = field.split(".", 1)[-1]
if not SAFE_META_FIELD_RE.match(field_name):
msg = (
f"Invalid metadata field name '{field_name}'. "
"Only alphanumeric characters, dashes, and underscores are allowed."
)
raise FilterError(msg)
field = f"meta->>'{field_name}'"

# Use SQLLiteral to safely embed the field name as a SQL string literal,
# preventing SQL injection via metadata field names.
composed: Composed = SQL("meta->>") + SQLLiteral(field_name)

# meta fields are stored as strings in the JSONB field,
# so we need to cast them to the correct type
Expand All @@ -146,25 +145,24 @@ def _treat_meta_field(field: str, value: Any) -> str:
type_value = PYTHON_TYPES_TO_PG_TYPES.get(type(value[0]))

if type_value:
field = f"({field})::{type_value}"
composed = SQL("(") + composed + SQL(f")::{type_value}")

return field
return composed


def _equal(field: str, value: Any) -> tuple[str, Any]:
def _equal(field: Composable, value: Any) -> tuple[Composed, Any]:
if value is None:
# NO_VALUE is a placeholder that will be removed in _convert_filters_to_where_clause_and_params
return f"{field} IS NULL", NO_VALUE
return f"{field} = %s", value
return SQL("{} IS NULL").format(field), NO_VALUE
return SQL("{} = %s").format(field), value


def _not_equal(field: str, value: Any) -> tuple[str, Any]:
def _not_equal(field: Composable, value: Any) -> tuple[Composed, Any]:
# we use IS DISTINCT FROM to correctly handle NULL values
# (not handled by !=)
return f"{field} IS DISTINCT FROM %s", value
return SQL("{} IS DISTINCT FROM %s").format(field), value


def _greater_than(field: str, value: Any) -> tuple[str, Any]:
def _greater_than(field: Composable, value: Any) -> tuple[Composed, Any]:
if isinstance(value, str):
try:
datetime.fromisoformat(value)
Expand All @@ -177,11 +175,10 @@ def _greater_than(field: str, value: Any) -> tuple[str, Any]:
if type(value) in [list, Jsonb]:
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
raise FilterError(msg)
return SQL("{} > %s").format(field), value

return f"{field} > %s", value


def _greater_than_equal(field: str, value: Any) -> tuple[str, Any]:
def _greater_than_equal(field: Composable, value: Any) -> tuple[Composed, Any]:
if isinstance(value, str):
try:
datetime.fromisoformat(value)
Expand All @@ -194,11 +191,10 @@ def _greater_than_equal(field: str, value: Any) -> tuple[str, Any]:
if type(value) in [list, Jsonb]:
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
raise FilterError(msg)

return f"{field} >= %s", value
return SQL("{} >= %s").format(field), value


def _less_than(field: str, value: Any) -> tuple[str, Any]:
def _less_than(field: Composable, value: Any) -> tuple[Composed, Any]:
if isinstance(value, str):
try:
datetime.fromisoformat(value)
Expand All @@ -211,11 +207,10 @@ def _less_than(field: str, value: Any) -> tuple[str, Any]:
if type(value) in [list, Jsonb]:
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
raise FilterError(msg)
return SQL("{} < %s").format(field), value

return f"{field} < %s", value


def _less_than_equal(field: str, value: Any) -> tuple[str, Any]:
def _less_than_equal(field: Composable, value: Any) -> tuple[Composed, Any]:
if isinstance(value, str):
try:
datetime.fromisoformat(value)
Expand All @@ -228,39 +223,37 @@ def _less_than_equal(field: str, value: Any) -> tuple[str, Any]:
if type(value) in [list, Jsonb]:
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
raise FilterError(msg)

return f"{field} <= %s", value
return SQL("{} <= %s").format(field), value


def _not_in(field: str, value: Any) -> tuple[str, list]:
def _not_in(field: Composable, value: Any) -> tuple[Composed, list]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone"
raise FilterError(msg)

return f"{field} IS NULL OR {field} != ALL(%s)", [value]
return SQL("{} IS NULL OR {} != ALL(%s)").format(field, field), [value]


def _in(field: str, value: Any) -> tuple[str, list]:
def _in(field: Composable, value: Any) -> tuple[Composed, list]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'in' comparator in Pinecone"
raise FilterError(msg)

# see https://www.psycopg.org/psycopg3/docs/basic/adapt.html#lists-adaptation
return f"{field} = ANY(%s)", [value]
return SQL("{} = ANY(%s)").format(field), [value]


def _like(field: str, value: Any) -> tuple[str, Any]:
def _like(field: Composable, value: Any) -> tuple[Composed, Any]:
if not isinstance(value, str):
msg = f"{field}'s value must be a str when using 'LIKE' "
raise FilterError(msg)
return f"{field} LIKE %s", value
return SQL("{} LIKE %s").format(field), value


def _not_like(field: str, value: Any) -> tuple[str, Any]:
def _not_like(field: Composable, value: Any) -> tuple[Composed, Any]:
if not isinstance(value, str):
msg = f"{field}'s value must be a str when using 'LIKE' "
raise FilterError(msg)
return f"{field} NOT LIKE %s", value
return SQL("{} NOT LIKE %s").format(field), value


COMPARISON_OPERATORS = {
Expand Down
113 changes: 93 additions & 20 deletions integrations/pgvector/tests/test_filters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

import pytest
from haystack.dataclasses.document import Document
from haystack.testing.document_store import FilterDocumentsTest
from psycopg.sql import SQL
from psycopg.sql import SQL, Composed
from psycopg.sql import Literal as SQLLiteral

from haystack_integrations.document_stores.pgvector.filters import (
FilterError,
Expand Down Expand Up @@ -129,24 +134,38 @@ def test_complex_filter(self, document_store, filterable_docs):


def test_treat_meta_field():
assert _treat_meta_field(field="meta.number", value=9) == "(meta->>'number')::integer"
assert _treat_meta_field(field="meta.number", value=[1, 2, 3]) == "(meta->>'number')::integer"
assert _treat_meta_field(field="meta.name", value="my_name") == "meta->>'name'"
assert _treat_meta_field(field="meta.name", value=["my_name"]) == "meta->>'name'"
assert _treat_meta_field(field="meta.number", value=1.1) == "(meta->>'number')::real"
assert _treat_meta_field(field="meta.number", value=[1.1, 2.2, 3.3]) == "(meta->>'number')::real"
assert _treat_meta_field(field="meta.bool", value=True) == "(meta->>'bool')::boolean"
assert _treat_meta_field(field="meta.bool", value=[True, False, True]) == "(meta->>'bool')::boolean"
assert _treat_meta_field(field="meta.number", value=9) == SQL("(") + SQL("meta->>") + SQLLiteral("number") + SQL(
")::integer"
)
assert _treat_meta_field(field="meta.number", value=[1, 2, 3]) == SQL("(") + SQL("meta->>") + SQLLiteral(
"number"
) + SQL(")::integer")
assert _treat_meta_field(field="meta.name", value="my_name") == SQL("meta->>") + SQLLiteral("name")
assert _treat_meta_field(field="meta.name", value=["my_name"]) == SQL("meta->>") + SQLLiteral("name")
assert _treat_meta_field(field="meta.number", value=1.1) == SQL("(") + SQL("meta->>") + SQLLiteral("number") + SQL(
")::real"
)
assert _treat_meta_field(field="meta.number", value=[1.1, 2.2, 3.3]) == SQL("(") + SQL("meta->>") + SQLLiteral(
"number"
) + SQL(")::real")
assert _treat_meta_field(field="meta.bool", value=True) == SQL("(") + SQL("meta->>") + SQLLiteral("bool") + SQL(
")::boolean"
)
assert _treat_meta_field(field="meta.bool", value=[True, False, True]) == SQL("(") + SQL("meta->>") + SQLLiteral(
"bool"
) + SQL(")::boolean")

# do not cast the field if its value is not one of the known types, an empty list or None
assert _treat_meta_field(field="meta.other", value={"a": 3, "b": "example"}) == "meta->>'other'"
assert _treat_meta_field(field="meta.empty_list", value=[]) == "meta->>'empty_list'"
assert _treat_meta_field(field="meta.name", value=None) == "meta->>'name'"
assert _treat_meta_field(field="meta.other", value={"a": 3, "b": "example"}) == SQL("meta->>") + SQLLiteral("other")
assert _treat_meta_field(field="meta.empty_list", value=[]) == SQL("meta->>") + SQLLiteral("empty_list")
assert _treat_meta_field(field="meta.name", value=None) == SQL("meta->>") + SQLLiteral("name")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's discuss tests: I understand that the return type has changed, so we need to do something like this, which, unfortunately, is hard to read/maintain.

An alternative would be to build a psycopg cursor and use it to convert SQL to str, keeping the test similar to the previous one. An example for test_treat_meta_field:

def test_treat_meta_field(document_store):
    document_store._ensure_db_setup()
    cursor = document_store._cursor

    expr = _treat_meta_field(field="meta.number", value=9)
    assert expr.as_string(cursor) == "(meta->>'number')::integer"
    ...

The only problem with this approach: the test is no longer a pure unit test since it needs a DB connection. But we keep readability. WDYT?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about something like this:

def test_treat_meta_field():
      cast_integer = SQL("(") + SQL("meta->>") + SQLLiteral("number") + SQL(")::integer")
      no_cast_name  = SQL("meta->>") + SQLLiteral("name")
      cast_real     = SQL("(") + SQL("meta->>") + SQLLiteral("number") + SQL(")::real")
      cast_boolean  = SQL("(") + SQL("meta->>") + SQLLiteral("bool") + SQL(")::boolean")

      assert _treat_meta_field(field="meta.number", value=9) == cast_integer
      assert _treat_meta_field(field="meta.number", value=[1, 2, 3]) == cast_integer
      assert _treat_meta_field(field="meta.name", value="my_name") == no_cast_name
      assert _treat_meta_field(field="meta.name", value=["my_name"]) == no_cast_name
      assert _treat_meta_field(field="meta.number", value=1.1) == cast_real
      assert _treat_meta_field(field="meta.number", value=[1.1, 2.2]) == cast_real
      assert _treat_meta_field(field="meta.bool", value=True) == cast_boolean
      assert _treat_meta_field(field="meta.bool", value=[True, False]) == cast_boolean

it compares two Composed instances

Copy link
Copy Markdown
Member

@anakin87 anakin87 Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test remains quite understandable in any case.
test_comparison_condition_missing_operator test_logical_condition_nested is hard to read instead.

For this reason, I was advocating to resolve to string in the test.

What's your opinion?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused.

To make sure I understand your point.

You agree with the approach for the test test_treat_meta_field as is, but instead of repr() you suggest having it as a string to compare the full output - is that it?

Regarding the test_comparison_condition_missing_operator- it's not changed by this PR - but I miss what you are suggesting.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I was pointing to the wrong test. It's test_logical_condition_nested that gets a bit hard to read.

If we want to resolve to a string for easier comparison, we can use a cursor as I proposed in #2964 (comment)

In any case, I'll leave this decision up to you.



def test_treat_meta_field_rejects_unsafe_metadata_key():
with pytest.raises(FilterError, match="Invalid metadata field name"):
_treat_meta_field(field="meta.name' OR 1=1 --", value="x")
def test_treat_meta_field_sql_injection_is_safely_escaped():
# SQL injection attempts are safely escaped by SQLLiteral rather than rejected
result = _treat_meta_field(field="meta.name' OR 1=1 --", value="x")
assert isinstance(result, Composed)
assert result == SQL("meta->>") + SQLLiteral("name' OR 1=1 --")


def test_comparison_condition_missing_operator():
Expand Down Expand Up @@ -206,10 +225,38 @@ def test_logical_condition_nested():
],
}
query, values = _parse_logical_condition(condition)
assert query == (
"((meta->>'domain' IS DISTINCT FROM %s OR meta->>'chapter' = ANY(%s)) "
"AND ((meta->>'number')::integer >= %s OR meta->>'author' IS NULL OR meta->>'author' != ALL(%s)))"
assert isinstance(query, Composed)

domain_field = SQL("meta->>") + SQLLiteral("domain")
chapter_field = SQL("meta->>") + SQLLiteral("chapter")
number_field = SQL("(") + SQL("meta->>") + SQLLiteral("number") + SQL(")::integer")
author_field = SQL("meta->>") + SQLLiteral("author")

expected = (
SQL("(")
+ SQL(" AND ").join(
[
SQL("(")
+ SQL(" OR ").join(
[
SQL("{} IS DISTINCT FROM %s").format(domain_field),
SQL("{} = ANY(%s)").format(chapter_field),
]
)
+ SQL(")"),
SQL("(")
+ SQL(" OR ").join(
[
SQL("{} >= %s").format(number_field),
SQL("{} IS NULL OR {} != ALL(%s)").format(author_field, author_field),
]
)
+ SQL(")"),
]
)
+ SQL(")")
)
assert query == expected
assert values == ["science", [["intro", "conclusion"]], 90, [["John", "Jane"]]]


Expand All @@ -222,7 +269,20 @@ def test_convert_filters_to_where_clause_and_params():
],
}
where_clause, params = _convert_filters_to_where_clause_and_params(filters)
assert where_clause == SQL(" WHERE ") + SQL("((meta->>'number')::integer = %s AND meta->>'chapter' = %s)")

number_field = SQL("(") + SQL("meta->>") + SQLLiteral("number") + SQL(")::integer")
chapter_field = SQL("meta->>") + SQLLiteral("chapter")
expected = SQL(" WHERE ") + (
SQL("(")
+ SQL(" AND ").join(
[
SQL("{} = %s").format(number_field),
SQL("{} = %s").format(chapter_field),
]
)
+ SQL(")")
)
assert where_clause == expected
assert params == (100, "intro")


Expand All @@ -235,7 +295,20 @@ def test_convert_filters_to_where_clause_and_params_handle_null():
],
}
where_clause, params = _convert_filters_to_where_clause_and_params(filters)
assert where_clause == SQL(" WHERE ") + SQL("(meta->>'number' IS NULL AND meta->>'chapter' = %s)")

number_field = SQL("meta->>") + SQLLiteral("number")
chapter_field = SQL("meta->>") + SQLLiteral("chapter")
expected = SQL(" WHERE ") + (
SQL("(")
+ SQL(" AND ").join(
[
SQL("{} IS NULL").format(number_field),
SQL("{} = %s").format(chapter_field),
]
)
+ SQL(")")
)
assert where_clause == expected
assert params == ("intro",)


Expand Down