Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion integrations/pgvector/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest.mock import patch
from unittest.mock import Mock, patch

import pytest
from psycopg import AsyncConnection, Connection

from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore

Expand Down Expand Up @@ -110,3 +111,15 @@ def mock_store(patches_for_unit_tests, monkeypatch): # noqa: ARG001 patches ar
)

yield store


@pytest.fixture
def mock_store_with_mock_connection(mock_store):
mock_store._connection = Mock(spec=Connection)
return mock_store


@pytest.fixture
def mock_store_with_mock_async_connection(mock_store):
mock_store._async_connection = Mock(spec=AsyncConnection)
return mock_store
154 changes: 153 additions & 1 deletion integrations/pgvector/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from unittest.mock import patch
from unittest.mock import Mock, patch

import psycopg
import pytest
Expand All @@ -24,6 +24,8 @@
WriteDocumentsTest,
)
from haystack.utils import Secret
from psycopg import Connection, Cursor, Error
from psycopg.sql import SQL

from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore

Expand Down Expand Up @@ -134,6 +136,18 @@ def test_init(monkeypatch):
assert document_store.keyword_index_name == "my_keyword_index"


@pytest.mark.usefixtures("patches_for_unit_tests")
def test_init_invalid_vector_type():
with pytest.raises(ValueError, match=r"vector_type must be one of.*"):
PgvectorDocumentStore(vector_type="invalid")


@pytest.mark.usefixtures("patches_for_unit_tests")
def test_init_invalid_vector_function():
with pytest.raises(ValueError, match=r"vector_function must be one of.*"):
PgvectorDocumentStore(vector_function="invalid")


@pytest.mark.usefixtures("patches_for_unit_tests")
def test_to_dict(monkeypatch):
monkeypatch.setenv("PG_CONN_STR", "some_connection_string")
Expand Down Expand Up @@ -175,6 +189,128 @@ def test_to_dict(monkeypatch):
}


def test_connection_is_valid_returns_true_on_success():
mock_connection = Mock(spec=Connection)
mock_connection.execute.return_value = None

assert PgvectorDocumentStore._connection_is_valid(mock_connection) is True


def test_connection_is_valid_returns_false_when_execute_raises():
mock_connection = Mock(spec=Connection)
mock_connection.execute.side_effect = Error("connection dropped")

assert PgvectorDocumentStore._connection_is_valid(mock_connection) is False


@pytest.mark.parametrize(
"cursor",
[None, Mock(spec=Cursor)],
ids=["cursor_is_none", "connection_is_none"],
)
def test_execute_sql_raises_when_not_initialized(mock_store, cursor):
assert mock_store._connection is None
with pytest.raises(ValueError, match="cursor or the connection is not initialized"):
mock_store._execute_sql(cursor=cursor, sql_query=SQL("SELECT 1"))


def test_execute_sql_converts_psycopg_error_to_document_store_error(mock_store_with_mock_connection):
mock_cursor = Mock(spec=Cursor)
original_error = Error("connection broken")
mock_cursor.execute.side_effect = original_error

with pytest.raises(DocumentStoreError, match="write failed") as exc_info:
mock_store_with_mock_connection._execute_sql(
cursor=mock_cursor,
sql_query=SQL("SELECT 1"),
error_msg="write failed",
)

mock_store_with_mock_connection._connection.rollback.assert_called_once_with()
assert exc_info.value.__cause__ is original_error


def test_write_documents_rejects_non_document_items(mock_store):
with pytest.raises(ValueError, match="must contain a list of objects of type Document"):
mock_store.write_documents([{"not": "a document"}])


def test_count_unique_metadata_by_filter_rejects_empty_fields(mock_store):
with pytest.raises(ValueError, match="metadata_fields must be a non-empty list"):
mock_store.count_unique_metadata_by_filter(filters={}, metadata_fields=[])


def test_check_and_build_embedding_retrieval_query_rejects_invalid_vector_function(mock_store):
with pytest.raises(ValueError, match="vector_function must be one of"):
mock_store._check_and_build_embedding_retrieval_query(
query_embedding=[0.1] * mock_store.embedding_dimension,
vector_function="invalid",
top_k=5,
)


@pytest.mark.parametrize(
"bad_field",
[
"field@invalid",
"field with spaces",
"field;drop",
"meta.field!",
"field' OR '1'='1",
"../etc/passwd",
"field/* comment */",
],
)
def test_normalize_metadata_field_name_rejects_invalid_chars(bad_field):
with pytest.raises(ValueError, match="Invalid metadata field name"):
PgvectorDocumentStore._normalize_metadata_field_name(bad_field)


@pytest.mark.parametrize(
"value, expected_type",
[
(True, "boolean"),
(42, "integer"),
(3.14, "real"),
("hello", "text"),
(["list", "fallback"], "text"),
],
)
def test_infer_metadata_field_type(value, expected_type):
assert PgvectorDocumentStore._infer_metadata_field_type(value) == expected_type


def test_analyze_metadata_fields_skips_non_dict_meta():
records = [{"meta": "not a dict"}, {"meta": None}]
assert PgvectorDocumentStore._analyze_metadata_fields_from_records(records) == {"content": {"type": "text"}}


def test_analyze_metadata_fields_defaults_null_first_value_to_text():
records = [{"meta": {"tag": None}}, {"meta": {"tag": 42}}]
result = PgvectorDocumentStore._analyze_metadata_fields_from_records(records)
assert result == {"content": {"type": "text"}, "tag": {"type": "text"}}


@pytest.mark.parametrize(
"result",
[None, {"min_value": None, "max_value": None}],
ids=["result_is_none", "values_are_none"],
)
def test_process_min_max_result_raises_when_no_values(result):
with pytest.raises(ValueError, match="Metadata field 'priority' has no values"):
PgvectorDocumentStore._process_min_max_result("priority", result)


def test_process_count_unique_metadata_result_returns_zero_dict_when_result_none():
counts = PgvectorDocumentStore._process_count_unique_metadata_result(None, ["category", "language"])
assert counts == {"category": 0, "language": 0}


def test_process_count_unique_metadata_result_uses_zero_for_missing_keys():
counts = PgvectorDocumentStore._process_count_unique_metadata_result({"category": 5}, ["category", "language"])
assert counts == {"category": 5, "language": 0}


@pytest.mark.integration
def test_halfvec_hnsw_write_documents(document_store_w_halfvec_hnsw_index: PgvectorDocumentStore):
documents = [
Expand Down Expand Up @@ -295,6 +431,22 @@ def test_delete_table_first_call(document_store):
document_store.delete_table() # if throw error, test fails


@pytest.mark.integration
def test_delete_documents_empty_list_is_noop(document_store: PgvectorDocumentStore):
docs = [
Document(id="1", content="hello"),
Document(id="2", content="world"),
]
document_store.write_documents(docs)
before = sorted(document_store.filter_documents(), key=lambda d: d.id)
assert len(before) == 2

document_store.delete_documents([])

after = sorted(document_store.filter_documents(), key=lambda d: d.id)
assert after == before


@pytest.mark.integration
def test_update_by_filter_empty_meta_raises_error(document_store: PgvectorDocumentStore):
docs = [Document(content="Doc 1", meta={"category": "A"})]
Expand Down
63 changes: 62 additions & 1 deletion integrations/pgvector/tests/test_document_store_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
#
# SPDX-License-Identifier: Apache-2.0

from unittest.mock import patch
from unittest.mock import Mock, patch

import psycopg
import pytest
from haystack.dataclasses.document import ByteStream, Document
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
from haystack.document_stores.types import DuplicatePolicy
from haystack.utils import Secret
from psycopg import AsyncConnection, Error
from psycopg.cursor_async import AsyncCursor
from psycopg.sql import SQL

from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore
Expand Down Expand Up @@ -797,3 +799,62 @@ async def test_get_metadata_field_unique_values_async(document_store: PgvectorDo
)
assert set(unique_priorities_filtered) == {"1"}
assert total_priorities_filtered == 1


@pytest.mark.asyncio
async def test_connection_is_valid_async_returns_true_on_success():
mock_connection = Mock(spec=AsyncConnection)
mock_connection.execute.return_value = None

assert await PgvectorDocumentStore._connection_is_valid_async(mock_connection) is True


@pytest.mark.asyncio
async def test_connection_is_valid_async_returns_false_when_execute_raises():
mock_connection = Mock(spec=AsyncConnection)
mock_connection.execute.side_effect = Error("connection dropped")

assert await PgvectorDocumentStore._connection_is_valid_async(mock_connection) is False


@pytest.mark.asyncio
@pytest.mark.parametrize(
"cursor",
[None, Mock(spec=AsyncCursor)],
ids=["cursor_is_none", "connection_is_none"],
)
async def test_execute_sql_async_raises_when_not_initialized(mock_store, cursor):
assert mock_store._async_connection is None
with pytest.raises(ValueError, match="cursor or the connection is not initialized"):
await mock_store._execute_sql_async(cursor=cursor, sql_query=SQL("SELECT 1"))


@pytest.mark.asyncio
async def test_execute_sql_async_converts_psycopg_error_to_document_store_error(
mock_store_with_mock_async_connection,
):
mock_cursor = Mock(spec=AsyncCursor)
original_error = Error("connection broken")
mock_cursor.execute.side_effect = original_error

with pytest.raises(DocumentStoreError, match="write failed") as exc_info:
await mock_store_with_mock_async_connection._execute_sql_async(
cursor=mock_cursor,
sql_query=SQL("SELECT 1"),
error_msg="write failed",
)

mock_store_with_mock_async_connection._async_connection.rollback.assert_awaited_once_with()
assert exc_info.value.__cause__ is original_error


@pytest.mark.asyncio
async def test_write_documents_async_rejects_non_document_items(mock_store):
with pytest.raises(ValueError, match="must contain a list of objects of type Document"):
await mock_store.write_documents_async([{"not": "a document"}])


@pytest.mark.asyncio
async def test_count_unique_metadata_by_filter_async_rejects_empty_fields(mock_store):
with pytest.raises(ValueError, match="metadata_fields must be a non-empty list"):
await mock_store.count_unique_metadata_by_filter_async(filters={}, metadata_fields=[])
Loading