From adab3502a307aa62e5ab4b6cdd0e893a165746c8 Mon Sep 17 00:00:00 2001 From: Sriniketh J Date: Thu, 2 Apr 2026 18:04:05 +0530 Subject: [PATCH 1/7] feat: add support for multi filter retriever --- haystack/components/retrievers/__init__.py | 2 + .../retrievers/multi_filter_retriever.py | 140 ++++++++++++++++++ .../components/retrievers/types/__init__.py | 4 +- .../components/retrievers/types/protocol.py | 19 +++ .../retrievers/test_multi_filter_retriever.py | 114 ++++++++++++++ 5 files changed, 277 insertions(+), 2 deletions(-) create mode 100644 haystack/components/retrievers/multi_filter_retriever.py create mode 100644 test/components/retrievers/test_multi_filter_retriever.py diff --git a/haystack/components/retrievers/__init__.py b/haystack/components/retrievers/__init__.py index 0eb2227822..92404414e5 100644 --- a/haystack/components/retrievers/__init__.py +++ b/haystack/components/retrievers/__init__.py @@ -11,6 +11,7 @@ "auto_merging_retriever": ["AutoMergingRetriever"], "filter_retriever": ["FilterRetriever"], "in_memory": ["InMemoryBM25Retriever", "InMemoryEmbeddingRetriever"], + "multi_filter_retriever": ["MultiFilterRetriever"], "multi_query_embedding_retriever": ["MultiQueryEmbeddingRetriever"], "multi_query_text_retriever": ["MultiQueryTextRetriever"], "sentence_window_retriever": ["SentenceWindowRetriever"], @@ -21,6 +22,7 @@ from .filter_retriever import FilterRetriever as FilterRetriever from .in_memory import InMemoryBM25Retriever as InMemoryBM25Retriever from .in_memory import InMemoryEmbeddingRetriever as InMemoryEmbeddingRetriever + from .multi_filter_retriever import MultiFilterRetriever as MultiFilterRetriever from .multi_query_embedding_retriever import MultiQueryEmbeddingRetriever as MultiQueryEmbeddingRetriever from .multi_query_text_retriever import MultiQueryTextRetriever as MultiQueryTextRetriever from .sentence_window_retriever import SentenceWindowRetriever as SentenceWindowRetriever diff --git a/haystack/components/retrievers/multi_filter_retriever.py b/haystack/components/retrievers/multi_filter_retriever.py new file mode 100644 index 0000000000..54ee7f0994 --- /dev/null +++ b/haystack/components/retrievers/multi_filter_retriever.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from concurrent.futures import ThreadPoolExecutor +from typing import Any + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.components.retrievers.types import FilterRetriever +from haystack.core.serialization import component_to_dict +from haystack.utils.misc import _deduplicate_documents + + +@component +class MultiFilterRetriever: + """ + A component that retrieves documents using multiple filters in parallel. + + This component takes a list of filter dictionaries and uses a filter-capable retriever to retrieve matching + documents for each filter set in parallel. The results are combined, de-duplicated, and sorted by score. + + ### Usage example + + ```python + from haystack import Document + from haystack.components.retrievers import FilterRetriever, MultiFilterRetriever + from haystack.document_stores.in_memory import InMemoryDocumentStore + from haystack.components.writers import DocumentWriter + from haystack.document_stores.types import DuplicatePolicy + + documents = [ + Document(content="Python is a popular programming language", meta={"lang": "en"}), + Document(content="python ist eine beliebte Programmiersprache", meta={"lang": "de"}), + ] + + document_store = InMemoryDocumentStore() + writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP) + writer.run(documents=documents) + + filter_retriever = FilterRetriever(document_store=document_store) + multi_filter_retriever = MultiFilterRetriever(retriever=filter_retriever) + + filters = [ + {"field": "meta.lang", "operator": "==", "value": "en"}, + {"field": "meta.lang", "operator": "==", "value": "de"}, + ] + + result = multi_filter_retriever.run(filters=filters) + for doc in result["documents"]: + print(doc.content) + ``` + """ + + def __init__(self, *, retriever: FilterRetriever, max_workers: int = 3) -> None: + """ + Initialize MultiFilterRetriever. + + :param retriever: The filter-capable retriever to use for document retrieval. + :param max_workers: Maximum number of worker threads for parallel processing. + """ + self.retriever = retriever + self.max_workers = max_workers + self._is_warmed_up = False + + def warm_up(self) -> None: + """ + Warm up the retriever if it has a warm_up method. + """ + if not self._is_warmed_up: + if hasattr(self.retriever, "warm_up") and callable(self.retriever.warm_up): + self.retriever.warm_up() + self._is_warmed_up = True + + @component.output_types(documents=list[Document]) + def run( + self, filters: list[dict[str, Any]], retriever_kwargs: dict[str, Any] | None = None + ) -> dict[str, list[Document]]: + """ + Retrieve documents using multiple filters in parallel. + + :param filters: List of filter dictionaries to process. + :param retriever_kwargs: Optional dictionary of arguments to pass to the retriever's run method. + :returns: + A dictionary containing: + - `documents`: List of retrieved documents sorted by relevance score. + """ + docs: list[Document] = [] + retriever_kwargs = retriever_kwargs or {} + + if not self._is_warmed_up: + self.warm_up() + + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + filters_results = executor.map(lambda filt: self._run_on_thread(filt, retriever_kwargs), filters) + for result in filters_results: + if not result: + continue + docs.extend(result) + + docs = _deduplicate_documents(docs) + docs.sort(key=lambda x: x.score or 0.0, reverse=True) + return {"documents": docs} + + def _run_on_thread( + self, filters: dict[str, Any], retriever_kwargs: dict[str, Any] | None = None + ) -> list[Document] | None: + """ + Process a single filter set on a separate thread. + + :param filters: The filter dictionary to process. + :param retriever_kwargs: Optional dictionary of arguments to pass to the retriever's run method. + :returns: + List of retrieved documents or None if no results. + """ + result = self.retriever.run(filters=filters, **(retriever_kwargs or {})) + if result and "documents" in result: + return result["documents"] + return None + + def to_dict(self) -> dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + A dictionary representing the serialized component. + """ + return default_to_dict( + self, retriever=component_to_dict(obj=self.retriever, name="retriever"), max_workers=self.max_workers + ) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "MultiFilterRetriever": + """ + Deserializes the component from a dictionary. + + :param data: The dictionary to deserialize from. + :returns: + The deserialized component. + """ + return default_from_dict(cls, data) diff --git a/haystack/components/retrievers/types/__init__.py b/haystack/components/retrievers/types/__init__.py index 5afc417a7f..53ec088277 100644 --- a/haystack/components/retrievers/types/__init__.py +++ b/haystack/components/retrievers/types/__init__.py @@ -2,6 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -from .protocol import EmbeddingRetriever, TextRetriever +from .protocol import EmbeddingRetriever, FilterRetriever, TextRetriever -__all__ = ["TextRetriever", "EmbeddingRetriever"] +__all__ = ["TextRetriever", "EmbeddingRetriever", "FilterRetriever"] diff --git a/haystack/components/retrievers/types/protocol.py b/haystack/components/retrievers/types/protocol.py index fd2bb06664..ed709f266a 100644 --- a/haystack/components/retrievers/types/protocol.py +++ b/haystack/components/retrievers/types/protocol.py @@ -54,3 +54,22 @@ def run( `documents`: List of retrieved documents sorted by relevance score. """ ... + + +class FilterRetriever(Protocol): + """ + This protocol defines the minimal interface for retrievers that filter documents. + """ + + def run(self, filters: dict[str, Any] | None = None, **kwargs: Any) -> dict[str, Any]: + """ + Retrieve documents matching the provided filters. + + Implementing classes may accept additional optional parameters in their run method. + + :param filters: A dictionary of filters to apply when retrieving documents. + :returns: + A dictionary containing: + `documents`: List of retrieved documents. + """ + ... diff --git a/test/components/retrievers/test_multi_filter_retriever.py b/test/components/retrievers/test_multi_filter_retriever.py new file mode 100644 index 0000000000..c52c257152 --- /dev/null +++ b/test/components/retrievers/test_multi_filter_retriever.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import pytest + +from haystack import Document, component +from haystack.components.retrievers import FilterRetriever, MultiFilterRetriever +from haystack.components.writers import DocumentWriter +from haystack.document_stores.in_memory import InMemoryDocumentStore +from haystack.document_stores.types import DuplicatePolicy + + +@pytest.fixture +def sample_documents() -> list[Document]: + return [ + Document(content="English text", id="doc1", meta={"lang": "en"}), + Document(content="German text", id="doc2", meta={"lang": "de"}), + ] + + +@pytest.fixture +def sample_document_store(sample_documents: list[Document]) -> InMemoryDocumentStore: + document_store = InMemoryDocumentStore() + DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP).run(documents=sample_documents) + return document_store + + +class TestMultiFilterRetriever: + def test_init_with_default_parameters(self, in_memory_doc_store: InMemoryDocumentStore) -> None: + filter_retriever = FilterRetriever(document_store=in_memory_doc_store) + multi_retriever = MultiFilterRetriever(retriever=filter_retriever) + + assert multi_retriever.retriever == filter_retriever + assert multi_retriever.max_workers == 3 + + def test_init_with_custom_parameters(self, in_memory_doc_store: InMemoryDocumentStore) -> None: + filter_retriever = FilterRetriever(document_store=in_memory_doc_store) + multi_retriever = MultiFilterRetriever(retriever=filter_retriever, max_workers=2) + + assert multi_retriever.retriever == filter_retriever + assert multi_retriever.max_workers == 2 + + def test_run_with_empty_filters(self, in_memory_doc_store: InMemoryDocumentStore) -> None: + multi_retriever = MultiFilterRetriever(retriever=FilterRetriever(document_store=in_memory_doc_store)) + result = multi_retriever.run(filters=[]) + + assert result == {"documents": []} + + @pytest.mark.parametrize( + ("filters", "expected_languages"), + [ + ( + [ + {"field": "meta.lang", "operator": "==", "value": "en"}, + {"field": "meta.lang", "operator": "==", "value": "de"}, + ], + {"en", "de"}, + ) + ], + ) + def test_run_with_multiple_filters( + self, sample_document_store: InMemoryDocumentStore, filters: list[dict[str, Any]], expected_languages: set[str] + ) -> None: + filter_retriever = FilterRetriever(document_store=sample_document_store) + multi_retriever = MultiFilterRetriever(retriever=filter_retriever) + + result = multi_retriever.run(filters=filters) + + assert "documents" in result + assert {doc.meta["lang"] for doc in result["documents"]} == expected_languages + + def test_deduplication_with_overlapping_results(self) -> None: + doc1 = Document(content="Solar energy is renewable", id="doc1", score=0.9) + doc2 = Document(content="Wind energy is clean", id="doc2", score=0.8) + doc3 = Document(content="Solar energy is renewable", id="doc1", score=0.7) + + call_count = 0 + + @component + class MockRetriever: + @component.output_types(documents=list[Document]) + def run(self, filters: dict[str, Any] | None = None, **kwargs: Any) -> dict[str, list[Document]]: + nonlocal call_count + call_count += 1 + if call_count == 1: + return {"documents": [doc1, doc2]} + return {"documents": [doc3, doc2]} + + multi_retriever = MultiFilterRetriever(retriever=MockRetriever(), max_workers=1) + + result = multi_retriever.run( + filters=[ + {"field": "meta.lang", "operator": "==", "value": "en"}, + {"field": "meta.lang", "operator": "==", "value": "de"}, + ] + ) + + assert "documents" in result + assert len(result["documents"]) == 2 + assert [doc.content for doc in result["documents"]].count("Solar energy is renewable") == 1 + assert [doc.content for doc in result["documents"]].count("Wind energy is clean") == 1 + + def test_from_dict_roundtrip(self, in_memory_doc_store: InMemoryDocumentStore) -> None: + filter_retriever = FilterRetriever(document_store=in_memory_doc_store) + multi_retriever = MultiFilterRetriever(retriever=filter_retriever, max_workers=2) + + serialized = multi_retriever.to_dict() + deserialized = MultiFilterRetriever.from_dict(serialized) + + assert isinstance(deserialized, MultiFilterRetriever) + assert deserialized.max_workers == 2 From 96caeea1da89d10c25bb341d59524a4c505359a0 Mon Sep 17 00:00:00 2001 From: Sriniketh J Date: Thu, 2 Apr 2026 18:08:38 +0530 Subject: [PATCH 2/7] docs: add rn section --- ...lti-filter-retriever-266283c61e693da3.yaml | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 releasenotes/notes/feat-multi-filter-retriever-266283c61e693da3.yaml diff --git a/releasenotes/notes/feat-multi-filter-retriever-266283c61e693da3.yaml b/releasenotes/notes/feat-multi-filter-retriever-266283c61e693da3.yaml new file mode 100644 index 0000000000..878155020b --- /dev/null +++ b/releasenotes/notes/feat-multi-filter-retriever-266283c61e693da3.yaml @@ -0,0 +1,20 @@ +--- +features: + - | + Add support for ``MultiFilterRetriever``, a new retriever component that executes multiple filter + queries against a document store **in parallel** and returns a single, de-duplicated list of + documents sorted by relevance score. + + .. code-block:: python + + filter_retriever = FilterRetriever(document_store=document_store) + multi_filter_retriever = MultiFilterRetriever(retriever=filter_retriever) + + filters = [ + {"field": "meta.lang", "operator": "==", "value": "en"}, + {"field": "meta.lang", "operator": "==", "value": "de"}, + ] + + result = multi_filter_retriever.run(filters=filters) + for doc in result["documents"]: + print(doc.content) From 26810270142607c19b573d587e7a6a877aa9a233 Mon Sep 17 00:00:00 2001 From: Sriniketh J Date: Thu, 2 Apr 2026 18:18:54 +0530 Subject: [PATCH 3/7] fix: update tests --- .../retrievers/test_multi_filter_retriever.py | 114 +++++++++--------- 1 file changed, 57 insertions(+), 57 deletions(-) diff --git a/test/components/retrievers/test_multi_filter_retriever.py b/test/components/retrievers/test_multi_filter_retriever.py index c52c257152..5b6f8b5ae9 100644 --- a/test/components/retrievers/test_multi_filter_retriever.py +++ b/test/components/retrievers/test_multi_filter_retriever.py @@ -1,5 +1,4 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# +# SPDX-FileCopyrightText: 2022-present deepset GmbH # SPDX-License-Identifier: Apache-2.0 from typing import Any @@ -28,86 +27,87 @@ def sample_document_store(sample_documents: list[Document]) -> InMemoryDocumentS return document_store +@pytest.fixture +def sample_filters() -> list[dict[str, Any]]: + return [ + {"field": "meta.lang", "operator": "==", "value": "en"}, + {"field": "meta.lang", "operator": "==", "value": "de"}, + ] + + class TestMultiFilterRetriever: - def test_init_with_default_parameters(self, in_memory_doc_store: InMemoryDocumentStore) -> None: - filter_retriever = FilterRetriever(document_store=in_memory_doc_store) - multi_retriever = MultiFilterRetriever(retriever=filter_retriever) + def test_init_default(self, in_memory_doc_store: InMemoryDocumentStore) -> None: + retriever = FilterRetriever(document_store=in_memory_doc_store) + multi = MultiFilterRetriever(retriever=retriever) + + assert multi.retriever == retriever + assert multi.max_workers == 3 - assert multi_retriever.retriever == filter_retriever - assert multi_retriever.max_workers == 3 + def test_init_with_parameters(self, in_memory_doc_store: InMemoryDocumentStore) -> None: + retriever = FilterRetriever(document_store=in_memory_doc_store) + multi = MultiFilterRetriever(retriever=retriever, max_workers=2) - def test_init_with_custom_parameters(self, in_memory_doc_store: InMemoryDocumentStore) -> None: - filter_retriever = FilterRetriever(document_store=in_memory_doc_store) - multi_retriever = MultiFilterRetriever(retriever=filter_retriever, max_workers=2) + assert multi.max_workers == 2 - assert multi_retriever.retriever == filter_retriever - assert multi_retriever.max_workers == 2 + def test_run_empty_filters(self, in_memory_doc_store: InMemoryDocumentStore) -> None: + multi = MultiFilterRetriever(retriever=FilterRetriever(document_store=in_memory_doc_store)) - def test_run_with_empty_filters(self, in_memory_doc_store: InMemoryDocumentStore) -> None: - multi_retriever = MultiFilterRetriever(retriever=FilterRetriever(document_store=in_memory_doc_store)) - result = multi_retriever.run(filters=[]) + result = multi.run(filters=[]) assert result == {"documents": []} - @pytest.mark.parametrize( - ("filters", "expected_languages"), - [ - ( - [ - {"field": "meta.lang", "operator": "==", "value": "en"}, - {"field": "meta.lang", "operator": "==", "value": "de"}, - ], - {"en", "de"}, - ) - ], - ) - def test_run_with_multiple_filters( - self, sample_document_store: InMemoryDocumentStore, filters: list[dict[str, Any]], expected_languages: set[str] + def test_run_multiple_filters( + self, sample_document_store: InMemoryDocumentStore, sample_filters: list[dict[str, Any]] ) -> None: - filter_retriever = FilterRetriever(document_store=sample_document_store) - multi_retriever = MultiFilterRetriever(retriever=filter_retriever) + multi = MultiFilterRetriever(retriever=FilterRetriever(document_store=sample_document_store)) - result = multi_retriever.run(filters=filters) + result = multi.run(filters=sample_filters) assert "documents" in result - assert {doc.meta["lang"] for doc in result["documents"]} == expected_languages + assert len(result["documents"]) == 2 + assert {doc.meta["lang"] for doc in result["documents"]} == {"en", "de"} - def test_deduplication_with_overlapping_results(self) -> None: - doc1 = Document(content="Solar energy is renewable", id="doc1", score=0.9) - doc2 = Document(content="Wind energy is clean", id="doc2", score=0.8) - doc3 = Document(content="Solar energy is renewable", id="doc1", score=0.7) + def test_run_single_filter(self, sample_document_store: InMemoryDocumentStore) -> None: + multi = MultiFilterRetriever(retriever=FilterRetriever(document_store=sample_document_store)) - call_count = 0 + result = multi.run(filters=[{"field": "meta.lang", "operator": "==", "value": "en"}]) + + assert "documents" in result + assert len(result["documents"]) == 1 + assert result["documents"][0].meta["lang"] == "en" + + def test_deduplication(self) -> None: + doc1 = Document(content="A", id="doc1", score=0.9) + doc2 = Document(content="B", id="doc2", score=0.8) + doc3 = Document(content="A", id="doc1", score=0.7) @component class MockRetriever: @component.output_types(documents=list[Document]) def run(self, filters: dict[str, Any] | None = None, **kwargs: Any) -> dict[str, list[Document]]: - nonlocal call_count - call_count += 1 - if call_count == 1: - return {"documents": [doc1, doc2]} - return {"documents": [doc3, doc2]} + return {"documents": [doc1, doc2, doc3]} - multi_retriever = MultiFilterRetriever(retriever=MockRetriever(), max_workers=1) + multi = MultiFilterRetriever(retriever=MockRetriever(), max_workers=1) - result = multi_retriever.run( - filters=[ - {"field": "meta.lang", "operator": "==", "value": "en"}, - {"field": "meta.lang", "operator": "==", "value": "de"}, - ] - ) + result = multi.run(filters=[{}, {}]) - assert "documents" in result assert len(result["documents"]) == 2 - assert [doc.content for doc in result["documents"]].count("Solar energy is renewable") == 1 - assert [doc.content for doc in result["documents"]].count("Wind energy is clean") == 1 + assert {doc.id for doc in result["documents"]} == {"doc1", "doc2"} + + def test_to_dict(self, in_memory_doc_store: InMemoryDocumentStore) -> None: + retriever = FilterRetriever(document_store=in_memory_doc_store) + multi = MultiFilterRetriever(retriever=retriever, max_workers=2) + + data = multi.to_dict() + + assert data["type"] == "haystack.components.retrievers.multi_filter_retriever.MultiFilterRetriever" + assert data["init_parameters"]["max_workers"] == 2 - def test_from_dict_roundtrip(self, in_memory_doc_store: InMemoryDocumentStore) -> None: - filter_retriever = FilterRetriever(document_store=in_memory_doc_store) - multi_retriever = MultiFilterRetriever(retriever=filter_retriever, max_workers=2) + def test_from_dict(self, in_memory_doc_store: InMemoryDocumentStore) -> None: + retriever = FilterRetriever(document_store=in_memory_doc_store) + multi = MultiFilterRetriever(retriever=retriever, max_workers=2) - serialized = multi_retriever.to_dict() + serialized = multi.to_dict() deserialized = MultiFilterRetriever.from_dict(serialized) assert isinstance(deserialized, MultiFilterRetriever) From ff230b987041a40e0a20188ef7d638fce86de1d5 Mon Sep 17 00:00:00 2001 From: Sriniketh J Date: Thu, 9 Apr 2026 13:17:52 +0530 Subject: [PATCH 4/7] fix: remove protocol usage --- .../components/retrievers/types/protocol.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/haystack/components/retrievers/types/protocol.py b/haystack/components/retrievers/types/protocol.py index ed709f266a..fd2bb06664 100644 --- a/haystack/components/retrievers/types/protocol.py +++ b/haystack/components/retrievers/types/protocol.py @@ -54,22 +54,3 @@ def run( `documents`: List of retrieved documents sorted by relevance score. """ ... - - -class FilterRetriever(Protocol): - """ - This protocol defines the minimal interface for retrievers that filter documents. - """ - - def run(self, filters: dict[str, Any] | None = None, **kwargs: Any) -> dict[str, Any]: - """ - Retrieve documents matching the provided filters. - - Implementing classes may accept additional optional parameters in their run method. - - :param filters: A dictionary of filters to apply when retrieving documents. - :returns: - A dictionary containing: - `documents`: List of retrieved documents. - """ - ... From 57998433075559d14fae42589ac4378ad1aa9d8c Mon Sep 17 00:00:00 2001 From: Sriniketh J Date: Thu, 9 Apr 2026 13:18:14 +0530 Subject: [PATCH 5/7] fix: remove warm_up code --- .../components/retrievers/multi_filter_retriever.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/haystack/components/retrievers/multi_filter_retriever.py b/haystack/components/retrievers/multi_filter_retriever.py index 54ee7f0994..0f5af0f6b5 100644 --- a/haystack/components/retrievers/multi_filter_retriever.py +++ b/haystack/components/retrievers/multi_filter_retriever.py @@ -60,16 +60,6 @@ def __init__(self, *, retriever: FilterRetriever, max_workers: int = 3) -> None: """ self.retriever = retriever self.max_workers = max_workers - self._is_warmed_up = False - - def warm_up(self) -> None: - """ - Warm up the retriever if it has a warm_up method. - """ - if not self._is_warmed_up: - if hasattr(self.retriever, "warm_up") and callable(self.retriever.warm_up): - self.retriever.warm_up() - self._is_warmed_up = True @component.output_types(documents=list[Document]) def run( @@ -87,9 +77,6 @@ def run( docs: list[Document] = [] retriever_kwargs = retriever_kwargs or {} - if not self._is_warmed_up: - self.warm_up() - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: filters_results = executor.map(lambda filt: self._run_on_thread(filt, retriever_kwargs), filters) for result in filters_results: From 734e1dd0b3d3cd8178395e2107c6aefad67ffc63 Mon Sep 17 00:00:00 2001 From: Sriniketh J Date: Thu, 9 Apr 2026 13:36:20 +0530 Subject: [PATCH 6/7] fix: remove kwargs, score --- .../retrievers/multi_filter_retriever.py | 46 ++++--------------- .../components/retrievers/types/__init__.py | 4 +- ...lti-filter-retriever-266283c61e693da3.yaml | 2 +- .../retrievers/test_multi_filter_retriever.py | 10 ++-- 4 files changed, 18 insertions(+), 44 deletions(-) diff --git a/haystack/components/retrievers/multi_filter_retriever.py b/haystack/components/retrievers/multi_filter_retriever.py index 0f5af0f6b5..167e28d900 100644 --- a/haystack/components/retrievers/multi_filter_retriever.py +++ b/haystack/components/retrievers/multi_filter_retriever.py @@ -5,9 +5,8 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any -from haystack import Document, component, default_from_dict, default_to_dict -from haystack.components.retrievers.types import FilterRetriever -from haystack.core.serialization import component_to_dict +from haystack import Document, component +from haystack.components.retrievers.filter_retriever import FilterRetriever from haystack.utils.misc import _deduplicate_documents @@ -17,7 +16,7 @@ class MultiFilterRetriever: A component that retrieves documents using multiple filters in parallel. This component takes a list of filter dictionaries and uses a filter-capable retriever to retrieve matching - documents for each filter set in parallel. The results are combined, de-duplicated, and sorted by score. + documents for each filter set in parallel. ### Usage example @@ -62,9 +61,7 @@ def __init__(self, *, retriever: FilterRetriever, max_workers: int = 3) -> None: self.max_workers = max_workers @component.output_types(documents=list[Document]) - def run( - self, filters: list[dict[str, Any]], retriever_kwargs: dict[str, Any] | None = None - ) -> dict[str, list[Document]]: + def run(self, filters: list[dict[str, Any]]) -> dict[str, list[Document]]: """ Retrieve documents using multiple filters in parallel. @@ -72,25 +69,22 @@ def run( :param retriever_kwargs: Optional dictionary of arguments to pass to the retriever's run method. :returns: A dictionary containing: - - `documents`: List of retrieved documents sorted by relevance score. + - `documents`: List of retrieved documents. """ docs: list[Document] = [] - retriever_kwargs = retriever_kwargs or {} with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - filters_results = executor.map(lambda filt: self._run_on_thread(filt, retriever_kwargs), filters) + filters_results = executor.map(self._run_on_thread, filters) for result in filters_results: if not result: continue docs.extend(result) docs = _deduplicate_documents(docs) - docs.sort(key=lambda x: x.score or 0.0, reverse=True) + return {"documents": docs} - def _run_on_thread( - self, filters: dict[str, Any], retriever_kwargs: dict[str, Any] | None = None - ) -> list[Document] | None: + def _run_on_thread(self, filters: dict[str, Any]) -> list[Document] | None: """ Process a single filter set on a separate thread. @@ -99,29 +93,7 @@ def _run_on_thread( :returns: List of retrieved documents or None if no results. """ - result = self.retriever.run(filters=filters, **(retriever_kwargs or {})) + result = self.retriever.run(filters=filters) if result and "documents" in result: return result["documents"] return None - - def to_dict(self) -> dict[str, Any]: - """ - Serializes the component to a dictionary. - - :returns: - A dictionary representing the serialized component. - """ - return default_to_dict( - self, retriever=component_to_dict(obj=self.retriever, name="retriever"), max_workers=self.max_workers - ) - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> "MultiFilterRetriever": - """ - Deserializes the component from a dictionary. - - :param data: The dictionary to deserialize from. - :returns: - The deserialized component. - """ - return default_from_dict(cls, data) diff --git a/haystack/components/retrievers/types/__init__.py b/haystack/components/retrievers/types/__init__.py index 53ec088277..5afc417a7f 100644 --- a/haystack/components/retrievers/types/__init__.py +++ b/haystack/components/retrievers/types/__init__.py @@ -2,6 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -from .protocol import EmbeddingRetriever, FilterRetriever, TextRetriever +from .protocol import EmbeddingRetriever, TextRetriever -__all__ = ["TextRetriever", "EmbeddingRetriever", "FilterRetriever"] +__all__ = ["TextRetriever", "EmbeddingRetriever"] diff --git a/releasenotes/notes/feat-multi-filter-retriever-266283c61e693da3.yaml b/releasenotes/notes/feat-multi-filter-retriever-266283c61e693da3.yaml index 878155020b..12ab0c2521 100644 --- a/releasenotes/notes/feat-multi-filter-retriever-266283c61e693da3.yaml +++ b/releasenotes/notes/feat-multi-filter-retriever-266283c61e693da3.yaml @@ -3,7 +3,7 @@ features: - | Add support for ``MultiFilterRetriever``, a new retriever component that executes multiple filter queries against a document store **in parallel** and returns a single, de-duplicated list of - documents sorted by relevance score. + documents. .. code-block:: python diff --git a/test/components/retrievers/test_multi_filter_retriever.py b/test/components/retrievers/test_multi_filter_retriever.py index 5b6f8b5ae9..b7d0e1e075 100644 --- a/test/components/retrievers/test_multi_filter_retriever.py +++ b/test/components/retrievers/test_multi_filter_retriever.py @@ -6,8 +6,10 @@ import pytest from haystack import Document, component -from haystack.components.retrievers import FilterRetriever, MultiFilterRetriever +from haystack.components.retrievers.filter_retriever import FilterRetriever +from haystack.components.retrievers.multi_filter_retriever import MultiFilterRetriever from haystack.components.writers import DocumentWriter +from haystack.core.serialization import component_from_dict, component_to_dict from haystack.document_stores.in_memory import InMemoryDocumentStore from haystack.document_stores.types import DuplicatePolicy @@ -98,7 +100,7 @@ def test_to_dict(self, in_memory_doc_store: InMemoryDocumentStore) -> None: retriever = FilterRetriever(document_store=in_memory_doc_store) multi = MultiFilterRetriever(retriever=retriever, max_workers=2) - data = multi.to_dict() + data = component_to_dict(multi, "multi_filter") assert data["type"] == "haystack.components.retrievers.multi_filter_retriever.MultiFilterRetriever" assert data["init_parameters"]["max_workers"] == 2 @@ -107,8 +109,8 @@ def test_from_dict(self, in_memory_doc_store: InMemoryDocumentStore) -> None: retriever = FilterRetriever(document_store=in_memory_doc_store) multi = MultiFilterRetriever(retriever=retriever, max_workers=2) - serialized = multi.to_dict() - deserialized = MultiFilterRetriever.from_dict(serialized) + serialized = component_to_dict(multi, "multi_filter") + deserialized = component_from_dict(MultiFilterRetriever, serialized, "multi_filter") assert isinstance(deserialized, MultiFilterRetriever) assert deserialized.max_workers == 2 From f361b6ebd536d98c5fc4de956212149cd5b62e22 Mon Sep 17 00:00:00 2001 From: Sriniketh J Date: Thu, 9 Apr 2026 14:44:21 +0530 Subject: [PATCH 7/7] fix: header format --- test/components/retrievers/test_multi_filter_retriever.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/components/retrievers/test_multi_filter_retriever.py b/test/components/retrievers/test_multi_filter_retriever.py index b7d0e1e075..dc43290c2f 100644 --- a/test/components/retrievers/test_multi_filter_retriever.py +++ b/test/components/retrievers/test_multi_filter_retriever.py @@ -1,4 +1,5 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# # SPDX-License-Identifier: Apache-2.0 from typing import Any