diff --git a/integrations/chroma/tests/test_document_store.py b/integrations/chroma/tests/test_document_store.py index 27f2dc8202..83e148e45c 100644 --- a/integrations/chroma/tests/test_document_store.py +++ b/integrations/chroma/tests/test_document_store.py @@ -42,53 +42,7 @@ def clear_chroma_system_cache(): SharedSystemClient.clear_system_cache() -class TestDocumentStore( - CountDocumentsTest, - DeleteDocumentsTest, - FilterDocumentsTest, - FilterableDocsFixtureMixin, - UpdateByFilterTest, - DeleteAllTest, - DeleteByFilterTest, - CountDocumentsByFilterTest, - CountUniqueMetadataByFilterTest, - GetMetadataFieldsInfoTest, - GetMetadataFieldMinMaxTest, - GetMetadataFieldUniqueValuesTest, -): - """ - Common test cases will be provided by `DocumentStoreBaseTests` but - you can add more to this class. - """ - - @pytest.fixture - def document_store(self, embedding_function) -> ChromaDocumentStore: - """ - This is the most basic requirement for the child class: provide - an instance of this document store so the base class can use it. - """ - with mock.patch( - "haystack_integrations.document_stores.chroma.document_store.get_embedding_function" - ) as get_func: - get_func.return_value = embedding_function - return ChromaDocumentStore(embedding_function="test_function", collection_name=str(uuid.uuid1())) - - def assert_documents_are_equal(self, received: list[Document], expected: list[Document]): - """ - Assert that two lists of Documents are equal. - This is used in every test, if a Document Store implementation has a different behaviour - it should override this method. - - This can happen for example when the Document Store sets a score to returned Documents. - Since we can't know what the score will be, we can't compare the Documents reliably. - """ - received.sort(key=operator.attrgetter("id")) - expected.sort(key=operator.attrgetter("id")) - - for doc_received, doc_expected in zip(received, expected, strict=True): - assert doc_received.content == doc_expected.content - assert doc_received.meta == doc_expected.meta - +class TestDocumentStoreUnit: def test_init_in_memory(self): store = ChromaDocumentStore() @@ -122,17 +76,6 @@ def test_invalid_initialization_both_host_and_persist_path(self): store = ChromaDocumentStore(persist_path="./path/to/local/store", host="localhost") store._ensure_initialized() - def test_client_settings_applied(self, clear_chroma_system_cache): - """ - Chroma's in-memory client uses a singleton pattern with an internal cache. - Once a client is created with certain settings, Chroma rejects creating another - with different settings in the same process. We clear the cache before and after - this test to avoid conflicts with other tests that use default settings. - """ - store = ChromaDocumentStore(client_settings={"anonymized_telemetry": False}) - store._ensure_initialized() - assert store._client.get_settings().anonymized_telemetry is False - def test_to_dict(self, request): ds = ChromaDocumentStore( collection_name=request.node.name, @@ -182,6 +125,66 @@ def test_same_collection_name_reinitialization(self): ChromaDocumentStore("test_1") ChromaDocumentStore("test_1") + +@pytest.mark.integration +class TestDocumentStore( + CountDocumentsTest, + DeleteDocumentsTest, + FilterDocumentsTest, + FilterableDocsFixtureMixin, + UpdateByFilterTest, + DeleteAllTest, + DeleteByFilterTest, + CountDocumentsByFilterTest, + CountUniqueMetadataByFilterTest, + GetMetadataFieldsInfoTest, + GetMetadataFieldMinMaxTest, + GetMetadataFieldUniqueValuesTest, +): + """ + Common test cases will be provided by `DocumentStoreBaseTests` but + you can add more to this class. + """ + + @pytest.fixture + def document_store(self, embedding_function) -> ChromaDocumentStore: + """ + This is the most basic requirement for the child class: provide + an instance of this document store so the base class can use it. + """ + with mock.patch( + "haystack_integrations.document_stores.chroma.document_store.get_embedding_function" + ) as get_func: + get_func.return_value = embedding_function + return ChromaDocumentStore(embedding_function="test_function", collection_name=str(uuid.uuid1())) + + def assert_documents_are_equal(self, received: list[Document], expected: list[Document]): + """ + Assert that two lists of Documents are equal. + This is used in every test, if a Document Store implementation has a different behaviour + it should override this method. + + This can happen for example when the Document Store sets a score to returned Documents. + Since we can't know what the score will be, we can't compare the Documents reliably. + """ + received.sort(key=operator.attrgetter("id")) + expected.sort(key=operator.attrgetter("id")) + + for doc_received, doc_expected in zip(received, expected, strict=True): + assert doc_received.content == doc_expected.content + assert doc_received.meta == doc_expected.meta + + def test_client_settings_applied(self, clear_chroma_system_cache): + """ + Chroma's in-memory client uses a singleton pattern with an internal cache. + Once a client is created with certain settings, Chroma rejects creating another + with different settings in the same process. We clear the cache before and after + this test to avoid conflicts with other tests that use default settings. + """ + store = ChromaDocumentStore(client_settings={"anonymized_telemetry": False}) + store._ensure_initialized() + assert store._client.get_settings().anonymized_telemetry is False + def test_distance_metric_initialization(self): store = ChromaDocumentStore("test_2", distance_function="cosine") store._ensure_initialized() @@ -445,7 +448,6 @@ def test_comparison_less_than_equal_with_none(self, document_store, filterable_d def test_not_operator(self, document_store, filterable_docs): pass - @pytest.mark.integration def test_search(self): document_store = ChromaDocumentStore() documents = [ @@ -491,7 +493,6 @@ def test_delete_all_documents_index_recreation(self, document_store: ChromaDocum document_store.write_documents(docs) assert document_store.count_documents() == 2 - @pytest.mark.integration def test_search_embeddings(self, document_store: ChromaDocumentStore): query_embedding = TEST_EMBEDDING_1 documents = [ @@ -515,6 +516,7 @@ def test_search_embeddings(self, document_store: ChromaDocumentStore): assert len(result_empty_filters[0]) == 2 +@pytest.mark.integration class TestMetadataOperations: """Test new metadata query operations for ChromaDocumentStore""" diff --git a/integrations/chroma/tests/test_document_store_async.py b/integrations/chroma/tests/test_document_store_async.py index 77559d52e4..ed285a6f5e 100644 --- a/integrations/chroma/tests/test_document_store_async.py +++ b/integrations/chroma/tests/test_document_store_async.py @@ -264,7 +264,6 @@ async def test_update_by_filter_async_no_matches(self, document_store: ChromaDoc assert updated_count == 0 assert await document_store.count_documents_async() == 2 - @pytest.mark.integration async def test_search_embeddings_async(self, document_store: ChromaDocumentStore): query_embedding = TEST_EMBEDDING_1 documents = [ diff --git a/integrations/elasticsearch/tests/test_document_store.py b/integrations/elasticsearch/tests/test_document_store.py index 59c93a0dce..993472dc11 100644 --- a/integrations/elasticsearch/tests/test_document_store.py +++ b/integrations/elasticsearch/tests/test_document_store.py @@ -232,6 +232,37 @@ def test_client_initialization_with_api_key_string(_mock_async_es, _mock_es): assert async_call_args[1]["api_key"] == "test_api_key" +@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") +def test_init_with_custom_mapping(mock_elasticsearch): + custom_mapping = { + "properties": { + "embedding": {"type": "dense_vector", "index": True, "similarity": "dot_product"}, + "content": {"type": "text"}, + }, + "dynamic_templates": [ + { + "strings": { + "path_match": "*", + "match_mapping_type": "string", + "mapping": { + "type": "keyword", + }, + } + } + ], + } + mock_client = Mock( + indices=Mock(create=Mock(), exists=Mock(return_value=False)), + ) + mock_elasticsearch.return_value = mock_client + + _ = ElasticsearchDocumentStore(hosts="http://testhost:9200", custom_mapping=custom_mapping).client + mock_client.indices.create.assert_called_once_with( + index="default", + mappings=custom_mapping, + ) + + @pytest.mark.integration class TestDocumentStore( DocumentStoreBaseExtendedTests, @@ -476,36 +507,6 @@ def test_write_documents_different_embedding_sizes_fail(self, document_store: El with pytest.raises(DocumentStoreError): document_store.write_documents(docs) - @patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") - def test_init_with_custom_mapping(self, mock_elasticsearch): - custom_mapping = { - "properties": { - "embedding": {"type": "dense_vector", "index": True, "similarity": "dot_product"}, - "content": {"type": "text"}, - }, - "dynamic_templates": [ - { - "strings": { - "path_match": "*", - "match_mapping_type": "string", - "mapping": { - "type": "keyword", - }, - } - } - ], - } - mock_client = Mock( - indices=Mock(create=Mock(), exists=Mock(return_value=False)), - ) - mock_elasticsearch.return_value = mock_client - - _ = ElasticsearchDocumentStore(hosts="http://testhost:9200", custom_mapping=custom_mapping).client - mock_client.indices.create.assert_called_once_with( - index="default", - mappings=custom_mapping, - ) - def test_delete_all_documents_index_recreation(self, document_store: ElasticsearchDocumentStore): # populate the index with some documents docs = [Document(id="1", content="A first document"), Document(id="2", content="Second document")] diff --git a/integrations/mongodb_atlas/tests/test_document_store.py b/integrations/mongodb_atlas/tests/test_document_store.py index 8a98d9d64b..2c75bf3168 100644 --- a/integrations/mongodb_atlas/tests/test_document_store.py +++ b/integrations/mongodb_atlas/tests/test_document_store.py @@ -272,7 +272,6 @@ def test_get_metadata_field_unique_values(self, document_store: MongoDBAtlasDocu assert len(values_page) == 1 assert values_page[0] in ["alpha", "beta", "gamma"] - @pytest.mark.integration def test_custom_embedding_field(self): """Test that the custom embedding field is correctly used in the document store.""" # Create a document store with a custom embedding field @@ -315,7 +314,6 @@ def test_custom_embedding_field(self): finally: database[collection_name].drop() - @pytest.mark.integration def test_custom_content_field(self): """Test that the custom content field is correctly used in the document store.""" # Create a document store with a custom content field diff --git a/integrations/mongodb_atlas/tests/test_fulltext_retrieval.py b/integrations/mongodb_atlas/tests/test_fulltext_retrieval.py index 9e8a60229b..cf390aeed5 100644 --- a/integrations/mongodb_atlas/tests/test_fulltext_retrieval.py +++ b/integrations/mongodb_atlas/tests/test_fulltext_retrieval.py @@ -4,7 +4,7 @@ import os from time import sleep -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest from haystack import Document @@ -13,9 +13,11 @@ from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore -def get_document_store(**kwargs): +def get_document_store(connection_string=None, **kwargs): + if connection_string is None: + connection_string = Secret.from_env_var("MONGO_CONNECTION_STRING_2") return MongoDBAtlasDocumentStore( - mongo_connection_string=Secret.from_env_var("MONGO_CONNECTION_STRING_2"), + mongo_connection_string=connection_string, database_name="haystack_test", collection_name="test_collection", vector_search_index="cosine_index", @@ -24,36 +26,12 @@ def get_document_store(**kwargs): ) -@pytest.mark.skipif( - not os.environ.get("MONGO_CONNECTION_STRING_2"), - reason="No MongoDB Atlas connection string provided", +@patch( + "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore._ensure_connection_setup" ) -@pytest.mark.integration -class TestFullTextRetrieval: - @pytest.fixture(scope="class") - def document_store(self) -> MongoDBAtlasDocumentStore: - return get_document_store() - - @pytest.fixture(autouse=True, scope="class") - def setup_teardown(self, document_store): - document_store._ensure_connection_setup() - document_store._collection.delete_many({}) - document_store.write_documents( - [ - Document(content="The quick brown fox chased the dog", meta={"meta_field": "right_value"}), - Document(content="The fox was brown", meta={"meta_field": "right_value"}), - Document(content="The lazy dog"), - Document(content="fox fox fox"), - ] - ) - - # Wait for documents to be indexed - sleep(5) - - yield - - def test_pipeline_correctly_passes_parameters(self, document_store): - document_store = get_document_store() +class TestFullTextRetrievalUnit: + def test_pipeline_correctly_passes_parameters(self, _mock_setup): + document_store = get_document_store(connection_string=Secret.from_token("test")) mock_collection = MagicMock() document_store._collection = mock_collection mock_collection.aggregate.return_value = [] @@ -98,9 +76,9 @@ def test_pipeline_correctly_passes_parameters(self, document_store): # Explicitly verify that the path in the text search is using the content_field assert actual_pipeline[0]["$search"]["compound"]["must"][0]["text"]["path"] == document_store.content_field - def test_pipeline_with_custom_content_field(self, document_store): + def test_pipeline_with_custom_content_field(self, _mock_setup): # Create a document store with a custom content field - document_store = get_document_store(content_field="custom_text") + document_store = get_document_store(connection_string=Secret.from_token("test"), content_field="custom_text") mock_collection = MagicMock() document_store._collection = mock_collection mock_collection.aggregate.return_value = [] @@ -125,6 +103,35 @@ def test_pipeline_with_custom_content_field(self, document_store): assert "$addFields" in actual_pipeline[3] assert "$project" in actual_pipeline[4] + +@pytest.mark.skipif( + not os.environ.get("MONGO_CONNECTION_STRING_2"), + reason="No MongoDB Atlas connection string provided", +) +@pytest.mark.integration +class TestFullTextRetrieval: + @pytest.fixture(scope="class") + def document_store(self) -> MongoDBAtlasDocumentStore: + return get_document_store() + + @pytest.fixture(autouse=True, scope="class") + def setup_teardown(self, document_store): + document_store._ensure_connection_setup() + document_store._collection.delete_many({}) + document_store.write_documents( + [ + Document(content="The quick brown fox chased the dog", meta={"meta_field": "right_value"}), + Document(content="The fox was brown", meta={"meta_field": "right_value"}), + Document(content="The lazy dog"), + Document(content="fox fox fox"), + ] + ) + + # Wait for documents to be indexed + sleep(5) + + yield + def test_query_retrieval(self, document_store: MongoDBAtlasDocumentStore): results = document_store._fulltext_retrieval(query="fox", top_k=2) assert len(results) == 2 diff --git a/integrations/opensearch/tests/test_document_store.py b/integrations/opensearch/tests/test_document_store.py index 28a8f277f9..7c6200d4ba 100644 --- a/integrations/opensearch/tests/test_document_store.py +++ b/integrations/opensearch/tests/test_document_store.py @@ -857,7 +857,6 @@ def test_get_metadata_field_unique_values(self, document_store: OpenSearchDocume unique_priorities_filtered, _ = document_store.get_metadata_field_unique_values("meta.priority", "Doc 1", 10) assert set(unique_priorities_filtered) == {"1"} - @pytest.mark.integration def test_write_with_routing(self, document_store: OpenSearchDocumentStore): """Test writing documents with routing metadata""" docs = [ @@ -884,7 +883,6 @@ def test_write_with_routing(self, document_store: OpenSearchDocumentStore): assert retrieved_by_id["3"].meta == {} - @pytest.mark.integration def test_delete_with_routing(self, document_store: OpenSearchDocumentStore): """Test deleting documents with routing""" docs = [ @@ -899,7 +897,6 @@ def test_delete_with_routing(self, document_store: OpenSearchDocumentStore): assert document_store.count_documents() == 1 - @pytest.mark.integration def test_metadata_search_fuzzy_mode(self, document_store: OpenSearchDocumentStore): """Test metadata search in fuzzy mode.""" docs = [ @@ -928,7 +925,6 @@ def test_metadata_search_fuzzy_mode(self, document_store: OpenSearchDocumentStor categories = [row.get("category", "").lower() for row in result] assert any("python" in cat for cat in categories) - @pytest.mark.integration def test_metadata_search_strict_mode(self, document_store: OpenSearchDocumentStore): """Test metadata search in strict mode.""" docs = [ @@ -951,7 +947,6 @@ def test_metadata_search_strict_mode(self, document_store: OpenSearchDocumentSto assert all(isinstance(row, dict) for row in result) assert all("category" in row for row in result) - @pytest.mark.integration def test_metadata_search_multiple_fields(self, document_store: OpenSearchDocumentStore): """Test metadata search across multiple fields.""" docs = [ @@ -976,7 +971,6 @@ def test_metadata_search_multiple_fields(self, document_store: OpenSearchDocumen for row in result: assert all(key in ["category", "status"] for key in row.keys()) - @pytest.mark.integration def test_metadata_search_comma_separated_query(self, document_store: OpenSearchDocumentStore): """Test metadata search with comma-separated query parts.""" docs = [ @@ -998,7 +992,6 @@ def test_metadata_search_comma_separated_query(self, document_store: OpenSearchD assert len(result) > 0 assert all(isinstance(row, dict) for row in result) - @pytest.mark.integration def test_metadata_search_top_k(self, document_store: OpenSearchDocumentStore): """Test metadata search respects top_k parameter.""" docs = [Document(content=f"Doc {i}", meta={"category": "Python", "index": i}) for i in range(15)] @@ -1015,7 +1008,6 @@ def test_metadata_search_top_k(self, document_store: OpenSearchDocumentStore): assert isinstance(result, list) assert len(result) <= 5 - @pytest.mark.integration def test_metadata_search_with_filters(self, document_store: OpenSearchDocumentStore): """Test metadata search with additional filters.""" docs = [ @@ -1039,7 +1031,6 @@ def test_metadata_search_with_filters(self, document_store: OpenSearchDocumentSt # Should only return documents with priority == 1 assert len(result) >= 1 - @pytest.mark.integration def test_metadata_search_empty_fields(self, document_store: OpenSearchDocumentStore): """Test metadata search with empty fields list returns empty result.""" docs = [ @@ -1057,7 +1048,6 @@ def test_metadata_search_empty_fields(self, document_store: OpenSearchDocumentSt assert isinstance(result, list) assert len(result) == 0 - @pytest.mark.integration def test_metadata_search_deduplication(self, document_store: OpenSearchDocumentStore): """Test that metadata search deduplicates results.""" docs = [ diff --git a/integrations/opensearch/tests/test_document_store_async.py b/integrations/opensearch/tests/test_document_store_async.py index c3df531a63..e413c06fa4 100644 --- a/integrations/opensearch/tests/test_document_store_async.py +++ b/integrations/opensearch/tests/test_document_store_async.py @@ -674,7 +674,6 @@ async def test_get_metadata_field_unique_values_async(self, document_store: Open ) assert set(unique_priorities_filtered) == {"1"} - @pytest.mark.integration @pytest.mark.asyncio async def test_metadata_search_async_fuzzy_mode(self, document_store: OpenSearchDocumentStore): """Test async metadata search in fuzzy mode.""" @@ -698,7 +697,6 @@ async def test_metadata_search_async_fuzzy_mode(self, document_store: OpenSearch assert all(isinstance(row, dict) for row in result) assert all("category" in row for row in result) - @pytest.mark.integration @pytest.mark.asyncio async def test_metadata_search_async_strict_mode(self, document_store: OpenSearchDocumentStore): """Test async metadata search in strict mode.""" @@ -721,7 +719,6 @@ async def test_metadata_search_async_strict_mode(self, document_store: OpenSearc assert all(isinstance(row, dict) for row in result) assert all("category" in row for row in result) - @pytest.mark.integration @pytest.mark.asyncio async def test_metadata_search_async_multiple_fields(self, document_store: OpenSearchDocumentStore): """Test async metadata search across multiple fields.""" @@ -746,7 +743,6 @@ async def test_metadata_search_async_multiple_fields(self, document_store: OpenS for row in result: assert all(key in ["category", "status"] for key in row.keys()) - @pytest.mark.integration @pytest.mark.asyncio async def test_metadata_search_async_top_k(self, document_store: OpenSearchDocumentStore): """Test async metadata search respects top_k parameter.""" @@ -764,7 +760,6 @@ async def test_metadata_search_async_top_k(self, document_store: OpenSearchDocum assert isinstance(result, list) assert len(result) <= 5 - @pytest.mark.integration @pytest.mark.asyncio async def test_metadata_search_async_comma_separated_query(self, document_store: OpenSearchDocumentStore): """Test async metadata search with comma-separated query parts.""" @@ -787,7 +782,6 @@ async def test_metadata_search_async_comma_separated_query(self, document_store: assert len(result) > 0 assert all(isinstance(row, dict) for row in result) - @pytest.mark.integration @pytest.mark.asyncio async def test_metadata_search_async_with_filters(self, document_store: OpenSearchDocumentStore): """Test async metadata search with additional filters.""" @@ -812,7 +806,6 @@ async def test_metadata_search_async_with_filters(self, document_store: OpenSear # Should only return documents with priority == 1 assert len(result) >= 1 - @pytest.mark.integration @pytest.mark.asyncio async def test_metadata_search_async_empty_fields(self, document_store: OpenSearchDocumentStore): """Test async metadata search with empty fields list returns empty result.""" @@ -831,7 +824,6 @@ async def test_metadata_search_async_empty_fields(self, document_store: OpenSear assert isinstance(result, list) assert len(result) == 0 - @pytest.mark.integration @pytest.mark.asyncio async def test_metadata_search_async_deduplication(self, document_store: OpenSearchDocumentStore): """Test that async metadata search deduplicates results.""" @@ -890,7 +882,6 @@ async def test_query_sql(self, document_store: OpenSearchDocumentStore): with pytest.raises(DocumentStoreError, match="Failed to execute SQL query"): await document_store._query_sql_async(invalid_query) - @pytest.mark.integration @pytest.mark.asyncio async def test_query_sql_async_with_fetch_size(self, document_store: OpenSearchDocumentStore): """Test async SQL query with fetch_size parameter""" @@ -917,7 +908,6 @@ async def test_query_sql_async_with_fetch_size(self, document_store: OpenSearchD assert len(result["datarows"]) <= 5 assert result.get("cursor") is not None - @pytest.mark.integration @pytest.mark.asyncio async def test_query_sql_async_pagination_flow(self, document_store: OpenSearchDocumentStore): """Test async pagination flow with fetch_size""" diff --git a/integrations/valkey/tests/test_embedding_retriever.py b/integrations/valkey/tests/test_embedding_retriever.py index 8d37e1132a..e44db40c10 100644 --- a/integrations/valkey/tests/test_embedding_retriever.py +++ b/integrations/valkey/tests/test_embedding_retriever.py @@ -2,6 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 +from unittest.mock import AsyncMock, Mock + import pytest from haystack import Document from haystack.document_stores.types import FilterPolicy @@ -24,6 +26,18 @@ def document_store(): store.close() +@pytest.fixture +def mock_store(): + """A ValkeyDocumentStore that doesn't connect to any server, for unit tests.""" + return ValkeyDocumentStore( + nodes_list=[("localhost", 6379)], + index_name="test_retriever", + embedding_dim=3, + distance_metric="cosine", + metadata_fields={"category": str, "priority": int}, + ) + + @pytest.fixture def sample_documents(): return [ @@ -33,34 +47,30 @@ def sample_documents(): ] -class TestValkeyEmbeddingRetriever: - @pytest.mark.integration - def test_init(self, document_store): - retriever = ValkeyEmbeddingRetriever(document_store=document_store) - assert retriever.document_store == document_store +class TestValkeyEmbeddingRetrieverUnit: + def test_init(self, mock_store): + retriever = ValkeyEmbeddingRetriever(document_store=mock_store) + assert retriever.document_store == mock_store assert retriever.filters == {} assert retriever.top_k == 10 assert retriever.filter_policy == FilterPolicy.REPLACE - @pytest.mark.integration - def test_init_with_parameters(self, document_store): + def test_init_with_parameters(self, mock_store): filters = {"field": "meta.category", "operator": "==", "value": "A"} retriever = ValkeyEmbeddingRetriever( - document_store=document_store, filters=filters, top_k=5, filter_policy=FilterPolicy.MERGE + document_store=mock_store, filters=filters, top_k=5, filter_policy=FilterPolicy.MERGE ) assert retriever.filters == filters assert retriever.top_k == 5 assert retriever.filter_policy == FilterPolicy.MERGE - @pytest.mark.integration def test_init_invalid_document_store(self): with pytest.raises(ValueError, match="document_store must be an instance of ValkeyDocumentStore"): ValkeyEmbeddingRetriever(document_store="not_a_store") - @pytest.mark.integration - def test_to_dict(self, document_store): + def test_to_dict(self, mock_store): retriever = ValkeyEmbeddingRetriever( - document_store=document_store, + document_store=mock_store, filters={"field": "meta.category", "operator": "==", "value": "A"}, top_k=5, ) @@ -72,12 +82,11 @@ def test_to_dict(self, document_store): assert result["init_parameters"]["filters"] == {"field": "meta.category", "operator": "==", "value": "A"} assert result["init_parameters"]["top_k"] == 5 - @pytest.mark.integration - def test_from_dict(self, document_store): + def test_from_dict(self, mock_store): data = { "type": "haystack_integrations.components.retrievers.valkey.embedding_retriever.ValkeyEmbeddingRetriever", "init_parameters": { - "document_store": document_store.to_dict(), + "document_store": mock_store.to_dict(), "filters": {"field": "meta.category", "operator": "==", "value": "A"}, "top_k": 5, "filter_policy": "replace", @@ -88,6 +97,25 @@ def test_from_dict(self, document_store): assert retriever.filters == {"field": "meta.category", "operator": "==", "value": "A"} assert retriever.top_k == 5 + def test_run(self, mock_store): + mock_store._embedding_retrieval = Mock(return_value=[Document(content="result", score=0.9)]) + retriever = ValkeyEmbeddingRetriever(document_store=mock_store, top_k=2) + result = retriever.run(query_embedding=[0.1, 0.2, 0.3]) + assert "documents" in result + assert len(result["documents"]) == 1 + mock_store._embedding_retrieval.assert_called_once() + + @pytest.mark.asyncio + async def test_run_async(self, mock_store): + mock_store._embedding_retrieval_async = AsyncMock(return_value=[Document(content="result", score=0.9)]) + retriever = ValkeyEmbeddingRetriever(document_store=mock_store, top_k=2) + result = await retriever.run_async(query_embedding=[0.1, 0.2, 0.3]) + assert "documents" in result + assert len(result["documents"]) == 1 + mock_store._embedding_retrieval_async.assert_called_once() + + +class TestValkeyEmbeddingRetriever: @pytest.mark.integration def test_run(self, document_store, sample_documents): document_store.write_documents(sample_documents)