From 8a76082ca7455b7fed0fda5173c8a50d862b6a99 Mon Sep 17 00:00:00 2001 From: bogdankostic Date: Wed, 20 May 2026 14:20:30 +0200 Subject: [PATCH 1/4] fix: Align start, end, and probability lists in ExtractiveReader --- haystack/components/readers/extractive.py | 11 ++-- ...start-end-prob-lists-883d318dacd44fc6.yaml | 8 +++ test/components/readers/test_extractive.py | 59 +++++++++++++++++++ 3 files changed, 74 insertions(+), 4 deletions(-) create mode 100644 releasenotes/notes/fix-extractivereader-unaligned-start-end-prob-lists-883d318dacd44fc6.yaml diff --git a/haystack/components/readers/extractive.py b/haystack/components/readers/extractive.py index f11accb1aa..2afcc91e79 100644 --- a/haystack/components/readers/extractive.py +++ b/haystack/components/readers/extractive.py @@ -265,7 +265,7 @@ def _postprocess( attention_mask: "torch.Tensor", answers_per_seq: int, encodings: list["Encoding"], - ) -> tuple[list[list[int]], list[list[int]], "torch.Tensor"]: + ) -> tuple[list[list[int]], list[list[int]], list["torch.Tensor"]]: """ Turns start and end logits into probabilities for each answer span. @@ -302,10 +302,12 @@ def _postprocess( start_candidates_tokens_to_chars = [] end_candidates_tokens_to_chars = [] + valid_candidates_values: list[torch.Tensor] = [] for i, (s_candidates, e_candidates, encoding) in enumerate( zip(start_candidates, end_candidates, encodings, strict=True) ): - # Those with probabilities > 0 are valid + # Those with probabilities > 0 are valid. topk may include masked candidates + # when answers_per_seq exceeds the number of valid spans, so filter all three lists together. valid = candidates_values[i] > 0 s_char_spans = [] e_char_spans = [] @@ -318,8 +320,9 @@ def _postprocess( e_char_spans.append(encoding.token_to_chars(end_token)[1]) start_candidates_tokens_to_chars.append(s_char_spans) end_candidates_tokens_to_chars.append(e_char_spans) + valid_candidates_values.append(candidates_values[i][valid]) - return start_candidates_tokens_to_chars, end_candidates_tokens_to_chars, candidates_values + return start_candidates_tokens_to_chars, end_candidates_tokens_to_chars, valid_candidates_values def _add_answer_page_number(self, answer: ExtractedAnswer) -> ExtractedAnswer: if answer.meta is None: @@ -351,7 +354,7 @@ def _nest_answers( *, start: list[list[int]], end: list[list[int]], - probabilities: "torch.Tensor", + probabilities: list["torch.Tensor"], flattened_documents: list[Document], queries: list[str], answers_per_seq: int, diff --git a/releasenotes/notes/fix-extractivereader-unaligned-start-end-prob-lists-883d318dacd44fc6.yaml b/releasenotes/notes/fix-extractivereader-unaligned-start-end-prob-lists-883d318dacd44fc6.yaml new file mode 100644 index 0000000000..8bc58a7257 --- /dev/null +++ b/releasenotes/notes/fix-extractivereader-unaligned-start-end-prob-lists-883d318dacd44fc6.yaml @@ -0,0 +1,8 @@ +--- +fixes: + - | + Fixed ``ExtractiveReader`` raising ``ValueError``` when the number of valid answer spans for a sequence was smaller + than ``answers_per_seq`` (for example with short documents or when ``answers_per_seq`` exceeded the number of + upper-triangular, non-masked (start, end) token pairs). ``_postprocess`` now filters the per-sequence probabilities + by the same validity mask it already applied to the start/end token indices, so the three structures + always have matching lengths. diff --git a/test/components/readers/test_extractive.py b/test/components/readers/test_extractive.py index 7a8a311ab7..e741d6e77e 100644 --- a/test/components/readers/test_extractive.py +++ b/test/components/readers/test_extractive.py @@ -436,6 +436,65 @@ def test_nest_answers(mock_reader: ExtractiveReader): assert no_answer.score == pytest.approx(expected_no_answer) +def test_postprocess_filters_probs_when_answers_per_seq_exceeds_valid_spans(mock_reader: ExtractiveReader): + # Setup: seq_length=4, attention-masked positions 0-1, so only context positions 2-3 + # remain. The number of valid (start <= end) spans is therefore 3: (2,2), (2,3), (3,3). + # We ask for 5 answers per sequence — the 2 extra topk picks fall in the masked region + # and become probability 0 after sigmoid. + start = torch.zeros((2, 4)) + end = torch.zeros((2, 4)) + sequence_ids = torch.ones((2, 4)) + attention_mask = torch.ones((2, 4)) + attention_mask[:, :2] = 0 + encoding = Mock() + encoding.token_to_chars = lambda i: (int(i), int(i) + 1) + + start_candidates, end_candidates, probs = mock_reader._postprocess( + start=start, + end=end, + sequence_ids=sequence_ids, + attention_mask=attention_mask, + answers_per_seq=5, + encodings=[encoding, encoding], + ) + + # Per-sequence lengths must all agree so the downstream strict=True zips don't blow up. + assert len(start_candidates) == len(end_candidates) == len(probs) == 2 + for i in range(2): + assert len(start_candidates[i]) == len(end_candidates[i]) == len(probs[i]) == 3 + # All retained probabilities are valid (> 0); masked candidates were dropped. + assert torch.all(probs[i] > 0) + + +def test_nest_answers_accepts_variable_length_probability_rows(mock_reader: ExtractiveReader): + # Mirrors the new contract from _postprocess: `probabilities` is a list of 1D tensors + # whose lengths match the per-sequence start/end lists. + start = [[0, 1], [0]] + end = [[5, 6], [5]] + probabilities = [torch.tensor([0.8, 0.6]), torch.tensor([0.7])] + query_ids = [0, 0] + document_ids = [0, 1] + + nested_answers = mock_reader._nest_answers( + start=start, + end=end, + probabilities=probabilities, + flattened_documents=example_documents[0][:2], + queries=example_queries[:1], + answers_per_seq=2, + top_k=10, + score_threshold=None, + query_ids=query_ids, + document_ids=document_ids, + no_answer=False, + overlap_threshold=None, + ) + + assert len(nested_answers) == 1 + scores = sorted((a.score for a in nested_answers[0]), reverse=True) + assert scores == pytest.approx([0.8, 0.7, 0.6]) + + def test_add_answer_page_number_returns_same_answer(mock_reader: ExtractiveReader, caplog): # answer.document_offset is None document = Document(content="I thought a lot about this. The answer is 42.", meta={"page_number": 5}) From e517a7380911c2c228c7c4ad5d4c1d29eeef991e Mon Sep 17 00:00:00 2001 From: bogdankostic Date: Wed, 20 May 2026 14:32:27 +0200 Subject: [PATCH 2/4] fix reno --- ...ereader-unaligned-start-end-prob-lists-883d318dacd44fc6.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/releasenotes/notes/fix-extractivereader-unaligned-start-end-prob-lists-883d318dacd44fc6.yaml b/releasenotes/notes/fix-extractivereader-unaligned-start-end-prob-lists-883d318dacd44fc6.yaml index 8bc58a7257..10d9242b09 100644 --- a/releasenotes/notes/fix-extractivereader-unaligned-start-end-prob-lists-883d318dacd44fc6.yaml +++ b/releasenotes/notes/fix-extractivereader-unaligned-start-end-prob-lists-883d318dacd44fc6.yaml @@ -1,7 +1,7 @@ --- fixes: - | - Fixed ``ExtractiveReader`` raising ``ValueError``` when the number of valid answer spans for a sequence was smaller + Fixed ``ExtractiveReader`` raising ``ValueError`` when the number of valid answer spans for a sequence was smaller than ``answers_per_seq`` (for example with short documents or when ``answers_per_seq`` exceeded the number of upper-triangular, non-masked (start, end) token pairs). ``_postprocess`` now filters the per-sequence probabilities by the same validity mask it already applied to the start/end token indices, so the three structures From 7477b44439747423911812c2bc452c2bf269e0bd Mon Sep 17 00:00:00 2001 From: bogdankostic Date: Wed, 20 May 2026 16:13:22 +0200 Subject: [PATCH 3/4] Add comment --- haystack/components/readers/extractive.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/haystack/components/readers/extractive.py b/haystack/components/readers/extractive.py index 2afcc91e79..56345bdaeb 100644 --- a/haystack/components/readers/extractive.py +++ b/haystack/components/readers/extractive.py @@ -392,6 +392,10 @@ def _nest_answers( nested_answers = [] for query_id in range(query_ids[-1] + 1): current_answers = [] + # `i // answers_per_seq` assumes every sequence contributes exactly `answers_per_seq` + # answers. That's not guaranteed (see _postprocess: invalid candidates are + # filtered out per sequence), but is fine here because `run` always passes a single + # query, so every entry in `query_ids` is 0 and the index lookup is correct for any i. while i < len(answers_without_query) and query_ids[i // answers_per_seq] == query_id: current_answers.append(replace(answers_without_query[i], query=queries[query_id])) i += 1 From 3ba1a9d0d14d44390cbd86d2987ffc4f84cbe61a Mon Sep 17 00:00:00 2001 From: bogdankostic Date: Wed, 20 May 2026 16:21:52 +0200 Subject: [PATCH 4/4] Update test/components/readers/test_extractive.py Co-authored-by: Stefano Fiorucci --- test/components/readers/test_extractive.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/test/components/readers/test_extractive.py b/test/components/readers/test_extractive.py index e741d6e77e..3afc7f439f 100644 --- a/test/components/readers/test_extractive.py +++ b/test/components/readers/test_extractive.py @@ -491,8 +491,16 @@ def test_nest_answers_accepts_variable_length_probability_rows(mock_reader: Extr ) assert len(nested_answers) == 1 - scores = sorted((a.score for a in nested_answers[0]), reverse=True) - assert scores == pytest.approx([0.8, 0.7, 0.6]) + sorted_answers = sorted(nested_answers[0], key=lambda a: a.score, reverse=True) + + docs = example_documents[0][:2] + expected = [ + (pytest.approx(0.8), docs[0].id, "Angel"), + (pytest.approx(0.7), docs[1].id, "Olaf "), + (pytest.approx(0.6), docs[0].id, "ngela"), + ] + actual = [(a.score, a.document.id, a.data) for a in sorted_answers] + assert actual == expected def test_add_answer_page_number_returns_same_answer(mock_reader: ExtractiveReader, caplog):