diff --git a/haystack/components/retrievers/__init__.py b/haystack/components/retrievers/__init__.py index 0eb2227822..92404414e5 100644 --- a/haystack/components/retrievers/__init__.py +++ b/haystack/components/retrievers/__init__.py @@ -11,6 +11,7 @@ "auto_merging_retriever": ["AutoMergingRetriever"], "filter_retriever": ["FilterRetriever"], "in_memory": ["InMemoryBM25Retriever", "InMemoryEmbeddingRetriever"], + "multi_filter_retriever": ["MultiFilterRetriever"], "multi_query_embedding_retriever": ["MultiQueryEmbeddingRetriever"], "multi_query_text_retriever": ["MultiQueryTextRetriever"], "sentence_window_retriever": ["SentenceWindowRetriever"], @@ -21,6 +22,7 @@ from .filter_retriever import FilterRetriever as FilterRetriever from .in_memory import InMemoryBM25Retriever as InMemoryBM25Retriever from .in_memory import InMemoryEmbeddingRetriever as InMemoryEmbeddingRetriever + from .multi_filter_retriever import MultiFilterRetriever as MultiFilterRetriever from .multi_query_embedding_retriever import MultiQueryEmbeddingRetriever as MultiQueryEmbeddingRetriever from .multi_query_text_retriever import MultiQueryTextRetriever as MultiQueryTextRetriever from .sentence_window_retriever import SentenceWindowRetriever as SentenceWindowRetriever diff --git a/haystack/components/retrievers/multi_filter_retriever.py b/haystack/components/retrievers/multi_filter_retriever.py new file mode 100644 index 0000000000..167e28d900 --- /dev/null +++ b/haystack/components/retrievers/multi_filter_retriever.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from concurrent.futures import ThreadPoolExecutor +from typing import Any + +from haystack import Document, component +from haystack.components.retrievers.filter_retriever import FilterRetriever +from haystack.utils.misc import _deduplicate_documents + + +@component +class MultiFilterRetriever: + """ + A component that retrieves documents using multiple filters in parallel. + + This component takes a list of filter dictionaries and uses a filter-capable retriever to retrieve matching + documents for each filter set in parallel. + + ### Usage example + + ```python + from haystack import Document + from haystack.components.retrievers import FilterRetriever, MultiFilterRetriever + from haystack.document_stores.in_memory import InMemoryDocumentStore + from haystack.components.writers import DocumentWriter + from haystack.document_stores.types import DuplicatePolicy + + documents = [ + Document(content="Python is a popular programming language", meta={"lang": "en"}), + Document(content="python ist eine beliebte Programmiersprache", meta={"lang": "de"}), + ] + + document_store = InMemoryDocumentStore() + writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP) + writer.run(documents=documents) + + filter_retriever = FilterRetriever(document_store=document_store) + multi_filter_retriever = MultiFilterRetriever(retriever=filter_retriever) + + filters = [ + {"field": "meta.lang", "operator": "==", "value": "en"}, + {"field": "meta.lang", "operator": "==", "value": "de"}, + ] + + result = multi_filter_retriever.run(filters=filters) + for doc in result["documents"]: + print(doc.content) + ``` + """ + + def __init__(self, *, retriever: FilterRetriever, max_workers: int = 3) -> None: + """ + Initialize MultiFilterRetriever. + + :param retriever: The filter-capable retriever to use for document retrieval. + :param max_workers: Maximum number of worker threads for parallel processing. + """ + self.retriever = retriever + self.max_workers = max_workers + + @component.output_types(documents=list[Document]) + def run(self, filters: list[dict[str, Any]]) -> dict[str, list[Document]]: + """ + Retrieve documents using multiple filters in parallel. + + :param filters: List of filter dictionaries to process. + :param retriever_kwargs: Optional dictionary of arguments to pass to the retriever's run method. + :returns: + A dictionary containing: + - `documents`: List of retrieved documents. + """ + docs: list[Document] = [] + + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + filters_results = executor.map(self._run_on_thread, filters) + for result in filters_results: + if not result: + continue + docs.extend(result) + + docs = _deduplicate_documents(docs) + + return {"documents": docs} + + def _run_on_thread(self, filters: dict[str, Any]) -> list[Document] | None: + """ + Process a single filter set on a separate thread. + + :param filters: The filter dictionary to process. + :param retriever_kwargs: Optional dictionary of arguments to pass to the retriever's run method. + :returns: + List of retrieved documents or None if no results. + """ + result = self.retriever.run(filters=filters) + if result and "documents" in result: + return result["documents"] + return None diff --git a/releasenotes/notes/feat-multi-filter-retriever-266283c61e693da3.yaml b/releasenotes/notes/feat-multi-filter-retriever-266283c61e693da3.yaml new file mode 100644 index 0000000000..12ab0c2521 --- /dev/null +++ b/releasenotes/notes/feat-multi-filter-retriever-266283c61e693da3.yaml @@ -0,0 +1,20 @@ +--- +features: + - | + Add support for ``MultiFilterRetriever``, a new retriever component that executes multiple filter + queries against a document store **in parallel** and returns a single, de-duplicated list of + documents. + + .. code-block:: python + + filter_retriever = FilterRetriever(document_store=document_store) + multi_filter_retriever = MultiFilterRetriever(retriever=filter_retriever) + + filters = [ + {"field": "meta.lang", "operator": "==", "value": "en"}, + {"field": "meta.lang", "operator": "==", "value": "de"}, + ] + + result = multi_filter_retriever.run(filters=filters) + for doc in result["documents"]: + print(doc.content) diff --git a/test/components/retrievers/test_multi_filter_retriever.py b/test/components/retrievers/test_multi_filter_retriever.py new file mode 100644 index 0000000000..dc43290c2f --- /dev/null +++ b/test/components/retrievers/test_multi_filter_retriever.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import pytest + +from haystack import Document, component +from haystack.components.retrievers.filter_retriever import FilterRetriever +from haystack.components.retrievers.multi_filter_retriever import MultiFilterRetriever +from haystack.components.writers import DocumentWriter +from haystack.core.serialization import component_from_dict, component_to_dict +from haystack.document_stores.in_memory import InMemoryDocumentStore +from haystack.document_stores.types import DuplicatePolicy + + +@pytest.fixture +def sample_documents() -> list[Document]: + return [ + Document(content="English text", id="doc1", meta={"lang": "en"}), + Document(content="German text", id="doc2", meta={"lang": "de"}), + ] + + +@pytest.fixture +def sample_document_store(sample_documents: list[Document]) -> InMemoryDocumentStore: + document_store = InMemoryDocumentStore() + DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP).run(documents=sample_documents) + return document_store + + +@pytest.fixture +def sample_filters() -> list[dict[str, Any]]: + return [ + {"field": "meta.lang", "operator": "==", "value": "en"}, + {"field": "meta.lang", "operator": "==", "value": "de"}, + ] + + +class TestMultiFilterRetriever: + def test_init_default(self, in_memory_doc_store: InMemoryDocumentStore) -> None: + retriever = FilterRetriever(document_store=in_memory_doc_store) + multi = MultiFilterRetriever(retriever=retriever) + + assert multi.retriever == retriever + assert multi.max_workers == 3 + + def test_init_with_parameters(self, in_memory_doc_store: InMemoryDocumentStore) -> None: + retriever = FilterRetriever(document_store=in_memory_doc_store) + multi = MultiFilterRetriever(retriever=retriever, max_workers=2) + + assert multi.max_workers == 2 + + def test_run_empty_filters(self, in_memory_doc_store: InMemoryDocumentStore) -> None: + multi = MultiFilterRetriever(retriever=FilterRetriever(document_store=in_memory_doc_store)) + + result = multi.run(filters=[]) + + assert result == {"documents": []} + + def test_run_multiple_filters( + self, sample_document_store: InMemoryDocumentStore, sample_filters: list[dict[str, Any]] + ) -> None: + multi = MultiFilterRetriever(retriever=FilterRetriever(document_store=sample_document_store)) + + result = multi.run(filters=sample_filters) + + assert "documents" in result + assert len(result["documents"]) == 2 + assert {doc.meta["lang"] for doc in result["documents"]} == {"en", "de"} + + def test_run_single_filter(self, sample_document_store: InMemoryDocumentStore) -> None: + multi = MultiFilterRetriever(retriever=FilterRetriever(document_store=sample_document_store)) + + result = multi.run(filters=[{"field": "meta.lang", "operator": "==", "value": "en"}]) + + assert "documents" in result + assert len(result["documents"]) == 1 + assert result["documents"][0].meta["lang"] == "en" + + def test_deduplication(self) -> None: + doc1 = Document(content="A", id="doc1", score=0.9) + doc2 = Document(content="B", id="doc2", score=0.8) + doc3 = Document(content="A", id="doc1", score=0.7) + + @component + class MockRetriever: + @component.output_types(documents=list[Document]) + def run(self, filters: dict[str, Any] | None = None, **kwargs: Any) -> dict[str, list[Document]]: + return {"documents": [doc1, doc2, doc3]} + + multi = MultiFilterRetriever(retriever=MockRetriever(), max_workers=1) + + result = multi.run(filters=[{}, {}]) + + assert len(result["documents"]) == 2 + assert {doc.id for doc in result["documents"]} == {"doc1", "doc2"} + + def test_to_dict(self, in_memory_doc_store: InMemoryDocumentStore) -> None: + retriever = FilterRetriever(document_store=in_memory_doc_store) + multi = MultiFilterRetriever(retriever=retriever, max_workers=2) + + data = component_to_dict(multi, "multi_filter") + + assert data["type"] == "haystack.components.retrievers.multi_filter_retriever.MultiFilterRetriever" + assert data["init_parameters"]["max_workers"] == 2 + + def test_from_dict(self, in_memory_doc_store: InMemoryDocumentStore) -> None: + retriever = FilterRetriever(document_store=in_memory_doc_store) + multi = MultiFilterRetriever(retriever=retriever, max_workers=2) + + serialized = component_to_dict(multi, "multi_filter") + deserialized = component_from_dict(MultiFilterRetriever, serialized, "multi_filter") + + assert isinstance(deserialized, MultiFilterRetriever) + assert deserialized.max_workers == 2