diff --git a/integrations/pgvector/tests/conftest.py b/integrations/pgvector/tests/conftest.py index d6db33e411..7d2866fd1d 100644 --- a/integrations/pgvector/tests/conftest.py +++ b/integrations/pgvector/tests/conftest.py @@ -1,6 +1,7 @@ -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest +from psycopg import AsyncConnection, Connection from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore @@ -110,3 +111,15 @@ def mock_store(patches_for_unit_tests, monkeypatch): # noqa: ARG001 patches ar ) yield store + + +@pytest.fixture +def mock_store_with_mock_connection(mock_store): + mock_store._connection = Mock(spec=Connection) + return mock_store + + +@pytest.fixture +def mock_store_with_mock_async_connection(mock_store): + mock_store._async_connection = Mock(spec=AsyncConnection) + return mock_store diff --git a/integrations/pgvector/tests/test_document_store.py b/integrations/pgvector/tests/test_document_store.py index 2b7e9613eb..62fdd559bd 100644 --- a/integrations/pgvector/tests/test_document_store.py +++ b/integrations/pgvector/tests/test_document_store.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import patch +from unittest.mock import Mock, patch import psycopg import pytest @@ -24,6 +24,8 @@ WriteDocumentsTest, ) from haystack.utils import Secret +from psycopg import Connection, Cursor, Error +from psycopg.sql import SQL from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore @@ -134,6 +136,18 @@ def test_init(monkeypatch): assert document_store.keyword_index_name == "my_keyword_index" +@pytest.mark.usefixtures("patches_for_unit_tests") +def test_init_invalid_vector_type(): + with pytest.raises(ValueError, match=r"vector_type must be one of.*"): + PgvectorDocumentStore(vector_type="invalid") + + +@pytest.mark.usefixtures("patches_for_unit_tests") +def test_init_invalid_vector_function(): + with pytest.raises(ValueError, match=r"vector_function must be one of.*"): + PgvectorDocumentStore(vector_function="invalid") + + @pytest.mark.usefixtures("patches_for_unit_tests") def test_to_dict(monkeypatch): monkeypatch.setenv("PG_CONN_STR", "some_connection_string") @@ -175,6 +189,128 @@ def test_to_dict(monkeypatch): } +def test_connection_is_valid_returns_true_on_success(): + mock_connection = Mock(spec=Connection) + mock_connection.execute.return_value = None + + assert PgvectorDocumentStore._connection_is_valid(mock_connection) is True + + +def test_connection_is_valid_returns_false_when_execute_raises(): + mock_connection = Mock(spec=Connection) + mock_connection.execute.side_effect = Error("connection dropped") + + assert PgvectorDocumentStore._connection_is_valid(mock_connection) is False + + +@pytest.mark.parametrize( + "cursor", + [None, Mock(spec=Cursor)], + ids=["cursor_is_none", "connection_is_none"], +) +def test_execute_sql_raises_when_not_initialized(mock_store, cursor): + assert mock_store._connection is None + with pytest.raises(ValueError, match="cursor or the connection is not initialized"): + mock_store._execute_sql(cursor=cursor, sql_query=SQL("SELECT 1")) + + +def test_execute_sql_converts_psycopg_error_to_document_store_error(mock_store_with_mock_connection): + mock_cursor = Mock(spec=Cursor) + original_error = Error("connection broken") + mock_cursor.execute.side_effect = original_error + + with pytest.raises(DocumentStoreError, match="write failed") as exc_info: + mock_store_with_mock_connection._execute_sql( + cursor=mock_cursor, + sql_query=SQL("SELECT 1"), + error_msg="write failed", + ) + + mock_store_with_mock_connection._connection.rollback.assert_called_once_with() + assert exc_info.value.__cause__ is original_error + + +def test_write_documents_rejects_non_document_items(mock_store): + with pytest.raises(ValueError, match="must contain a list of objects of type Document"): + mock_store.write_documents([{"not": "a document"}]) + + +def test_count_unique_metadata_by_filter_rejects_empty_fields(mock_store): + with pytest.raises(ValueError, match="metadata_fields must be a non-empty list"): + mock_store.count_unique_metadata_by_filter(filters={}, metadata_fields=[]) + + +def test_check_and_build_embedding_retrieval_query_rejects_invalid_vector_function(mock_store): + with pytest.raises(ValueError, match="vector_function must be one of"): + mock_store._check_and_build_embedding_retrieval_query( + query_embedding=[0.1] * mock_store.embedding_dimension, + vector_function="invalid", + top_k=5, + ) + + +@pytest.mark.parametrize( + "bad_field", + [ + "field@invalid", + "field with spaces", + "field;drop", + "meta.field!", + "field' OR '1'='1", + "../etc/passwd", + "field/* comment */", + ], +) +def test_normalize_metadata_field_name_rejects_invalid_chars(bad_field): + with pytest.raises(ValueError, match="Invalid metadata field name"): + PgvectorDocumentStore._normalize_metadata_field_name(bad_field) + + +@pytest.mark.parametrize( + "value, expected_type", + [ + (True, "boolean"), + (42, "integer"), + (3.14, "real"), + ("hello", "text"), + (["list", "fallback"], "text"), + ], +) +def test_infer_metadata_field_type(value, expected_type): + assert PgvectorDocumentStore._infer_metadata_field_type(value) == expected_type + + +def test_analyze_metadata_fields_skips_non_dict_meta(): + records = [{"meta": "not a dict"}, {"meta": None}] + assert PgvectorDocumentStore._analyze_metadata_fields_from_records(records) == {"content": {"type": "text"}} + + +def test_analyze_metadata_fields_defaults_null_first_value_to_text(): + records = [{"meta": {"tag": None}}, {"meta": {"tag": 42}}] + result = PgvectorDocumentStore._analyze_metadata_fields_from_records(records) + assert result == {"content": {"type": "text"}, "tag": {"type": "text"}} + + +@pytest.mark.parametrize( + "result", + [None, {"min_value": None, "max_value": None}], + ids=["result_is_none", "values_are_none"], +) +def test_process_min_max_result_raises_when_no_values(result): + with pytest.raises(ValueError, match="Metadata field 'priority' has no values"): + PgvectorDocumentStore._process_min_max_result("priority", result) + + +def test_process_count_unique_metadata_result_returns_zero_dict_when_result_none(): + counts = PgvectorDocumentStore._process_count_unique_metadata_result(None, ["category", "language"]) + assert counts == {"category": 0, "language": 0} + + +def test_process_count_unique_metadata_result_uses_zero_for_missing_keys(): + counts = PgvectorDocumentStore._process_count_unique_metadata_result({"category": 5}, ["category", "language"]) + assert counts == {"category": 5, "language": 0} + + @pytest.mark.integration def test_halfvec_hnsw_write_documents(document_store_w_halfvec_hnsw_index: PgvectorDocumentStore): documents = [ @@ -295,6 +431,22 @@ def test_delete_table_first_call(document_store): document_store.delete_table() # if throw error, test fails +@pytest.mark.integration +def test_delete_documents_empty_list_is_noop(document_store: PgvectorDocumentStore): + docs = [ + Document(id="1", content="hello"), + Document(id="2", content="world"), + ] + document_store.write_documents(docs) + before = sorted(document_store.filter_documents(), key=lambda d: d.id) + assert len(before) == 2 + + document_store.delete_documents([]) + + after = sorted(document_store.filter_documents(), key=lambda d: d.id) + assert after == before + + @pytest.mark.integration def test_update_by_filter_empty_meta_raises_error(document_store: PgvectorDocumentStore): docs = [Document(content="Doc 1", meta={"category": "A"})] diff --git a/integrations/pgvector/tests/test_document_store_async.py b/integrations/pgvector/tests/test_document_store_async.py index 64b4a49bbb..fdf3d1a1b4 100644 --- a/integrations/pgvector/tests/test_document_store_async.py +++ b/integrations/pgvector/tests/test_document_store_async.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import patch +from unittest.mock import Mock, patch import psycopg import pytest @@ -10,6 +10,8 @@ from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy from haystack.utils import Secret +from psycopg import AsyncConnection, Error +from psycopg.cursor_async import AsyncCursor from psycopg.sql import SQL from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore @@ -797,3 +799,62 @@ async def test_get_metadata_field_unique_values_async(document_store: PgvectorDo ) assert set(unique_priorities_filtered) == {"1"} assert total_priorities_filtered == 1 + + +@pytest.mark.asyncio +async def test_connection_is_valid_async_returns_true_on_success(): + mock_connection = Mock(spec=AsyncConnection) + mock_connection.execute.return_value = None + + assert await PgvectorDocumentStore._connection_is_valid_async(mock_connection) is True + + +@pytest.mark.asyncio +async def test_connection_is_valid_async_returns_false_when_execute_raises(): + mock_connection = Mock(spec=AsyncConnection) + mock_connection.execute.side_effect = Error("connection dropped") + + assert await PgvectorDocumentStore._connection_is_valid_async(mock_connection) is False + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "cursor", + [None, Mock(spec=AsyncCursor)], + ids=["cursor_is_none", "connection_is_none"], +) +async def test_execute_sql_async_raises_when_not_initialized(mock_store, cursor): + assert mock_store._async_connection is None + with pytest.raises(ValueError, match="cursor or the connection is not initialized"): + await mock_store._execute_sql_async(cursor=cursor, sql_query=SQL("SELECT 1")) + + +@pytest.mark.asyncio +async def test_execute_sql_async_converts_psycopg_error_to_document_store_error( + mock_store_with_mock_async_connection, +): + mock_cursor = Mock(spec=AsyncCursor) + original_error = Error("connection broken") + mock_cursor.execute.side_effect = original_error + + with pytest.raises(DocumentStoreError, match="write failed") as exc_info: + await mock_store_with_mock_async_connection._execute_sql_async( + cursor=mock_cursor, + sql_query=SQL("SELECT 1"), + error_msg="write failed", + ) + + mock_store_with_mock_async_connection._async_connection.rollback.assert_awaited_once_with() + assert exc_info.value.__cause__ is original_error + + +@pytest.mark.asyncio +async def test_write_documents_async_rejects_non_document_items(mock_store): + with pytest.raises(ValueError, match="must contain a list of objects of type Document"): + await mock_store.write_documents_async([{"not": "a document"}]) + + +@pytest.mark.asyncio +async def test_count_unique_metadata_by_filter_async_rejects_empty_fields(mock_store): + with pytest.raises(ValueError, match="metadata_fields must be a non-empty list"): + await mock_store.count_unique_metadata_by_filter_async(filters={}, metadata_fields=[])