diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py index 9bb76562c5..92b3d42bd6 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py @@ -2,13 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 -import re from datetime import datetime from itertools import chain from typing import Any, Literal from haystack.errors import FilterError -from psycopg.sql import SQL, Composed +from psycopg.sql import SQL, Composable, Composed, Identifier +from psycopg.sql import Literal as SQLLiteral from psycopg.types.json import Jsonb # we need this mapping to cast meta values to the correct type, @@ -21,7 +21,6 @@ } NO_VALUE = "no_value" -SAFE_META_FIELD_RE = re.compile(r"^[a-zA-Z0-9_-]+$") def _validate_filters(filters: dict[str, Any] | None = None) -> None: @@ -48,13 +47,13 @@ def _convert_filters_to_where_clause_and_params( else: query, values = _parse_logical_condition(filters) - where_clause = SQL(f" {operator} ") + SQL(query) + where_clause = SQL(f" {operator} ") + query params = tuple(value for value in values if value != NO_VALUE) return where_clause, params -def _parse_logical_condition(condition: dict[str, Any]) -> tuple[str, list[Any]]: +def _parse_logical_condition(condition: dict[str, Any]) -> tuple[Composed, list[Any]]: if "operator" not in condition: msg = f"'operator' key missing in {condition}" raise FilterError(msg) @@ -84,17 +83,14 @@ def _parse_logical_condition(condition: dict[str, Any]) -> tuple[str, list[Any]] values = list(chain.from_iterable(values)) if operator == "AND": - sql_query = f"({' AND '.join(query_parts)})" - elif operator == "OR": - sql_query = f"({' OR '.join(query_parts)})" - else: - msg = f"Unknown logical operator '{operator}'" - raise FilterError(msg) + sql_query = SQL("(") + SQL(" AND ").join(query_parts) + SQL(")") + else: # operator == "OR" + sql_query = SQL("(") + SQL(" OR ").join(query_parts) + SQL(")") return sql_query, values -def _parse_comparison_condition(condition: dict[str, Any]) -> tuple[str, list[Any]]: +def _parse_comparison_condition(condition: dict[str, Any]) -> tuple[Composed, list[Any]]: field: str = condition["field"] if "operator" not in condition: msg = f"'operator' key missing in {condition}" @@ -110,34 +106,37 @@ def _parse_comparison_condition(condition: dict[str, Any]) -> tuple[str, list[An value: Any = condition["value"] if field.startswith("meta."): - field = _treat_meta_field(field, value) + sql_field: Composable = _treat_meta_field(field, value) + else: + sql_field = Identifier(field) - field, value = COMPARISON_OPERATORS[operator](field, value) - return field, [value] + sql_expr, value = COMPARISON_OPERATORS[operator](sql_field, value) + return sql_expr, [value] -def _treat_meta_field(field: str, value: Any) -> str: +def _treat_meta_field(field: str, value: Any) -> Composed: """ - Internal method that modifies the field str - to make the meta JSONB field queryable. + Internal method that returns a psycopg Composed object + to make the meta JSONB field queryable safely. + + Uses psycopg.sql.Literal to embed the field name, preventing SQL injection + via metadata field names without requiring regex validation. + + Use the ->> operator to access keys in the meta JSONB field. Examples: >>> _treat_meta_field(field="meta.number", value=9) - "(meta->>'number')::integer" + Composed([SQL('(meta->>'), Literal('number'), SQL(')::integer')]) >>> _treat_meta_field(field="meta.name", value="my_name") - "meta->>'name'" - """ + Composed([SQL('meta->>'), Literal('name')]) - # use the ->> operator to access keys in the meta JSONB field + """ field_name = field.split(".", 1)[-1] - if not SAFE_META_FIELD_RE.match(field_name): - msg = ( - f"Invalid metadata field name '{field_name}'. " - "Only alphanumeric characters, dashes, and underscores are allowed." - ) - raise FilterError(msg) - field = f"meta->>'{field_name}'" + + # Use SQLLiteral to safely embed the field name as a SQL string literal, + # preventing SQL injection via metadata field names. + composed: Composed = SQL("meta->>") + SQLLiteral(field_name) # meta fields are stored as strings in the JSONB field, # so we need to cast them to the correct type @@ -146,25 +145,24 @@ def _treat_meta_field(field: str, value: Any) -> str: type_value = PYTHON_TYPES_TO_PG_TYPES.get(type(value[0])) if type_value: - field = f"({field})::{type_value}" + composed = SQL("(") + composed + SQL(f")::{type_value}") - return field + return composed -def _equal(field: str, value: Any) -> tuple[str, Any]: +def _equal(field: Composable, value: Any) -> tuple[Composed, Any]: if value is None: - # NO_VALUE is a placeholder that will be removed in _convert_filters_to_where_clause_and_params - return f"{field} IS NULL", NO_VALUE - return f"{field} = %s", value + return SQL("{} IS NULL").format(field), NO_VALUE + return SQL("{} = %s").format(field), value -def _not_equal(field: str, value: Any) -> tuple[str, Any]: +def _not_equal(field: Composable, value: Any) -> tuple[Composed, Any]: # we use IS DISTINCT FROM to correctly handle NULL values # (not handled by !=) - return f"{field} IS DISTINCT FROM %s", value + return SQL("{} IS DISTINCT FROM %s").format(field), value -def _greater_than(field: str, value: Any) -> tuple[str, Any]: +def _greater_than(field: Composable, value: Any) -> tuple[Composed, Any]: if isinstance(value, str): try: datetime.fromisoformat(value) @@ -177,11 +175,10 @@ def _greater_than(field: str, value: Any) -> tuple[str, Any]: if type(value) in [list, Jsonb]: msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" raise FilterError(msg) + return SQL("{} > %s").format(field), value - return f"{field} > %s", value - -def _greater_than_equal(field: str, value: Any) -> tuple[str, Any]: +def _greater_than_equal(field: Composable, value: Any) -> tuple[Composed, Any]: if isinstance(value, str): try: datetime.fromisoformat(value) @@ -194,11 +191,10 @@ def _greater_than_equal(field: str, value: Any) -> tuple[str, Any]: if type(value) in [list, Jsonb]: msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" raise FilterError(msg) - - return f"{field} >= %s", value + return SQL("{} >= %s").format(field), value -def _less_than(field: str, value: Any) -> tuple[str, Any]: +def _less_than(field: Composable, value: Any) -> tuple[Composed, Any]: if isinstance(value, str): try: datetime.fromisoformat(value) @@ -211,11 +207,10 @@ def _less_than(field: str, value: Any) -> tuple[str, Any]: if type(value) in [list, Jsonb]: msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" raise FilterError(msg) + return SQL("{} < %s").format(field), value - return f"{field} < %s", value - -def _less_than_equal(field: str, value: Any) -> tuple[str, Any]: +def _less_than_equal(field: Composable, value: Any) -> tuple[Composed, Any]: if isinstance(value, str): try: datetime.fromisoformat(value) @@ -228,39 +223,37 @@ def _less_than_equal(field: str, value: Any) -> tuple[str, Any]: if type(value) in [list, Jsonb]: msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" raise FilterError(msg) - - return f"{field} <= %s", value + return SQL("{} <= %s").format(field), value -def _not_in(field: str, value: Any) -> tuple[str, list]: +def _not_in(field: Composable, value: Any) -> tuple[Composed, list]: if not isinstance(value, list): msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone" raise FilterError(msg) - - return f"{field} IS NULL OR {field} != ALL(%s)", [value] + return SQL("{} IS NULL OR {} != ALL(%s)").format(field, field), [value] -def _in(field: str, value: Any) -> tuple[str, list]: +def _in(field: Composable, value: Any) -> tuple[Composed, list]: if not isinstance(value, list): msg = f"{field}'s value must be a list when using 'in' comparator in Pinecone" raise FilterError(msg) # see https://www.psycopg.org/psycopg3/docs/basic/adapt.html#lists-adaptation - return f"{field} = ANY(%s)", [value] + return SQL("{} = ANY(%s)").format(field), [value] -def _like(field: str, value: Any) -> tuple[str, Any]: +def _like(field: Composable, value: Any) -> tuple[Composed, Any]: if not isinstance(value, str): msg = f"{field}'s value must be a str when using 'LIKE' " raise FilterError(msg) - return f"{field} LIKE %s", value + return SQL("{} LIKE %s").format(field), value -def _not_like(field: str, value: Any) -> tuple[str, Any]: +def _not_like(field: Composable, value: Any) -> tuple[Composed, Any]: if not isinstance(value, str): msg = f"{field}'s value must be a str when using 'LIKE' " raise FilterError(msg) - return f"{field} NOT LIKE %s", value + return SQL("{} NOT LIKE %s").format(field), value COMPARISON_OPERATORS = { diff --git a/integrations/pgvector/tests/test_filters.py b/integrations/pgvector/tests/test_filters.py index ee62eb1f62..6787448239 100644 --- a/integrations/pgvector/tests/test_filters.py +++ b/integrations/pgvector/tests/test_filters.py @@ -1,7 +1,13 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import pytest from haystack.dataclasses.document import Document from haystack.testing.document_store import FilterDocumentsTest -from psycopg.sql import SQL +from psycopg.adapt import Transformer +from psycopg.sql import SQL, Composed +from psycopg.sql import Literal as SQLLiteral from haystack_integrations.document_stores.pgvector.filters import ( FilterError, @@ -13,6 +19,11 @@ ) +def _render(composed: Composed) -> str: + """Render a psycopg Composed object to a plain SQL string for assertions.""" + return composed.as_string(Transformer()) + + @pytest.mark.integration class TestFilters(FilterDocumentsTest): def assert_documents_are_equal(self, received: list[Document], expected: list[Document]): @@ -129,24 +140,26 @@ def test_complex_filter(self, document_store, filterable_docs): def test_treat_meta_field(): - assert _treat_meta_field(field="meta.number", value=9) == "(meta->>'number')::integer" - assert _treat_meta_field(field="meta.number", value=[1, 2, 3]) == "(meta->>'number')::integer" - assert _treat_meta_field(field="meta.name", value="my_name") == "meta->>'name'" - assert _treat_meta_field(field="meta.name", value=["my_name"]) == "meta->>'name'" - assert _treat_meta_field(field="meta.number", value=1.1) == "(meta->>'number')::real" - assert _treat_meta_field(field="meta.number", value=[1.1, 2.2, 3.3]) == "(meta->>'number')::real" - assert _treat_meta_field(field="meta.bool", value=True) == "(meta->>'bool')::boolean" - assert _treat_meta_field(field="meta.bool", value=[True, False, True]) == "(meta->>'bool')::boolean" + cast_integer = SQL("(") + SQL("meta->>") + SQLLiteral("number") + SQL(")::integer") + no_cast_name = SQL("meta->>") + SQLLiteral("name") + cast_real = SQL("(") + SQL("meta->>") + SQLLiteral("number") + SQL(")::real") + cast_boolean = SQL("(") + SQL("meta->>") + SQLLiteral("bool") + SQL(")::boolean") - # do not cast the field if its value is not one of the known types, an empty list or None - assert _treat_meta_field(field="meta.other", value={"a": 3, "b": "example"}) == "meta->>'other'" - assert _treat_meta_field(field="meta.empty_list", value=[]) == "meta->>'empty_list'" - assert _treat_meta_field(field="meta.name", value=None) == "meta->>'name'" + assert _treat_meta_field(field="meta.number", value=9) == cast_integer + assert _treat_meta_field(field="meta.number", value=[1, 2, 3]) == cast_integer + assert _treat_meta_field(field="meta.name", value="my_name") == no_cast_name + assert _treat_meta_field(field="meta.name", value=["my_name"]) == no_cast_name + assert _treat_meta_field(field="meta.number", value=1.1) == cast_real + assert _treat_meta_field(field="meta.number", value=[1.1, 2.2]) == cast_real + assert _treat_meta_field(field="meta.bool", value=True) == cast_boolean + assert _treat_meta_field(field="meta.bool", value=[True, False]) == cast_boolean -def test_treat_meta_field_rejects_unsafe_metadata_key(): - with pytest.raises(FilterError, match="Invalid metadata field name"): - _treat_meta_field(field="meta.name' OR 1=1 --", value="x") +def test_treat_meta_field_sql_injection_is_safely_escaped(): + # SQL injection attempts are safely escaped by SQLLiteral rather than rejected + result = _treat_meta_field(field="meta.name' OR 1=1 --", value="x") + assert isinstance(result, Composed) + assert result == SQL("meta->>") + SQLLiteral("name' OR 1=1 --") def test_comparison_condition_missing_operator(): @@ -206,10 +219,16 @@ def test_logical_condition_nested(): ], } query, values = _parse_logical_condition(condition) - assert query == ( - "((meta->>'domain' IS DISTINCT FROM %s OR meta->>'chapter' = ANY(%s)) " - "AND ((meta->>'number')::integer >= %s OR meta->>'author' IS NULL OR meta->>'author' != ALL(%s)))" + assert isinstance(query, Composed) + + expected_sql = ( + "(" + "(meta->>'domain' IS DISTINCT FROM %s OR meta->>'chapter' = ANY(%s))" + " AND " + "((meta->>'number')::integer >= %s OR meta->>'author' IS NULL OR meta->>'author' != ALL(%s))" + ")" ) + assert _render(query) == expected_sql assert values == ["science", [["intro", "conclusion"]], 90, [["John", "Jane"]]] @@ -222,7 +241,9 @@ def test_convert_filters_to_where_clause_and_params(): ], } where_clause, params = _convert_filters_to_where_clause_and_params(filters) - assert where_clause == SQL(" WHERE ") + SQL("((meta->>'number')::integer = %s AND meta->>'chapter' = %s)") + + expected_sql = " WHERE ((meta->>'number')::integer = %s AND meta->>'chapter' = %s)" + assert _render(where_clause) == expected_sql assert params == (100, "intro") @@ -235,7 +256,9 @@ def test_convert_filters_to_where_clause_and_params_handle_null(): ], } where_clause, params = _convert_filters_to_where_clause_and_params(filters) - assert where_clause == SQL(" WHERE ") + SQL("(meta->>'number' IS NULL AND meta->>'chapter' = %s)") + + expected_sql = " WHERE (meta->>'number' IS NULL AND meta->>'chapter' = %s)" + assert _render(where_clause) == expected_sql assert params == ("intro",)