|
2 | 2 | # |
3 | 3 | # SPDX-License-Identifier: Apache-2.0 |
4 | 4 |
|
5 | | -from unittest.mock import patch |
| 5 | +from unittest.mock import Mock, patch |
6 | 6 |
|
7 | 7 | import psycopg |
8 | 8 | import pytest |
|
24 | 24 | WriteDocumentsTest, |
25 | 25 | ) |
26 | 26 | from haystack.utils import Secret |
| 27 | +from psycopg import Connection, Cursor, Error |
| 28 | +from psycopg.sql import SQL |
27 | 29 |
|
28 | 30 | from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore |
29 | 31 |
|
@@ -101,6 +103,8 @@ def test_invalid_connection_string(self, monkeypatch): |
101 | 103 | assert "Failed to connect to PostgreSQL database" in str(e) |
102 | 104 |
|
103 | 105 |
|
| 106 | + |
| 107 | + |
104 | 108 | @pytest.mark.usefixtures("patches_for_unit_tests") |
105 | 109 | def test_init(monkeypatch): |
106 | 110 | monkeypatch.setenv("PG_CONN_STR", "some_connection_string") |
@@ -134,6 +138,16 @@ def test_init(monkeypatch): |
134 | 138 | assert document_store.keyword_index_name == "my_keyword_index" |
135 | 139 |
|
136 | 140 |
|
| 141 | +@pytest.mark.usefixtures("patches_for_unit_tests") |
| 142 | +def test_init_invalid_vector_type(): |
| 143 | + with pytest.raises(ValueError, match=r"vector_type must be one of.*"): |
| 144 | + PgvectorDocumentStore(vector_type="invalid") |
| 145 | + |
| 146 | +@pytest.mark.usefixtures("patches_for_unit_tests") |
| 147 | +def test_init_invalid_vector_function(): |
| 148 | + with pytest.raises(ValueError, match=r"vector_function must be one of.*"): |
| 149 | + PgvectorDocumentStore(vector_function="invalid") |
| 150 | + |
137 | 151 | @pytest.mark.usefixtures("patches_for_unit_tests") |
138 | 152 | def test_to_dict(monkeypatch): |
139 | 153 | monkeypatch.setenv("PG_CONN_STR", "some_connection_string") |
@@ -175,6 +189,130 @@ def test_to_dict(monkeypatch): |
175 | 189 | } |
176 | 190 |
|
177 | 191 |
|
| 192 | +def test_connection_is_valid_returns_true_on_success(): |
| 193 | + mock_connection = Mock(spec=Connection) |
| 194 | + mock_connection.execute.return_value = None |
| 195 | + |
| 196 | + assert PgvectorDocumentStore._connection_is_valid(mock_connection) is True |
| 197 | + |
| 198 | + |
| 199 | +def test_connection_is_valid_returns_false_when_execute_raises(): |
| 200 | + mock_connection = Mock(spec=Connection) |
| 201 | + mock_connection.execute.side_effect = Error("connection dropped") |
| 202 | + |
| 203 | + assert PgvectorDocumentStore._connection_is_valid(mock_connection) is False |
| 204 | + |
| 205 | + |
| 206 | +@pytest.mark.parametrize( |
| 207 | + "cursor", |
| 208 | + [None, Mock(spec=Cursor)], |
| 209 | + ids=["cursor_is_none", "connection_is_none"], |
| 210 | +) |
| 211 | +def test_execute_sql_raises_when_not_initialized(mock_store, cursor): |
| 212 | + assert mock_store._connection is None |
| 213 | + with pytest.raises(ValueError, match="cursor or the connection is not initialized"): |
| 214 | + mock_store._execute_sql(cursor=cursor, sql_query=SQL("SELECT 1")) |
| 215 | + |
| 216 | + |
| 217 | +def test_execute_sql_converts_psycopg_error_to_document_store_error(mock_store_with_mock_connection): |
| 218 | + mock_cursor = Mock(spec=Cursor) |
| 219 | + original_error = Error("connection broken") |
| 220 | + mock_cursor.execute.side_effect = original_error |
| 221 | + |
| 222 | + with pytest.raises(DocumentStoreError, match="write failed") as exc_info: |
| 223 | + mock_store_with_mock_connection._execute_sql( |
| 224 | + cursor=mock_cursor, |
| 225 | + sql_query=SQL("SELECT 1"), |
| 226 | + error_msg="write failed", |
| 227 | + ) |
| 228 | + |
| 229 | + mock_store_with_mock_connection._connection.rollback.assert_called_once_with() |
| 230 | + assert exc_info.value.__cause__ is original_error |
| 231 | + |
| 232 | + |
| 233 | +def test_write_documents_rejects_non_document_items(mock_store): |
| 234 | + with pytest.raises(ValueError, match="must contain a list of objects of type Document"): |
| 235 | + mock_store.write_documents([{"not": "a document"}]) |
| 236 | + |
| 237 | + |
| 238 | +def test_count_unique_metadata_by_filter_rejects_empty_fields(mock_store): |
| 239 | + with pytest.raises(ValueError, match="metadata_fields must be a non-empty list"): |
| 240 | + mock_store.count_unique_metadata_by_filter(filters={}, metadata_fields=[]) |
| 241 | + |
| 242 | + |
| 243 | +def test_check_and_build_embedding_retrieval_query_rejects_invalid_vector_function(mock_store): |
| 244 | + with pytest.raises(ValueError, match="vector_function must be one of"): |
| 245 | + mock_store._check_and_build_embedding_retrieval_query( |
| 246 | + query_embedding=[0.1] * mock_store.embedding_dimension, |
| 247 | + vector_function="invalid", |
| 248 | + top_k=5, |
| 249 | + ) |
| 250 | + |
| 251 | + |
| 252 | +@pytest.mark.parametrize( |
| 253 | + "bad_field", |
| 254 | + [ |
| 255 | + "field@invalid", |
| 256 | + "field with spaces", |
| 257 | + "field;drop", |
| 258 | + "meta.field!", |
| 259 | + "field' OR '1'='1", |
| 260 | + "../etc/passwd", |
| 261 | + "field/* comment */", |
| 262 | + ], |
| 263 | +) |
| 264 | +def test_normalize_metadata_field_name_rejects_invalid_chars(bad_field): |
| 265 | + with pytest.raises(ValueError, match="Invalid metadata field name"): |
| 266 | + PgvectorDocumentStore._normalize_metadata_field_name(bad_field) |
| 267 | + |
| 268 | + |
| 269 | +@pytest.mark.parametrize( |
| 270 | + "value, expected_type", |
| 271 | + [ |
| 272 | + (True, "boolean"), |
| 273 | + (42, "integer"), |
| 274 | + (3.14, "real"), |
| 275 | + ("hello", "text"), |
| 276 | + (["list", "fallback"], "text"), |
| 277 | + ], |
| 278 | +) |
| 279 | +def test_infer_metadata_field_type(value, expected_type): |
| 280 | + assert PgvectorDocumentStore._infer_metadata_field_type(value) == expected_type |
| 281 | + |
| 282 | + |
| 283 | +def test_analyze_metadata_fields_skips_non_dict_meta(): |
| 284 | + records = [{"meta": "not a dict"}, {"meta": None}] |
| 285 | + assert PgvectorDocumentStore._analyze_metadata_fields_from_records(records) == {"content": {"type": "text"}} |
| 286 | + |
| 287 | + |
| 288 | +def test_analyze_metadata_fields_defaults_null_first_value_to_text(): |
| 289 | + records = [{"meta": {"tag": None}}, {"meta": {"tag": 42}}] |
| 290 | + result = PgvectorDocumentStore._analyze_metadata_fields_from_records(records) |
| 291 | + assert result == {"content": {"type": "text"}, "tag": {"type": "text"}} |
| 292 | + |
| 293 | + |
| 294 | +@pytest.mark.parametrize( |
| 295 | + "result", |
| 296 | + [None, {"min_value": None, "max_value": None}], |
| 297 | + ids=["result_is_none", "values_are_none"], |
| 298 | +) |
| 299 | +def test_process_min_max_result_raises_when_no_values(result): |
| 300 | + with pytest.raises(ValueError, match="Metadata field 'priority' has no values"): |
| 301 | + PgvectorDocumentStore._process_min_max_result("priority", result) |
| 302 | + |
| 303 | + |
| 304 | +def test_process_count_unique_metadata_result_returns_zero_dict_when_result_none(): |
| 305 | + counts = PgvectorDocumentStore._process_count_unique_metadata_result(None, ["category", "language"]) |
| 306 | + assert counts == {"category": 0, "language": 0} |
| 307 | + |
| 308 | + |
| 309 | +def test_process_count_unique_metadata_result_uses_zero_for_missing_keys(): |
| 310 | + counts = PgvectorDocumentStore._process_count_unique_metadata_result( |
| 311 | + {"category": 5}, ["category", "language"] |
| 312 | + ) |
| 313 | + assert counts == {"category": 5, "language": 0} |
| 314 | + |
| 315 | + |
178 | 316 | @pytest.mark.integration |
179 | 317 | def test_halfvec_hnsw_write_documents(document_store_w_halfvec_hnsw_index: PgvectorDocumentStore): |
180 | 318 | documents = [ |
@@ -295,6 +433,22 @@ def test_delete_table_first_call(document_store): |
295 | 433 | document_store.delete_table() # if throw error, test fails |
296 | 434 |
|
297 | 435 |
|
| 436 | +@pytest.mark.integration |
| 437 | +def test_delete_documents_empty_list_is_noop(document_store: PgvectorDocumentStore): |
| 438 | + docs = [ |
| 439 | + Document(id="1", content="hello"), |
| 440 | + Document(id="2", content="world"), |
| 441 | + ] |
| 442 | + document_store.write_documents(docs) |
| 443 | + before = sorted(document_store.filter_documents(), key=lambda d: d.id) |
| 444 | + assert len(before) == 2 |
| 445 | + |
| 446 | + document_store.delete_documents([]) |
| 447 | + |
| 448 | + after = sorted(document_store.filter_documents(), key=lambda d: d.id) |
| 449 | + assert after == before |
| 450 | + |
| 451 | + |
298 | 452 | @pytest.mark.integration |
299 | 453 | def test_update_by_filter_empty_meta_raises_error(document_store: PgvectorDocumentStore): |
300 | 454 | docs = [Document(content="Doc 1", meta={"category": "A"})] |
|
0 commit comments