diff --git a/docs-website/docs/pipeline-components/rankers/lostinthemiddleranker.mdx b/docs-website/docs/pipeline-components/rankers/lostinthemiddleranker.mdx index 4b810fbd48..a7ee69665d 100644 --- a/docs-website/docs/pipeline-components/rankers/lostinthemiddleranker.mdx +++ b/docs-website/docs/pipeline-components/rankers/lostinthemiddleranker.mdx @@ -32,6 +32,8 @@ In contrast to other rankers, `LostInTheMiddleRanker` assumes that the input doc If you specify the `word_count_threshold` when running the component, the Ranker includes all documents up until the point where adding another document would exceed the given threshold. The last document that exceeds the threshold will be included in the resulting list of Documents, but all following documents will be discarded. +By default, `word_count_threshold` counts whitespace-separated words. You can set `count_mode` to `word`, `char`, or `token` to count the threshold in words, characters, or tokens. Token counting uses the `tokenizer_encoding` setting and defaults to the `o200k_base` encoding. + You can also specify the `top_k` parameter to set the maximum number of documents to return. ## Usage @@ -88,7 +90,7 @@ document_store = InMemoryDocumentStore() document_store.write_documents(docs) retriever = InMemoryBM25Retriever(document_store=document_store) -ranker = LostInTheMiddleRanker(word_count_threshold=1024) +ranker = LostInTheMiddleRanker(word_count_threshold=1024, count_mode="token") prompt_builder = ChatPromptBuilder( template=prompt_template, required_variables={"query", "documents"}, diff --git a/haystack/components/rankers/lost_in_the_middle.py b/haystack/components/rankers/lost_in_the_middle.py index 930643cc1b..c27f714984 100644 --- a/haystack/components/rankers/lost_in_the_middle.py +++ b/haystack/components/rankers/lost_in_the_middle.py @@ -3,9 +3,15 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Literal + from haystack import Document, component +from haystack.lazy_imports import LazyImport from haystack.utils.misc import _deduplicate_documents +with LazyImport("Run 'pip install tiktoken'") as tiktoken_imports: + import tiktoken + @component class LostInTheMiddleRanker: @@ -37,7 +43,14 @@ class LostInTheMiddleRanker: ``` """ - def __init__(self, word_count_threshold: int | None = None, top_k: int | None = None) -> None: + def __init__( + self, + word_count_threshold: int | None = None, + top_k: int | None = None, + *, + count_mode: Literal["word", "char", "token"] = "word", + tokenizer_encoding: str = "o200k_base", + ) -> None: """ Initialize the LostInTheMiddleRanker. @@ -46,8 +59,12 @@ def __init__(self, word_count_threshold: int | None = None, top_k: int | None = be breached will be included in the resulting list of documents, but all subsequent documents will be discarded. - :param word_count_threshold: The maximum total number of words across all documents selected by the ranker. + :param word_count_threshold: The maximum total count across all documents selected by the ranker. The count is + measured in the unit configured by `count_mode`. :param top_k: The maximum number of documents to return. + :param count_mode: The unit used for threshold counting. It can be either "word", "char", or "token". + If "token" is selected, the text is counted using the tiktoken tokenizer. + :param tokenizer_encoding: The tiktoken encoding to use when `count_mode` is "token". """ if isinstance(word_count_threshold, int) and word_count_threshold <= 0: raise ValueError( @@ -55,9 +72,24 @@ def __init__(self, word_count_threshold: int | None = None, top_k: int | None = ) if isinstance(top_k, int) and top_k <= 0: raise ValueError(f"top_k must be > 0, but got {top_k}") + if count_mode not in ["word", "char", "token"]: + raise ValueError( + f"Invalid value for count_mode: {count_mode}. count_mode must be one of: 'word', 'char', 'token'." + ) self.word_count_threshold = word_count_threshold self.top_k = top_k + self.count_mode = count_mode + self.tokenizer_encoding = tokenizer_encoding + self.tiktoken_tokenizer: "tiktoken.Encoding" | None = None + + def warm_up(self) -> None: + """ + Initialize the tokenizer when `count_mode` is "token". + """ + if self.count_mode == "token" and self.tiktoken_tokenizer is None: + tiktoken_imports.check() + self.tiktoken_tokenizer = tiktoken.get_encoding(self.tokenizer_encoding) @component.output_types(documents=list[Document]) def run( @@ -71,7 +103,8 @@ def run( :param documents: List of Documents to reorder. :param top_k: The maximum number of documents to return. - :param word_count_threshold: The maximum total number of words across all documents selected by the ranker. + :param word_count_threshold: The maximum total count across all documents selected by the ranker. The count is + measured in the unit configured by `count_mode`. :returns: A dictionary with the following keys: - `documents`: Reranked list of Documents @@ -90,7 +123,7 @@ def run( return {"documents": []} top_k = top_k or self.top_k - word_count_threshold = word_count_threshold or self.word_count_threshold + word_count_threshold = self.word_count_threshold if word_count_threshold is None else word_count_threshold deduplicated_documents = _deduplicate_documents(documents) documents_to_reorder = deduplicated_documents[:top_k] if top_k else deduplicated_documents @@ -103,17 +136,18 @@ def run( if any(not doc.content_type == "text" for doc in documents_to_reorder): raise ValueError("Some provided documents are not textual; LostInTheMiddleRanker can process only text.") - # Initialize word count and indices for the "lost in the middle" order - word_count = 0 + # Initialize threshold count and indices for the "lost in the middle" order + count = 0 document_index = list(range(len(documents_to_reorder))) lost_in_the_middle_indices = [0] - # If word count threshold is set and the first document has content, calculate word count for the first document - if word_count_threshold and documents_to_reorder[0].content: - word_count = len(documents_to_reorder[0].content.split()) + # If threshold is set and the first document has content, calculate count for the first document. + first_document_content = documents_to_reorder[0].content + if word_count_threshold and first_document_content: + count = self._count_text_units(first_document_content) - # If the first document already meets the word count threshold, return it - if word_count >= word_count_threshold: + # If the first document already meets the threshold, return it. + if count >= word_count_threshold: return {"documents": [documents_to_reorder[0]]} # Start from the second document and create "lost in the middle" order @@ -124,14 +158,32 @@ def run( # Insert the document index at the calculated position lost_in_the_middle_indices.insert(insertion_index, doc_idx) - # If word count threshold is set and the document has content, calculate the total word count - if word_count_threshold and documents_to_reorder[doc_idx].content: - word_count += len(documents_to_reorder[doc_idx].content.split()) # type: ignore[union-attr] + # If threshold is set and the document has content, calculate the total count. + document_content = documents_to_reorder[doc_idx].content + if word_count_threshold and document_content: + count += self._count_text_units(document_content) - # If the total word count meets the threshold, stop processing further documents - if word_count >= word_count_threshold: + # If the total count meets the threshold, stop processing further documents. + if count >= word_count_threshold: break # Documents in the "lost in the middle" order ranked_docs = [documents_to_reorder[idx] for idx in lost_in_the_middle_indices] return {"documents": ranked_docs} + + def _count_text_units(self, text: str) -> int: + """ + Count text according to the configured count mode. + """ + if self.count_mode == "word": + return len(text.split()) + if self.count_mode == "char": + return len(text) + + tokenizer = self.tiktoken_tokenizer + if tokenizer is None: + self.warm_up() + tokenizer = self.tiktoken_tokenizer + if tokenizer is None: + raise RuntimeError("Tokenizer was not initialized.") + return len(tokenizer.encode(text)) diff --git a/releasenotes/notes/lost-in-the-middle-count-mode-22a91cf4a16cff99.yaml b/releasenotes/notes/lost-in-the-middle-count-mode-22a91cf4a16cff99.yaml new file mode 100644 index 0000000000..b9ea12c4f8 --- /dev/null +++ b/releasenotes/notes/lost-in-the-middle-count-mode-22a91cf4a16cff99.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Added ``count_mode`` and ``tokenizer_encoding`` parameters to ``LostInTheMiddleRanker``. + The existing ``word_count_threshold`` behavior remains word-based by default, while ``count_mode="char"`` + counts characters and ``count_mode="token"`` counts tokens using the configured tiktoken encoding. diff --git a/test/components/rankers/test_lost_in_the_middle.py b/test/components/rankers/test_lost_in_the_middle.py index 69618f6bde..86e30d21e1 100644 --- a/test/components/rankers/test_lost_in_the_middle.py +++ b/test/components/rankers/test_lost_in_the_middle.py @@ -2,12 +2,20 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Any + import pytest -from haystack import Document +from haystack import Document, Pipeline +from haystack.components.rankers import lost_in_the_middle as lost_in_the_middle_module from haystack.components.rankers.lost_in_the_middle import LostInTheMiddleRanker +class FakeEncoding: + def encode(self, text: str) -> list[str]: + return text.split("|") + + class TestLostInTheMiddleRanker: def test_lost_in_the_middle_order_odd(self): # tests that lost_in_the_middle order works with an odd number of documents @@ -39,10 +47,17 @@ def test_lost_in_the_middle_init(self): # tests that LostInTheMiddleRanker initializes with default values ranker = LostInTheMiddleRanker() assert ranker.word_count_threshold is None + assert ranker.count_mode == "word" + assert ranker.tokenizer_encoding == "o200k_base" ranker = LostInTheMiddleRanker(word_count_threshold=10) assert ranker.word_count_threshold == 10 + def test_lost_in_the_middle_init_invalid_count_mode(self): + invalid_count_mode: Any = "byte" + with pytest.raises(ValueError, match="Invalid value for count_mode"): + LostInTheMiddleRanker(count_mode=invalid_count_mode) + def test_lost_in_the_middle_init_invalid_word_count_threshold(self): # tests that LostInTheMiddleRanker raises an error when word_count_threshold is <= 0 with pytest.raises(ValueError, match="Invalid value for word_count_threshold"): @@ -74,6 +89,45 @@ def test_word_count_threshold_greater_than_total_number_of_words_returns_all_doc expected_order = ["word1", "word3", "word5", "word7", "word9", "word8", "word6", "word4", "word2"] assert all(doc.content == expected_order[idx] for idx, doc in enumerate(ordered_docs["documents"])) + def test_lost_in_the_middle_with_char_count_mode(self): + ranker = LostInTheMiddleRanker(count_mode="char") + docs = [Document(content="aa"), Document(content="bbb"), Document(content="cccc"), Document(content="ddddd")] + + result = ranker.run(documents=docs, word_count_threshold=5) + + assert [doc.content for doc in result["documents"]] == ["aa", "bbb"] + + def test_lost_in_the_middle_with_token_count_mode(self, monkeypatch): + encoding_names = [] + + def fake_get_encoding(name: str) -> FakeEncoding: + encoding_names.append(name) + return FakeEncoding() + + monkeypatch.setattr(lost_in_the_middle_module.tiktoken, "get_encoding", fake_get_encoding) + ranker = LostInTheMiddleRanker(word_count_threshold=4, count_mode="token", tokenizer_encoding="test-encoding") + docs = [Document(content="a|b|c"), Document(content="d"), Document(content="e"), Document(content="f")] + + result = ranker.run(documents=docs) + + assert [doc.content for doc in result["documents"]] == ["a|b|c", "d"] + assert encoding_names == ["test-encoding"] + + def test_pipeline_serialization_with_count_mode(self) -> None: + pipeline = Pipeline() + pipeline.add_component( + "ranker", + LostInTheMiddleRanker(word_count_threshold=4, count_mode="token", tokenizer_encoding="test-encoding"), + ) + + restored_pipeline = Pipeline.loads(pipeline.dumps()) + restored_ranker = restored_pipeline.get_component("ranker") + + assert isinstance(restored_ranker, LostInTheMiddleRanker) + assert restored_ranker.word_count_threshold == 4 + assert restored_ranker.count_mode == "token" + assert restored_ranker.tokenizer_encoding == "test-encoding" + def test_empty_documents_returns_empty_list(self): ranker = LostInTheMiddleRanker() result = ranker.run(documents=[]) @@ -98,7 +152,7 @@ def test_run_deduplicates_documents(self): assert result["documents"][1].content == "unique" @pytest.mark.parametrize("top_k", [1, 2, 3, 4, 5, 6, 7, 8, 12, 20]) - def test_lost_in_the_middle_order_with_top_k(self, top_k: int): + def test_lost_in_the_middle_order_with_top_k(self, top_k: int) -> None: # tests that lost_in_the_middle order works with an odd number of documents and a top_k parameter docs = [Document(content=str(i)) for i in range(1, 10)] ranker = LostInTheMiddleRanker()