-
Notifications
You must be signed in to change notification settings - Fork 254
fix: PgVectorDocumentStore _treat_meta_field and comparison functions now return Composed - string escaping done by psycopg
#2964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
fe55020
8a91dde
2d754e2
7421869
0912296
2fd0f10
7862bdf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,12 @@ | ||
| # SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai> | ||
| # | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import pytest | ||
| from haystack.dataclasses.document import Document | ||
| from haystack.testing.document_store import FilterDocumentsTest | ||
| from psycopg.sql import SQL | ||
| from psycopg.sql import SQL, Composed | ||
| from psycopg.sql import Literal as SQLLiteral | ||
|
|
||
| from haystack_integrations.document_stores.pgvector.filters import ( | ||
| FilterError, | ||
|
|
@@ -129,24 +134,38 @@ def test_complex_filter(self, document_store, filterable_docs): | |
|
|
||
|
|
||
| def test_treat_meta_field(): | ||
| assert _treat_meta_field(field="meta.number", value=9) == "(meta->>'number')::integer" | ||
| assert _treat_meta_field(field="meta.number", value=[1, 2, 3]) == "(meta->>'number')::integer" | ||
| assert _treat_meta_field(field="meta.name", value="my_name") == "meta->>'name'" | ||
| assert _treat_meta_field(field="meta.name", value=["my_name"]) == "meta->>'name'" | ||
| assert _treat_meta_field(field="meta.number", value=1.1) == "(meta->>'number')::real" | ||
| assert _treat_meta_field(field="meta.number", value=[1.1, 2.2, 3.3]) == "(meta->>'number')::real" | ||
| assert _treat_meta_field(field="meta.bool", value=True) == "(meta->>'bool')::boolean" | ||
| assert _treat_meta_field(field="meta.bool", value=[True, False, True]) == "(meta->>'bool')::boolean" | ||
| assert _treat_meta_field(field="meta.number", value=9) == SQL("(") + SQL("meta->>") + SQLLiteral("number") + SQL( | ||
| ")::integer" | ||
| ) | ||
| assert _treat_meta_field(field="meta.number", value=[1, 2, 3]) == SQL("(") + SQL("meta->>") + SQLLiteral( | ||
| "number" | ||
| ) + SQL(")::integer") | ||
| assert _treat_meta_field(field="meta.name", value="my_name") == SQL("meta->>") + SQLLiteral("name") | ||
| assert _treat_meta_field(field="meta.name", value=["my_name"]) == SQL("meta->>") + SQLLiteral("name") | ||
| assert _treat_meta_field(field="meta.number", value=1.1) == SQL("(") + SQL("meta->>") + SQLLiteral("number") + SQL( | ||
| ")::real" | ||
| ) | ||
| assert _treat_meta_field(field="meta.number", value=[1.1, 2.2, 3.3]) == SQL("(") + SQL("meta->>") + SQLLiteral( | ||
| "number" | ||
| ) + SQL(")::real") | ||
| assert _treat_meta_field(field="meta.bool", value=True) == SQL("(") + SQL("meta->>") + SQLLiteral("bool") + SQL( | ||
| ")::boolean" | ||
| ) | ||
| assert _treat_meta_field(field="meta.bool", value=[True, False, True]) == SQL("(") + SQL("meta->>") + SQLLiteral( | ||
| "bool" | ||
| ) + SQL(")::boolean") | ||
|
|
||
| # do not cast the field if its value is not one of the known types, an empty list or None | ||
| assert _treat_meta_field(field="meta.other", value={"a": 3, "b": "example"}) == "meta->>'other'" | ||
| assert _treat_meta_field(field="meta.empty_list", value=[]) == "meta->>'empty_list'" | ||
| assert _treat_meta_field(field="meta.name", value=None) == "meta->>'name'" | ||
| assert _treat_meta_field(field="meta.other", value={"a": 3, "b": "example"}) == SQL("meta->>") + SQLLiteral("other") | ||
| assert _treat_meta_field(field="meta.empty_list", value=[]) == SQL("meta->>") + SQLLiteral("empty_list") | ||
| assert _treat_meta_field(field="meta.name", value=None) == SQL("meta->>") + SQLLiteral("name") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's discuss tests: I understand that the return type has changed, so we need to do something like this, which, unfortunately, is hard to read/maintain. An alternative would be to build a psycopg def test_treat_meta_field(document_store):
document_store._ensure_db_setup()
cursor = document_store._cursor
expr = _treat_meta_field(field="meta.number", value=9)
assert expr.as_string(cursor) == "(meta->>'number')::integer"
...The only problem with this approach: the test is no longer a pure unit test since it needs a DB connection. But we keep readability. WDYT?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about something like this: def test_treat_meta_field():
cast_integer = SQL("(") + SQL("meta->>") + SQLLiteral("number") + SQL(")::integer")
no_cast_name = SQL("meta->>") + SQLLiteral("name")
cast_real = SQL("(") + SQL("meta->>") + SQLLiteral("number") + SQL(")::real")
cast_boolean = SQL("(") + SQL("meta->>") + SQLLiteral("bool") + SQL(")::boolean")
assert _treat_meta_field(field="meta.number", value=9) == cast_integer
assert _treat_meta_field(field="meta.number", value=[1, 2, 3]) == cast_integer
assert _treat_meta_field(field="meta.name", value="my_name") == no_cast_name
assert _treat_meta_field(field="meta.name", value=["my_name"]) == no_cast_name
assert _treat_meta_field(field="meta.number", value=1.1) == cast_real
assert _treat_meta_field(field="meta.number", value=[1.1, 2.2]) == cast_real
assert _treat_meta_field(field="meta.bool", value=True) == cast_boolean
assert _treat_meta_field(field="meta.bool", value=[True, False]) == cast_booleanit compares two
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test remains quite understandable in any case. For this reason, I was advocating to resolve to string in the test. What's your opinion?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm confused. To make sure I understand your point. You agree with the approach for the test Regarding the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I was pointing to the wrong test. It's If we want to resolve to a string for easier comparison, we can use a cursor as I proposed in #2964 (comment) In any case, I'll leave this decision up to you. |
||
|
|
||
|
|
||
| def test_treat_meta_field_rejects_unsafe_metadata_key(): | ||
| with pytest.raises(FilterError, match="Invalid metadata field name"): | ||
| _treat_meta_field(field="meta.name' OR 1=1 --", value="x") | ||
| def test_treat_meta_field_sql_injection_is_safely_escaped(): | ||
| # SQL injection attempts are safely escaped by SQLLiteral rather than rejected | ||
| result = _treat_meta_field(field="meta.name' OR 1=1 --", value="x") | ||
| assert isinstance(result, Composed) | ||
| assert result == SQL("meta->>") + SQLLiteral("name' OR 1=1 --") | ||
|
|
||
|
|
||
| def test_comparison_condition_missing_operator(): | ||
|
|
@@ -206,10 +225,38 @@ def test_logical_condition_nested(): | |
| ], | ||
| } | ||
| query, values = _parse_logical_condition(condition) | ||
| assert query == ( | ||
| "((meta->>'domain' IS DISTINCT FROM %s OR meta->>'chapter' = ANY(%s)) " | ||
| "AND ((meta->>'number')::integer >= %s OR meta->>'author' IS NULL OR meta->>'author' != ALL(%s)))" | ||
| assert isinstance(query, Composed) | ||
|
|
||
| domain_field = SQL("meta->>") + SQLLiteral("domain") | ||
| chapter_field = SQL("meta->>") + SQLLiteral("chapter") | ||
| number_field = SQL("(") + SQL("meta->>") + SQLLiteral("number") + SQL(")::integer") | ||
| author_field = SQL("meta->>") + SQLLiteral("author") | ||
|
|
||
| expected = ( | ||
| SQL("(") | ||
| + SQL(" AND ").join( | ||
| [ | ||
| SQL("(") | ||
| + SQL(" OR ").join( | ||
| [ | ||
| SQL("{} IS DISTINCT FROM %s").format(domain_field), | ||
| SQL("{} = ANY(%s)").format(chapter_field), | ||
| ] | ||
| ) | ||
| + SQL(")"), | ||
| SQL("(") | ||
| + SQL(" OR ").join( | ||
| [ | ||
| SQL("{} >= %s").format(number_field), | ||
| SQL("{} IS NULL OR {} != ALL(%s)").format(author_field, author_field), | ||
| ] | ||
| ) | ||
| + SQL(")"), | ||
| ] | ||
| ) | ||
| + SQL(")") | ||
| ) | ||
| assert query == expected | ||
| assert values == ["science", [["intro", "conclusion"]], 90, [["John", "Jane"]]] | ||
|
|
||
|
|
||
|
|
@@ -222,7 +269,20 @@ def test_convert_filters_to_where_clause_and_params(): | |
| ], | ||
| } | ||
| where_clause, params = _convert_filters_to_where_clause_and_params(filters) | ||
| assert where_clause == SQL(" WHERE ") + SQL("((meta->>'number')::integer = %s AND meta->>'chapter' = %s)") | ||
|
|
||
| number_field = SQL("(") + SQL("meta->>") + SQLLiteral("number") + SQL(")::integer") | ||
| chapter_field = SQL("meta->>") + SQLLiteral("chapter") | ||
| expected = SQL(" WHERE ") + ( | ||
| SQL("(") | ||
| + SQL(" AND ").join( | ||
| [ | ||
| SQL("{} = %s").format(number_field), | ||
| SQL("{} = %s").format(chapter_field), | ||
| ] | ||
| ) | ||
| + SQL(")") | ||
| ) | ||
| assert where_clause == expected | ||
| assert params == (100, "intro") | ||
|
|
||
|
|
||
|
|
@@ -235,7 +295,20 @@ def test_convert_filters_to_where_clause_and_params_handle_null(): | |
| ], | ||
| } | ||
| where_clause, params = _convert_filters_to_where_clause_and_params(filters) | ||
| assert where_clause == SQL(" WHERE ") + SQL("(meta->>'number' IS NULL AND meta->>'chapter' = %s)") | ||
|
|
||
| number_field = SQL("meta->>") + SQLLiteral("number") | ||
| chapter_field = SQL("meta->>") + SQLLiteral("chapter") | ||
| expected = SQL(" WHERE ") + ( | ||
| SQL("(") | ||
| + SQL(" AND ").join( | ||
| [ | ||
| SQL("{} IS NULL").format(number_field), | ||
| SQL("{} = %s").format(chapter_field), | ||
| ] | ||
| ) | ||
| + SQL(")") | ||
| ) | ||
| assert where_clause == expected | ||
| assert params == ("intro",) | ||
|
|
||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.