Skip to content

Commit 744198c

Browse files
authored
feat: Add reciprocal rank fusion to MultiRetriever (#11220)
1 parent 8887691 commit 744198c

8 files changed

Lines changed: 187 additions & 95 deletions

File tree

docs-website/docs/pipeline-components/retrievers/multiretriever.mdx

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
title: "MultiRetriever"
33
id: multiretriever
44
slug: "/multiretriever"
5-
description: "Runs multiple text retrievers in parallel and combines their deduplicated results."
5+
description: "Runs multiple text retrievers in parallel and combines their results using reciprocal rank fusion or deduplication."
66
---
77

88
# MultiRetriever
99

10-
Runs multiple text retrievers in parallel and combines their deduplicated results.
10+
Runs multiple text retrievers in parallel and combines their results using reciprocal rank fusion or deduplication.
1111

1212
:::warning[Experimental]
1313

@@ -21,8 +21,9 @@ Runs multiple text retrievers in parallel and combines their deduplicated result
2121
| --- | --- |
2222
| **Most common position in a pipeline** | After query input, before a [`ChatPromptBuilder`](../builders/chatpromptbuilder.mdx) in RAG pipelines |
2323
| **Mandatory init variables** | `retrievers`: A dictionary mapping names to text retrievers (implementing the `TextRetriever` protocol) |
24+
| **Optional init variables** | `join_mode`: `"reciprocal_rank_fusion"` (default) or `"concatenate"` |
2425
| **Mandatory run variables** | `query`: A query string |
25-
| **Output variables** | `documents`: A deduplicated list of retrieved documents |
26+
| **Output variables** | `documents`: A merged list of retrieved documents |
2627
| **API reference** | [Retrievers](/reference/retrievers-api) |
2728
| **GitHub link** | https://github.com/deepset-ai/haystack/blob/main/haystack/components/retrievers/multi_retriever.py |
2829
| **Package name** | `haystack-ai` |
@@ -31,20 +32,27 @@ Runs multiple text retrievers in parallel and combines their deduplicated result
3132

3233
## Overview
3334

34-
`MultiRetriever` composes any number of text retrievers into a single component. All retrievers are queried in parallel using a thread pool, and their results are deduplicated before being returned.
35+
`MultiRetriever` composes any number of text retrievers into a single component. All retrievers are queried in parallel using a thread pool, and their results are merged before being returned.
3536

3637
The component:
3738
- Queries all retrievers concurrently for better performance
38-
- Automatically deduplicates results across retrievers
39+
- Merges results across retrievers using the configured `join_mode`
3940
- Supports selectively enabling retrievers at runtime via `active_retrievers`
4041

4142
All retrievers passed to `MultiRetriever` must implement the `TextRetriever` protocol — their `run` method must accept a text `query`, `filters`, and `top_k`. Use [`TextEmbeddingRetriever`](textembeddingretriever.mdx) to wrap an embedding-based retriever so it can be used with this component.
4243

44+
### Join modes
45+
46+
The `join_mode` parameter controls how results from multiple retrievers are merged:
47+
48+
- **`reciprocal_rank_fusion`** (default): Assigns scores based on each document's rank across retrieval lists using the [Reciprocal Rank Fusion](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) algorithm. Documents appearing highly ranked in multiple lists receive higher scores. Results are deduplicated and returned in descending score order. This is the recommended mode when combining retrievers with incomparable scores, such as BM25 and embedding retrievers.
49+
- **`concatenate`**: Combines all results into a single list and deduplicates.
50+
4351
## Usage
4452

4553
### On its own
4654

47-
This example sets up a `MultiRetriever` combining a BM25 retriever and an embedding-based retriever (wrapped with `TextEmbeddingRetriever`). Both are queried in parallel and the deduplicated results are returned.
55+
This example sets up a `MultiRetriever` combining a BM25 retriever and an embedding-based retriever (wrapped with `TextEmbeddingRetriever`). Both are queried in parallel and the results are merged using reciprocal rank fusion.
4856

4957
```python
5058
from haystack import Document

haystack/components/joiners/document_joiner.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from haystack import Document, component, default_from_dict, default_to_dict, logging
1313
from haystack.core.component.types import Variadic
14+
from haystack.utils.misc import _reciprocal_rank_fusion
1415

1516
logger = logging.getLogger(__name__)
1617

@@ -117,7 +118,7 @@ def __init__(
117118
join_mode_functions = {
118119
JoinMode.CONCATENATE: DocumentJoiner._concatenate,
119120
JoinMode.MERGE: self._merge,
120-
JoinMode.RECIPROCAL_RANK_FUSION: self._reciprocal_rank_fusion,
121+
JoinMode.RECIPROCAL_RANK_FUSION: self._rrf,
121122
JoinMode.DISTRIBUTION_BASED_RANK_FUSION: DocumentJoiner._distribution_based_rank_fusion,
122123
}
123124
self.join_mode_function = join_mode_functions[join_mode]
@@ -193,35 +194,11 @@ def _merge(self, document_lists: list[list[Document]]) -> list[Document]:
193194

194195
return [replace(doc, score=scores_map[doc.id]) for doc in documents_map.values()]
195196

196-
def _reciprocal_rank_fusion(self, document_lists: list[list[Document]]) -> list[Document]:
197+
def _rrf(self, document_lists: list[list[Document]]) -> list[Document]:
197198
"""
198199
Merge multiple lists of Documents and assign scores based on reciprocal rank fusion.
199-
200-
The constant k is set to 61 (60 was suggested by the original paper,
201-
plus 1 as python lists are 0-based and the paper used 1-based ranking).
202200
"""
203-
# This check prevents a division by zero when no documents are passed
204-
if not document_lists:
205-
return []
206-
207-
k = 61
208-
209-
scores_map: dict = defaultdict(int)
210-
documents_map = {}
211-
weights = self.weights if self.weights else [1 / len(document_lists)] * len(document_lists)
212-
213-
# Calculate weighted reciprocal rank fusion score
214-
for documents, weight in zip(document_lists, weights, strict=True):
215-
for rank, doc in enumerate(documents):
216-
scores_map[doc.id] += (weight * len(document_lists)) / (k + rank)
217-
documents_map[doc.id] = doc
218-
219-
# Normalize scores. Note: len(results) / k is the maximum possible score,
220-
# achieved by being ranked first in all doc lists with non-zero weight.
221-
for _id in scores_map:
222-
scores_map[_id] /= len(document_lists) / k
223-
224-
return [replace(doc, score=scores_map[doc.id]) for doc in documents_map.values()]
201+
return _reciprocal_rank_fusion(document_lists, weights=self.weights)
225202

226203
@staticmethod
227204
def _distribution_based_rank_fusion(document_lists: list[list[Document]]) -> list[Document]:

haystack/components/retrievers/multi_retriever.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44

55
import asyncio
66
from concurrent.futures import ThreadPoolExecutor, as_completed
7-
from typing import Any
7+
from math import inf
8+
from typing import Any, Literal
89

910
from haystack import component, default_from_dict, default_to_dict
1011
from haystack.components.retrievers.types.protocol import TextRetriever
1112
from haystack.core.serialization import component_from_dict, component_to_dict, import_class_by_name
1213
from haystack.dataclasses import Document
1314
from haystack.utils.experimental import _experimental
14-
from haystack.utils.misc import _deduplicate_documents
15+
from haystack.utils.misc import _deduplicate_documents, _reciprocal_rank_fusion
1516

1617

1718
@_experimental
@@ -82,6 +83,7 @@ def __init__(
8283
filters: dict[str, Any] | None = None,
8384
top_k: int = 10,
8485
max_workers: int = 4,
86+
join_mode: Literal["concatenate", "reciprocal_rank_fusion"] = "reciprocal_rank_fusion",
8587
) -> None:
8688
"""
8789
Create the MultiRetriever component.
@@ -95,13 +97,30 @@ def __init__(
9597
The maximum number of documents to return per retriever.
9698
:param max_workers:
9799
The maximum number of threads to use for parallel retrieval.
100+
:param join_mode:
101+
How to merge results from multiple retrievers. Available modes:
102+
- `concatenate`: Combines all results into a single list and deduplicates.
103+
- `reciprocal_rank_fusion`: Deduplicates and assigns scores based on reciprocal rank fusion.
98104
"""
99105
self.retrievers = retrievers
100106
self.filters = filters
101107
self.top_k = top_k
102108
self.max_workers = max_workers
109+
self.join_mode = join_mode
103110
self._is_warmed_up = False
104111

112+
def _merge_results(self, document_lists: list[list[Document]]) -> list[Document]:
113+
"""
114+
Merge per-retriever result lists according to `join_mode`.
115+
116+
In `concatenate` mode, all lists are flattened and deduplicated. In `reciprocal_rank_fusion` mode, results
117+
are deduplicated and re-scored using RRF, then returned in descending score order.
118+
"""
119+
if self.join_mode == "reciprocal_rank_fusion":
120+
documents = _reciprocal_rank_fusion(document_lists)
121+
return sorted(documents, key=lambda d: d.score if d.score is not None else -inf, reverse=True)
122+
return _deduplicate_documents([doc for docs in document_lists for doc in docs])
123+
105124
def _resolve_retrievers(self, active_retrievers: list[str] | None) -> dict[str, TextRetriever]:
106125
"""
107126
Returns the subset of retrievers to run based on the active_retrievers list.
@@ -171,7 +190,7 @@ def run(
171190

172191
retrievers_to_run = self._resolve_retrievers(active_retrievers)
173192

174-
all_documents: list[Document] = []
193+
results_by_name: dict[str, list[Document]] = {}
175194
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
176195
future_to_name = {
177196
executor.submit(retriever.run, query=query, filters=resolved_filters, top_k=resolved_top_k): name
@@ -180,11 +199,12 @@ def run(
180199
for future in as_completed(future_to_name):
181200
name = future_to_name[future]
182201
try:
183-
all_documents.extend(future.result().get("documents", []))
202+
results_by_name[name] = future.result().get("documents", [])
184203
except Exception as e:
185204
raise RuntimeError(f"Retriever '{name}' failed: {e}") from e
186205

187-
return {"documents": _deduplicate_documents(all_documents)}
206+
document_lists = [results_by_name[name] for name in retrievers_to_run]
207+
return {"documents": self._merge_results(document_lists)}
188208

189209
@component.output_types(documents=list[Document])
190210
async def run_async(
@@ -238,13 +258,8 @@ async def _run_one(name: str, retriever: TextRetriever) -> list[Document]:
238258
except Exception as e:
239259
raise RuntimeError(f"Retriever '{name}' failed: {e}") from e
240260

241-
results = await asyncio.gather(*[_run_one(name, r) for name, r in retrievers_to_run.items()])
242-
243-
all_documents: list[Document] = []
244-
for docs in results:
245-
all_documents.extend(docs)
246-
247-
return {"documents": _deduplicate_documents(all_documents)}
261+
document_lists = list(await asyncio.gather(*[_run_one(name, r) for name, r in retrievers_to_run.items()]))
262+
return {"documents": self._merge_results(document_lists)}
248263

249264
def to_dict(self) -> dict[str, Any]:
250265
"""
@@ -259,6 +274,7 @@ def to_dict(self) -> dict[str, Any]:
259274
filters=self.filters,
260275
top_k=self.top_k,
261276
max_workers=self.max_workers,
277+
join_mode=self.join_mode,
262278
)
263279

264280
@classmethod

haystack/utils/misc.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import json
66
import mimetypes
77
import tempfile
8+
from collections import defaultdict
9+
from dataclasses import replace
810
from math import inf
911
from pathlib import Path
1012
from typing import TYPE_CHECKING, Any, Literal, overload
@@ -145,6 +147,41 @@ def _deduplicate_documents(documents: list["Document"]) -> list["Document"]:
145147
return list(highest_scoring_docs.values())
146148

147149

150+
def _reciprocal_rank_fusion(
151+
document_lists: list[list["Document"]], weights: list[float] | None = None
152+
) -> list["Document"]:
153+
"""
154+
Merge multiple ranked lists of Documents using Reciprocal Rank Fusion, deduplicating across lists.
155+
156+
See the original paper: https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
157+
158+
The constant k is set to 61 (60 was suggested by the original paper, plus 1 as python lists are 0-based and the
159+
paper used 1-based ranking).
160+
161+
:param document_lists: A list of ranked document lists to fuse.
162+
:param weights: Optional per-list weights. Defaults to equal weights.
163+
:returns:
164+
Deduplicated list of documents with updated RRF scores.
165+
"""
166+
if not document_lists:
167+
return []
168+
169+
k = 61
170+
scores_map: dict = defaultdict(int)
171+
documents_map: dict = {}
172+
resolved_weights = weights if weights else [1 / len(document_lists)] * len(document_lists)
173+
174+
for documents, weight in zip(document_lists, resolved_weights, strict=True):
175+
for rank, doc in enumerate(documents):
176+
scores_map[doc.id] += (weight * len(document_lists)) / (k + rank)
177+
documents_map[doc.id] = doc
178+
179+
for _id in scores_map:
180+
scores_map[_id] /= len(document_lists) / k
181+
182+
return [replace(doc, score=scores_map[doc.id]) for doc in documents_map.values()]
183+
184+
148185
@overload
149186
def _parse_dict_from_json(
150187
text: str, expected_keys: list[str] | None = ..., raise_on_failure: Literal[True] = ...
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
---
2+
enhancements:
3+
- |
4+
Add ``join_mode`` parameter to the experimental ``MultiRetriever`` component, supporting
5+
``"reciprocal_rank_fusion"`` (default) and ``"concatenate"``. Reciprocal Rank Fusion merges
6+
the ranked result lists from all retrievers into a single deduplicated list ordered by RRF score.
7+
The underlying RRF logic is extracted into a shared utility ``_reciprocal_rank_fusion`` in
8+
``haystack.utils.misc``, which is now also used by ``DocumentJoiner``.

test/components/joiners/test_document_joiner.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -270,43 +270,3 @@ def test_output_documents_not_sorted_by_score(self):
270270
documents_2 = [Document(content="d", score=0.2)]
271271
output = joiner.run([documents_1, documents_2])
272272
assert output["documents"] == documents_1 + documents_2
273-
274-
def test_test_score_norm_with_rrf(self):
275-
"""
276-
Verifies reciprocal rank fusion (RRF) of the DocumentJoiner component with various weight configurations.
277-
It creates a set of documents, forms them into two lists, and then applies multiple DocumentJoiner
278-
instances with distinct weights to these lists. The test checks if the resulting
279-
joined documents are correctly sorted in descending order by score, ensuring the RRF ranking works as
280-
expected under different weighting scenarios.
281-
"""
282-
num_docs = 6
283-
docs = []
284-
285-
for i in range(num_docs):
286-
docs.append(Document(content=f"doc{i}"))
287-
288-
docs_2 = [docs[0], docs[4], docs[2], docs[5], docs[1]]
289-
document_lists = [docs, docs_2]
290-
291-
joiner_1 = DocumentJoiner(join_mode="reciprocal_rank_fusion", weights=[0.5, 0.5])
292-
293-
joiner_2 = DocumentJoiner(join_mode="reciprocal_rank_fusion", weights=[7, 7])
294-
295-
joiner_3 = DocumentJoiner(join_mode="reciprocal_rank_fusion", weights=[0.7, 0.3])
296-
297-
joiner_4 = DocumentJoiner(join_mode="reciprocal_rank_fusion", weights=[0.6, 0.4])
298-
299-
joiner_5 = DocumentJoiner(join_mode="reciprocal_rank_fusion", weights=[1, 0])
300-
301-
joiners = [joiner_1, joiner_2, joiner_3, joiner_4, joiner_5]
302-
303-
for joiner in joiners:
304-
join_results = joiner.run(documents=document_lists)
305-
is_sorted = all(
306-
join_results["documents"][i].score >= join_results["documents"][i + 1].score
307-
for i in range(len(join_results["documents"]) - 1)
308-
)
309-
310-
assert is_sorted, (
311-
"Documents are not sorted in descending order by score, there is an issue with rff ranking"
312-
)

0 commit comments

Comments
 (0)