Skip to content

Commit ae6f3bc

Browse files
fix: fix inconsistent top_k validation in SentenceTransformersDiversityRanker (#9698)
* Fix inconsistent top_k validation in SentenceTransformersDiversityRanker - change elif to if in run() method to ensure top_k validation always runs regardless of whatever top_k comes from init or runtime - Both scenarios now consistently raise ValueError with descriptive message format: 'top_k must be between 1 and X, but got Y' - Fixes inconsistency where init top_k gave confusing MMR error while runtime top_k gave clear validation error * improvements --------- Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
1 parent 9ce48f9 commit ae6f3bc

3 files changed

Lines changed: 11 additions & 7 deletions

File tree

haystack/components/rankers/sentence_transformers_diversity.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def _maximum_margin_relevance(
353353

354354
texts_to_embed = self._prepare_texts_to_embed(documents)
355355
doc_embeddings, query_embedding = self._embed_and_normalize(query, texts_to_embed)
356-
top_k = top_k if top_k else len(documents)
356+
top_k = min(top_k, len(documents))
357357

358358
selected: list[int] = []
359359
query_similarities_as_tensor = query_embedding @ doc_embeddings.T
@@ -375,9 +375,8 @@ def _maximum_margin_relevance(
375375
if mmr_score > best_score:
376376
best_score = mmr_score
377377
best_idx = idx
378-
if best_idx is None:
379-
raise ValueError("No best document found, check if the documents list contains any documents.")
380-
selected.append(best_idx)
378+
# loop condition ensures unselected docs exist with valid scores
379+
selected.append(best_idx) # type: ignore[arg-type]
381380

382381
return [documents[i] for i in selected]
383382

@@ -421,8 +420,8 @@ def run(
421420

422421
if top_k is None:
423422
top_k = self.top_k
424-
elif not 0 < top_k <= len(documents):
425-
raise ValueError(f"top_k must be between 1 and {len(documents)}, but got {top_k}")
423+
if top_k <= 0:
424+
raise ValueError(f"top_k must be > 0, but got {top_k}")
426425

427426
if self.strategy == DiversityRankingStrategy.MAXIMUM_MARGIN_RELEVANCE:
428427
if lambda_threshold is None:
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
fixes:
3+
- |
4+
Ensure consistent behavior in `SentenceTransformersDiversityRanker`. Like other rankers, it now returns
5+
all documents instead of raising an error when `top_k` exceeds the number of available documents.

test/components/rankers/test_sentence_transformers_diversity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def test_run_negative_top_k(self, similarity):
381381
query = "test"
382382
documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")]
383383

384-
with pytest.raises(ValueError, match="top_k must be between"):
384+
with pytest.raises(ValueError, match="top_k must be > 0"):
385385
ranker.run(query=query, documents=documents, top_k=-5)
386386

387387
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])

0 commit comments

Comments
 (0)