Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions haystack/components/retrievers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"auto_merging_retriever": ["AutoMergingRetriever"],
"filter_retriever": ["FilterRetriever"],
"in_memory": ["InMemoryBM25Retriever", "InMemoryEmbeddingRetriever"],
"multi_filter_retriever": ["MultiFilterRetriever"],
"multi_query_embedding_retriever": ["MultiQueryEmbeddingRetriever"],
"multi_query_text_retriever": ["MultiQueryTextRetriever"],
"sentence_window_retriever": ["SentenceWindowRetriever"],
Expand All @@ -21,6 +22,7 @@
from .filter_retriever import FilterRetriever as FilterRetriever
from .in_memory import InMemoryBM25Retriever as InMemoryBM25Retriever
from .in_memory import InMemoryEmbeddingRetriever as InMemoryEmbeddingRetriever
from .multi_filter_retriever import MultiFilterRetriever as MultiFilterRetriever
from .multi_query_embedding_retriever import MultiQueryEmbeddingRetriever as MultiQueryEmbeddingRetriever
from .multi_query_text_retriever import MultiQueryTextRetriever as MultiQueryTextRetriever
from .sentence_window_retriever import SentenceWindowRetriever as SentenceWindowRetriever
Expand Down
99 changes: 99 additions & 0 deletions haystack/components/retrievers/multi_filter_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

from concurrent.futures import ThreadPoolExecutor
from typing import Any

from haystack import Document, component
from haystack.components.retrievers.filter_retriever import FilterRetriever
from haystack.utils.misc import _deduplicate_documents


@component
class MultiFilterRetriever:
"""
A component that retrieves documents using multiple filters in parallel.

This component takes a list of filter dictionaries and uses a filter-capable retriever to retrieve matching
documents for each filter set in parallel.

### Usage example

```python
from haystack import Document
from haystack.components.retrievers import FilterRetriever, MultiFilterRetriever
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.writers import DocumentWriter
from haystack.document_stores.types import DuplicatePolicy

documents = [
Document(content="Python is a popular programming language", meta={"lang": "en"}),
Document(content="python ist eine beliebte Programmiersprache", meta={"lang": "de"}),
]

document_store = InMemoryDocumentStore()
writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP)
writer.run(documents=documents)

filter_retriever = FilterRetriever(document_store=document_store)
multi_filter_retriever = MultiFilterRetriever(retriever=filter_retriever)

filters = [
{"field": "meta.lang", "operator": "==", "value": "en"},
{"field": "meta.lang", "operator": "==", "value": "de"},
]

result = multi_filter_retriever.run(filters=filters)
for doc in result["documents"]:
print(doc.content)
```
"""

def __init__(self, *, retriever: FilterRetriever, max_workers: int = 3) -> None:
"""
Initialize MultiFilterRetriever.

:param retriever: The filter-capable retriever to use for document retrieval.
:param max_workers: Maximum number of worker threads for parallel processing.
"""
self.retriever = retriever
self.max_workers = max_workers

@component.output_types(documents=list[Document])
def run(self, filters: list[dict[str, Any]]) -> dict[str, list[Document]]:
"""
Retrieve documents using multiple filters in parallel.

:param filters: List of filter dictionaries to process.
:param retriever_kwargs: Optional dictionary of arguments to pass to the retriever's run method.
:returns:
A dictionary containing:
- `documents`: List of retrieved documents.
"""
docs: list[Document] = []

with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
filters_results = executor.map(self._run_on_thread, filters)
for result in filters_results:
if not result:
continue
docs.extend(result)

docs = _deduplicate_documents(docs)

return {"documents": docs}

def _run_on_thread(self, filters: dict[str, Any]) -> list[Document] | None:
"""
Process a single filter set on a separate thread.

:param filters: The filter dictionary to process.
:param retriever_kwargs: Optional dictionary of arguments to pass to the retriever's run method.
:returns:
List of retrieved documents or None if no results.
"""
result = self.retriever.run(filters=filters)
if result and "documents" in result:
return result["documents"]
return None
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
---
features:
- |
Add support for ``MultiFilterRetriever``, a new retriever component that executes multiple filter
queries against a document store **in parallel** and returns a single, de-duplicated list of
documents.

.. code-block:: python

filter_retriever = FilterRetriever(document_store=document_store)
multi_filter_retriever = MultiFilterRetriever(retriever=filter_retriever)

filters = [
{"field": "meta.lang", "operator": "==", "value": "en"},
{"field": "meta.lang", "operator": "==", "value": "de"},
]

result = multi_filter_retriever.run(filters=filters)
for doc in result["documents"]:
print(doc.content)
117 changes: 117 additions & 0 deletions test/components/retrievers/test_multi_filter_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any

import pytest

from haystack import Document, component
from haystack.components.retrievers.filter_retriever import FilterRetriever
from haystack.components.retrievers.multi_filter_retriever import MultiFilterRetriever
from haystack.components.writers import DocumentWriter
from haystack.core.serialization import component_from_dict, component_to_dict
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.document_stores.types import DuplicatePolicy


@pytest.fixture
def sample_documents() -> list[Document]:
return [
Document(content="English text", id="doc1", meta={"lang": "en"}),
Document(content="German text", id="doc2", meta={"lang": "de"}),
]


@pytest.fixture
def sample_document_store(sample_documents: list[Document]) -> InMemoryDocumentStore:
document_store = InMemoryDocumentStore()
DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP).run(documents=sample_documents)
return document_store


@pytest.fixture
def sample_filters() -> list[dict[str, Any]]:
return [
{"field": "meta.lang", "operator": "==", "value": "en"},
{"field": "meta.lang", "operator": "==", "value": "de"},
]


class TestMultiFilterRetriever:
def test_init_default(self, in_memory_doc_store: InMemoryDocumentStore) -> None:
retriever = FilterRetriever(document_store=in_memory_doc_store)
multi = MultiFilterRetriever(retriever=retriever)

assert multi.retriever == retriever
assert multi.max_workers == 3

def test_init_with_parameters(self, in_memory_doc_store: InMemoryDocumentStore) -> None:
retriever = FilterRetriever(document_store=in_memory_doc_store)
multi = MultiFilterRetriever(retriever=retriever, max_workers=2)

assert multi.max_workers == 2

def test_run_empty_filters(self, in_memory_doc_store: InMemoryDocumentStore) -> None:
multi = MultiFilterRetriever(retriever=FilterRetriever(document_store=in_memory_doc_store))

result = multi.run(filters=[])

assert result == {"documents": []}

def test_run_multiple_filters(
self, sample_document_store: InMemoryDocumentStore, sample_filters: list[dict[str, Any]]
) -> None:
multi = MultiFilterRetriever(retriever=FilterRetriever(document_store=sample_document_store))

result = multi.run(filters=sample_filters)

assert "documents" in result
assert len(result["documents"]) == 2
assert {doc.meta["lang"] for doc in result["documents"]} == {"en", "de"}

def test_run_single_filter(self, sample_document_store: InMemoryDocumentStore) -> None:
multi = MultiFilterRetriever(retriever=FilterRetriever(document_store=sample_document_store))

result = multi.run(filters=[{"field": "meta.lang", "operator": "==", "value": "en"}])

assert "documents" in result
assert len(result["documents"]) == 1
assert result["documents"][0].meta["lang"] == "en"

def test_deduplication(self) -> None:
doc1 = Document(content="A", id="doc1", score=0.9)
doc2 = Document(content="B", id="doc2", score=0.8)
doc3 = Document(content="A", id="doc1", score=0.7)

@component
class MockRetriever:
@component.output_types(documents=list[Document])
def run(self, filters: dict[str, Any] | None = None, **kwargs: Any) -> dict[str, list[Document]]:
return {"documents": [doc1, doc2, doc3]}

multi = MultiFilterRetriever(retriever=MockRetriever(), max_workers=1)

result = multi.run(filters=[{}, {}])

assert len(result["documents"]) == 2
assert {doc.id for doc in result["documents"]} == {"doc1", "doc2"}

def test_to_dict(self, in_memory_doc_store: InMemoryDocumentStore) -> None:
retriever = FilterRetriever(document_store=in_memory_doc_store)
multi = MultiFilterRetriever(retriever=retriever, max_workers=2)

data = component_to_dict(multi, "multi_filter")

assert data["type"] == "haystack.components.retrievers.multi_filter_retriever.MultiFilterRetriever"
assert data["init_parameters"]["max_workers"] == 2

def test_from_dict(self, in_memory_doc_store: InMemoryDocumentStore) -> None:
retriever = FilterRetriever(document_store=in_memory_doc_store)
multi = MultiFilterRetriever(retriever=retriever, max_workers=2)

serialized = component_to_dict(multi, "multi_filter")
deserialized = component_from_dict(MultiFilterRetriever, serialized, "multi_filter")

assert isinstance(deserialized, MultiFilterRetriever)
assert deserialized.max_workers == 2