Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ In contrast to other rankers, `LostInTheMiddleRanker` assumes that the input doc

If you specify the `word_count_threshold` when running the component, the Ranker includes all documents up until the point where adding another document would exceed the given threshold. The last document that exceeds the threshold will be included in the resulting list of Documents, but all following documents will be discarded.

By default, `word_count_threshold` counts whitespace-separated words. You can set `count_mode` to `word`, `char`, or `token` to count the threshold in words, characters, or tokens. Token counting uses the `tokenizer_encoding` setting and defaults to the `o200k_base` encoding.

You can also specify the `top_k` parameter to set the maximum number of documents to return.

## Usage
Expand Down Expand Up @@ -88,7 +90,7 @@ document_store = InMemoryDocumentStore()
document_store.write_documents(docs)

retriever = InMemoryBM25Retriever(document_store=document_store)
ranker = LostInTheMiddleRanker(word_count_threshold=1024)
ranker = LostInTheMiddleRanker(word_count_threshold=1024, count_mode="token")
prompt_builder = ChatPromptBuilder(
template=prompt_template,
required_variables={"query", "documents"},
Expand Down
84 changes: 68 additions & 16 deletions haystack/components/rankers/lost_in_the_middle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
# SPDX-License-Identifier: Apache-2.0


from typing import Literal

from haystack import Document, component
from haystack.lazy_imports import LazyImport
from haystack.utils.misc import _deduplicate_documents

with LazyImport("Run 'pip install tiktoken'") as tiktoken_imports:
import tiktoken


@component
class LostInTheMiddleRanker:
Expand Down Expand Up @@ -37,7 +43,14 @@ class LostInTheMiddleRanker:
```
"""

def __init__(self, word_count_threshold: int | None = None, top_k: int | None = None) -> None:
def __init__(
self,
word_count_threshold: int | None = None,
top_k: int | None = None,
*,
count_mode: Literal["word", "char", "token"] = "word",
tokenizer_encoding: str = "o200k_base",
) -> None:
"""
Initialize the LostInTheMiddleRanker.

Expand All @@ -46,18 +59,37 @@ def __init__(self, word_count_threshold: int | None = None, top_k: int | None =
be breached will be included in the resulting list of documents, but all subsequent documents will be
discarded.

:param word_count_threshold: The maximum total number of words across all documents selected by the ranker.
:param word_count_threshold: The maximum total count across all documents selected by the ranker. The count is
measured in the unit configured by `count_mode`.
:param top_k: The maximum number of documents to return.
:param count_mode: The unit used for threshold counting. It can be either "word", "char", or "token".
If "token" is selected, the text is counted using the tiktoken tokenizer.
:param tokenizer_encoding: The tiktoken encoding to use when `count_mode` is "token".
"""
if isinstance(word_count_threshold, int) and word_count_threshold <= 0:
raise ValueError(
f"Invalid value for word_count_threshold: {word_count_threshold}. word_count_threshold must be > 0."
)
if isinstance(top_k, int) and top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")
if count_mode not in ["word", "char", "token"]:
raise ValueError(
f"Invalid value for count_mode: {count_mode}. count_mode must be one of: 'word', 'char', 'token'."
)

self.word_count_threshold = word_count_threshold
self.top_k = top_k
self.count_mode = count_mode
self.tokenizer_encoding = tokenizer_encoding
self.tiktoken_tokenizer: "tiktoken.Encoding" | None = None

def warm_up(self) -> None:
"""
Initialize the tokenizer when `count_mode` is "token".
"""
if self.count_mode == "token" and self.tiktoken_tokenizer is None:
tiktoken_imports.check()
self.tiktoken_tokenizer = tiktoken.get_encoding(self.tokenizer_encoding)

@component.output_types(documents=list[Document])
def run(
Expand All @@ -71,7 +103,8 @@ def run(

:param documents: List of Documents to reorder.
:param top_k: The maximum number of documents to return.
:param word_count_threshold: The maximum total number of words across all documents selected by the ranker.
:param word_count_threshold: The maximum total count across all documents selected by the ranker. The count is
measured in the unit configured by `count_mode`.
:returns:
A dictionary with the following keys:
- `documents`: Reranked list of Documents
Expand All @@ -90,7 +123,7 @@ def run(
return {"documents": []}

top_k = top_k or self.top_k
word_count_threshold = word_count_threshold or self.word_count_threshold
word_count_threshold = self.word_count_threshold if word_count_threshold is None else word_count_threshold

deduplicated_documents = _deduplicate_documents(documents)
documents_to_reorder = deduplicated_documents[:top_k] if top_k else deduplicated_documents
Expand All @@ -103,17 +136,18 @@ def run(
if any(not doc.content_type == "text" for doc in documents_to_reorder):
raise ValueError("Some provided documents are not textual; LostInTheMiddleRanker can process only text.")

# Initialize word count and indices for the "lost in the middle" order
word_count = 0
# Initialize threshold count and indices for the "lost in the middle" order
count = 0
document_index = list(range(len(documents_to_reorder)))
lost_in_the_middle_indices = [0]

# If word count threshold is set and the first document has content, calculate word count for the first document
if word_count_threshold and documents_to_reorder[0].content:
word_count = len(documents_to_reorder[0].content.split())
# If threshold is set and the first document has content, calculate count for the first document.
first_document_content = documents_to_reorder[0].content
if word_count_threshold and first_document_content:
count = self._count_text_units(first_document_content)

# If the first document already meets the word count threshold, return it
if word_count >= word_count_threshold:
# If the first document already meets the threshold, return it.
if count >= word_count_threshold:
return {"documents": [documents_to_reorder[0]]}

# Start from the second document and create "lost in the middle" order
Expand All @@ -124,14 +158,32 @@ def run(
# Insert the document index at the calculated position
lost_in_the_middle_indices.insert(insertion_index, doc_idx)

# If word count threshold is set and the document has content, calculate the total word count
if word_count_threshold and documents_to_reorder[doc_idx].content:
word_count += len(documents_to_reorder[doc_idx].content.split()) # type: ignore[union-attr]
# If threshold is set and the document has content, calculate the total count.
document_content = documents_to_reorder[doc_idx].content
if word_count_threshold and document_content:
count += self._count_text_units(document_content)

# If the total word count meets the threshold, stop processing further documents
if word_count >= word_count_threshold:
# If the total count meets the threshold, stop processing further documents.
if count >= word_count_threshold:
break

# Documents in the "lost in the middle" order
ranked_docs = [documents_to_reorder[idx] for idx in lost_in_the_middle_indices]
return {"documents": ranked_docs}

def _count_text_units(self, text: str) -> int:
"""
Count text according to the configured count mode.
"""
if self.count_mode == "word":
return len(text.split())
if self.count_mode == "char":
return len(text)

tokenizer = self.tiktoken_tokenizer
if tokenizer is None:
self.warm_up()
tokenizer = self.tiktoken_tokenizer
if tokenizer is None:
raise RuntimeError("Tokenizer was not initialized.")
return len(tokenizer.encode(text))
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
Added ``count_mode`` and ``tokenizer_encoding`` parameters to ``LostInTheMiddleRanker``.
The existing ``word_count_threshold`` behavior remains word-based by default, while ``count_mode="char"``
counts characters and ``count_mode="token"`` counts tokens using the configured tiktoken encoding.
58 changes: 56 additions & 2 deletions test/components/rankers/test_lost_in_the_middle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,20 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any

import pytest

from haystack import Document
from haystack import Document, Pipeline
from haystack.components.rankers import lost_in_the_middle as lost_in_the_middle_module
from haystack.components.rankers.lost_in_the_middle import LostInTheMiddleRanker


class FakeEncoding:
def encode(self, text: str) -> list[str]:
return text.split("|")


class TestLostInTheMiddleRanker:
def test_lost_in_the_middle_order_odd(self):
# tests that lost_in_the_middle order works with an odd number of documents
Expand Down Expand Up @@ -39,10 +47,17 @@ def test_lost_in_the_middle_init(self):
# tests that LostInTheMiddleRanker initializes with default values
ranker = LostInTheMiddleRanker()
assert ranker.word_count_threshold is None
assert ranker.count_mode == "word"
assert ranker.tokenizer_encoding == "o200k_base"

ranker = LostInTheMiddleRanker(word_count_threshold=10)
assert ranker.word_count_threshold == 10

def test_lost_in_the_middle_init_invalid_count_mode(self):
invalid_count_mode: Any = "byte"
with pytest.raises(ValueError, match="Invalid value for count_mode"):
LostInTheMiddleRanker(count_mode=invalid_count_mode)

def test_lost_in_the_middle_init_invalid_word_count_threshold(self):
# tests that LostInTheMiddleRanker raises an error when word_count_threshold is <= 0
with pytest.raises(ValueError, match="Invalid value for word_count_threshold"):
Expand Down Expand Up @@ -74,6 +89,45 @@ def test_word_count_threshold_greater_than_total_number_of_words_returns_all_doc
expected_order = ["word1", "word3", "word5", "word7", "word9", "word8", "word6", "word4", "word2"]
assert all(doc.content == expected_order[idx] for idx, doc in enumerate(ordered_docs["documents"]))

def test_lost_in_the_middle_with_char_count_mode(self):
ranker = LostInTheMiddleRanker(count_mode="char")
docs = [Document(content="aa"), Document(content="bbb"), Document(content="cccc"), Document(content="ddddd")]

result = ranker.run(documents=docs, word_count_threshold=5)

assert [doc.content for doc in result["documents"]] == ["aa", "bbb"]

def test_lost_in_the_middle_with_token_count_mode(self, monkeypatch):
encoding_names = []

def fake_get_encoding(name: str) -> FakeEncoding:
encoding_names.append(name)
return FakeEncoding()

monkeypatch.setattr(lost_in_the_middle_module.tiktoken, "get_encoding", fake_get_encoding)
ranker = LostInTheMiddleRanker(word_count_threshold=4, count_mode="token", tokenizer_encoding="test-encoding")
docs = [Document(content="a|b|c"), Document(content="d"), Document(content="e"), Document(content="f")]

result = ranker.run(documents=docs)

assert [doc.content for doc in result["documents"]] == ["a|b|c", "d"]
assert encoding_names == ["test-encoding"]

def test_pipeline_serialization_with_count_mode(self) -> None:
pipeline = Pipeline()
pipeline.add_component(
"ranker",
LostInTheMiddleRanker(word_count_threshold=4, count_mode="token", tokenizer_encoding="test-encoding"),
)

restored_pipeline = Pipeline.loads(pipeline.dumps())
restored_ranker = restored_pipeline.get_component("ranker")

assert isinstance(restored_ranker, LostInTheMiddleRanker)
assert restored_ranker.word_count_threshold == 4
assert restored_ranker.count_mode == "token"
assert restored_ranker.tokenizer_encoding == "test-encoding"

def test_empty_documents_returns_empty_list(self):
ranker = LostInTheMiddleRanker()
result = ranker.run(documents=[])
Expand All @@ -98,7 +152,7 @@ def test_run_deduplicates_documents(self):
assert result["documents"][1].content == "unique"

@pytest.mark.parametrize("top_k", [1, 2, 3, 4, 5, 6, 7, 8, 12, 20])
def test_lost_in_the_middle_order_with_top_k(self, top_k: int):
def test_lost_in_the_middle_order_with_top_k(self, top_k: int) -> None:
# tests that lost_in_the_middle order works with an odd number of documents and a top_k parameter
docs = [Document(content=str(i)) for i in range(1, 10)]
ranker = LostInTheMiddleRanker()
Expand Down