diff --git a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/open_search_hybrid_retriever.py b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/open_search_hybrid_retriever.py index 0c7eacc290..0a812446aa 100644 --- a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/open_search_hybrid_retriever.py +++ b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/open_search_hybrid_retriever.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any -from haystack import Document, Pipeline, default_from_dict, default_to_dict, logging, super_component +from haystack import AsyncPipeline, Document, default_from_dict, default_to_dict, logging, super_component from haystack.components.embedders.types import TextEmbedder from haystack.components.joiners import DocumentJoiner from haystack.components.joiners.document_joiner import JoinMode @@ -258,7 +258,7 @@ def run( """Run the hybrid retrieval pipeline and return retrieved documents.""" ... - def _create_pipeline(self, data: dict[str, Any]) -> Pipeline: + def _create_pipeline(self, data: dict[str, Any]) -> AsyncPipeline: """ Create the pipeline for the OpenSearchHybridRetriever. """ @@ -266,7 +266,7 @@ def _create_pipeline(self, data: dict[str, Any]) -> Pipeline: bm25_retriever = OpenSearchBM25Retriever(**data["bm25_retriever"]) document_joiner = DocumentJoiner(**data["document_joiner"]) - hybrid_retrieval = Pipeline() + hybrid_retrieval = AsyncPipeline() hybrid_retrieval.add_component("text_embedder", self.embedder) hybrid_retrieval.add_component("embedding_retriever", embedding_retriever) hybrid_retrieval.add_component("bm25_retriever", bm25_retriever) diff --git a/integrations/opensearch/tests/test_open_search_hybrid_retriever.py b/integrations/opensearch/tests/test_open_search_hybrid_retriever.py index de5a9310fb..70600e5d61 100644 --- a/integrations/opensearch/tests/test_open_search_hybrid_retriever.py +++ b/integrations/opensearch/tests/test_open_search_hybrid_retriever.py @@ -7,7 +7,7 @@ from unittest.mock import Mock import pytest -from haystack import Document, Pipeline +from haystack import AsyncPipeline, Document, Pipeline from haystack.components.embedders import SentenceTransformersTextEmbedder from haystack.core.component import component @@ -21,6 +21,10 @@ class MockedTextEmbedder: def run(self, text: str, param_a: str = "default", param_b: str = "another_default") -> dict[str, Any]: return {"embedding": [0.1, 0.2, 0.3], "metadata": {"text": text, "param_a": param_a, "param_b": param_b}} + @component.output_types(embedding=list[float]) + async def run_async(self, text: str, param_a: str = "default", param_b: str = "another_default") -> dict[str, Any]: + return {"embedding": [0.1, 0.2, 0.3], "metadata": {"text": text, "param_a": param_a, "param_b": param_b}} + class TestOpenSearchHybridRetriever: serialised = { # noqa: RUF012 @@ -153,8 +157,8 @@ def test_from_dict_without_optional_keys(self): def test_run(self, mock_embedder): # mocked document store mock_store = Mock(spec=OpenSearchDocumentStore) - mock_store._bm25_retrieval.return_value = [Document(content="Test doc BM25")] - mock_store._embedding_retrieval.return_value = [Document(content="Test doc Embedding")] + mock_store._bm25_retrieval_async.return_value = [Document(content="Test doc BM25")] + mock_store._embedding_retrieval_async.return_value = [Document(content="Test doc Embedding")] # use the mocked embedder retriever = OpenSearchHybridRetriever(document_store=mock_store, embedder=mock_embedder) @@ -168,8 +172,8 @@ def test_run(self, mock_embedder): def test_run_with_extra_arg(self, mock_embedder): # mocked document store mock_store = Mock(spec=OpenSearchDocumentStore) - mock_store._bm25_retrieval.return_value = [Document(content="Test doc BM25")] - mock_store._embedding_retrieval.return_value = [Document(content="Test doc Embedding")] + mock_store._bm25_retrieval_async.return_value = [Document(content="Test doc BM25")] + mock_store._embedding_retrieval_async.return_value = [Document(content="Test doc Embedding")] # use the mocked embedder retriever = OpenSearchHybridRetriever( @@ -181,8 +185,8 @@ def test_run_with_extra_arg(self, mock_embedder): result = retriever.run(query="test query") # Verify the retrievers were called with the extra arguments - mock_store._bm25_retrieval.assert_called_once() - mock_store._embedding_retrieval.assert_called_once() + mock_store._bm25_retrieval_async.assert_called_once() + mock_store._embedding_retrieval_async.assert_called_once() # Verify the results assert len(result) == 1 @@ -193,8 +197,8 @@ def test_run_with_extra_arg(self, mock_embedder): def test_run_with_extra_arg_invalid_param(self, mock_embedder): # mocked document store mock_store = Mock(spec=OpenSearchDocumentStore) - mock_store._bm25_retrieval.return_value = [Document(content="Test doc BM25")] - mock_store._embedding_retrieval.return_value = [Document(content="Test doc Embedding")] + mock_store._bm25_retrieval_async.return_value = [Document(content="Test doc BM25")] + mock_store._embedding_retrieval_async.return_value = [Document(content="Test doc Embedding")] with pytest.raises( ValueError, match=r"valid extra args are only: 'bm25_retriever' and 'embedding_retriever'\." @@ -209,8 +213,8 @@ def test_run_with_extra_arg_invalid_param(self, mock_embedder): def test_run_with_extra_runtime_params(self, mock_embedder): # mocked document store mock_store = Mock(spec=OpenSearchDocumentStore) - mock_store._bm25_retrieval.return_value = [Document(content="Test doc BM25")] - mock_store._embedding_retrieval.return_value = [Document(content="Test doc Embedding")] + mock_store._bm25_retrieval_async.return_value = [Document(content="Test doc BM25")] + mock_store._embedding_retrieval_async.return_value = [Document(content="Test doc Embedding")] # use the mocked embedder retriever = OpenSearchHybridRetriever(document_store=mock_store, embedder=mock_embedder) @@ -222,7 +226,7 @@ def test_run_with_extra_runtime_params(self, mock_embedder): top_k_embedding=1, ) - mock_store._bm25_retrieval.assert_called_once_with( + mock_store._bm25_retrieval_async.assert_called_once_with( query="test query", filters={"key": "value"}, top_k=1, @@ -231,7 +235,7 @@ def test_run_with_extra_runtime_params(self, mock_embedder): scale_score=False, custom_query=None, ) - mock_store._embedding_retrieval.assert_called_once_with( + mock_store._embedding_retrieval_async.assert_called_once_with( query_embedding=[0.1, 0.2, 0.3], filters={"key": "value"}, top_k=1, @@ -244,8 +248,8 @@ def test_run_in_pipeline(self, mock_embedder): # mocked document store pipeline = Pipeline() mock_store = Mock(spec=OpenSearchDocumentStore) - mock_store._bm25_retrieval.return_value = [Document(content="Test doc BM25")] - mock_store._embedding_retrieval.return_value = [Document(content="Test doc Embedding")] + mock_store._bm25_retrieval_async.return_value = [Document(content="Test doc BM25")] + mock_store._embedding_retrieval_async.return_value = [Document(content="Test doc Embedding")] # use the mocked embedder retriever = OpenSearchHybridRetriever(document_store=mock_store, embedder=mock_embedder) @@ -255,7 +259,106 @@ def test_run_in_pipeline(self, mock_embedder): # Should not fail _ = pipeline.run(data={"retriever": {"query": "test query", "filters_bm25": {"param_a": "default"}}}) - mock_store._bm25_retrieval.assert_called_once_with( + mock_store._bm25_retrieval_async.assert_called_once_with( + query="test query", + filters={"param_a": "default"}, + top_k=10, + all_terms_must_match=False, + fuzziness=0, + scale_score=False, + custom_query=None, + ) + mock_store._embedding_retrieval_async.assert_called_once_with( + query_embedding=[0.1, 0.2, 0.3], + filters={}, + top_k=10, + custom_query=None, + efficient_filtering=False, + search_kwargs=None, + ) + + @pytest.mark.asyncio + async def test_run_async(self, mock_embedder): + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._bm25_retrieval_async.return_value = [Document(content="Test doc BM25")] + mock_store._embedding_retrieval_async.return_value = [Document(content="Test doc Embedding")] + + retriever = OpenSearchHybridRetriever(document_store=mock_store, embedder=mock_embedder) + result = await retriever.run_async(query="test query") + + assert len(result) == 1 + assert len(result["documents"]) == 2 + assert any(doc.content == "Test doc BM25" for doc in result["documents"]) + assert any(doc.content == "Test doc Embedding" for doc in result["documents"]) + + @pytest.mark.asyncio + async def test_run_async_with_extra_arg(self, mock_embedder): + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._bm25_retrieval_async.return_value = [Document(content="Test doc BM25")] + mock_store._embedding_retrieval_async.return_value = [Document(content="Test doc Embedding")] + + retriever = OpenSearchHybridRetriever( + document_store=mock_store, + embedder=mock_embedder, + bm25_retriever={"raise_on_failure": True}, + embedding_retriever={"raise_on_failure": False}, + ) + result = await retriever.run_async(query="test query") + + mock_store._bm25_retrieval_async.assert_called_once() + mock_store._embedding_retrieval_async.assert_called_once() + + assert len(result) == 1 + assert len(result["documents"]) == 2 + assert any(doc.content == "Test doc BM25" for doc in result["documents"]) + assert any(doc.content == "Test doc Embedding" for doc in result["documents"]) + + @pytest.mark.asyncio + async def test_run_async_with_extra_runtime_params(self, mock_embedder): + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._bm25_retrieval_async.return_value = [Document(content="Test doc BM25")] + mock_store._embedding_retrieval_async.return_value = [Document(content="Test doc Embedding")] + + retriever = OpenSearchHybridRetriever(document_store=mock_store, embedder=mock_embedder) + await retriever.run_async( + query="test query", + filters_bm25={"key": "value"}, + filters_embedding={"key": "value"}, + top_k_bm25=1, + top_k_embedding=1, + ) + + mock_store._bm25_retrieval_async.assert_called_once_with( + query="test query", + filters={"key": "value"}, + top_k=1, + all_terms_must_match=False, + fuzziness=0, + scale_score=False, + custom_query=None, + ) + mock_store._embedding_retrieval_async.assert_called_once_with( + query_embedding=[0.1, 0.2, 0.3], + filters={"key": "value"}, + top_k=1, + custom_query=None, + efficient_filtering=False, + search_kwargs=None, + ) + + @pytest.mark.asyncio + async def test_run_async_in_pipeline(self, mock_embedder): + pipeline = AsyncPipeline() + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._bm25_retrieval_async.return_value = [Document(content="Test doc BM25")] + mock_store._embedding_retrieval_async.return_value = [Document(content="Test doc Embedding")] + + retriever = OpenSearchHybridRetriever(document_store=mock_store, embedder=mock_embedder) + pipeline.add_component("retriever", retriever) + + await pipeline.run_async(data={"retriever": {"query": "test query", "filters_bm25": {"param_a": "default"}}}) + + mock_store._bm25_retrieval_async.assert_called_once_with( query="test query", filters={"param_a": "default"}, top_k=10, @@ -264,7 +367,7 @@ def test_run_in_pipeline(self, mock_embedder): scale_score=False, custom_query=None, ) - mock_store._embedding_retrieval.assert_called_once_with( + mock_store._embedding_retrieval_async.assert_called_once_with( query_embedding=[0.1, 0.2, 0.3], filters={}, top_k=10,