Skip to content

Commit 435daf3

Browse files
fix: Align start, end, and probability lists in ExtractiveReader (#11347)
Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
1 parent 4a5727b commit 435daf3

3 files changed

Lines changed: 86 additions & 4 deletions

File tree

haystack/components/readers/extractive.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def _postprocess(
265265
attention_mask: "torch.Tensor",
266266
answers_per_seq: int,
267267
encodings: list["Encoding"],
268-
) -> tuple[list[list[int]], list[list[int]], "torch.Tensor"]:
268+
) -> tuple[list[list[int]], list[list[int]], list["torch.Tensor"]]:
269269
"""
270270
Turns start and end logits into probabilities for each answer span.
271271
@@ -302,10 +302,12 @@ def _postprocess(
302302

303303
start_candidates_tokens_to_chars = []
304304
end_candidates_tokens_to_chars = []
305+
valid_candidates_values: list[torch.Tensor] = []
305306
for i, (s_candidates, e_candidates, encoding) in enumerate(
306307
zip(start_candidates, end_candidates, encodings, strict=True)
307308
):
308-
# Those with probabilities > 0 are valid
309+
# Those with probabilities > 0 are valid. topk may include masked candidates
310+
# when answers_per_seq exceeds the number of valid spans, so filter all three lists together.
309311
valid = candidates_values[i] > 0
310312
s_char_spans = []
311313
e_char_spans = []
@@ -318,8 +320,9 @@ def _postprocess(
318320
e_char_spans.append(encoding.token_to_chars(end_token)[1])
319321
start_candidates_tokens_to_chars.append(s_char_spans)
320322
end_candidates_tokens_to_chars.append(e_char_spans)
323+
valid_candidates_values.append(candidates_values[i][valid])
321324

322-
return start_candidates_tokens_to_chars, end_candidates_tokens_to_chars, candidates_values
325+
return start_candidates_tokens_to_chars, end_candidates_tokens_to_chars, valid_candidates_values
323326

324327
def _add_answer_page_number(self, answer: ExtractedAnswer) -> ExtractedAnswer:
325328
if answer.meta is None:
@@ -351,7 +354,7 @@ def _nest_answers(
351354
*,
352355
start: list[list[int]],
353356
end: list[list[int]],
354-
probabilities: "torch.Tensor",
357+
probabilities: list["torch.Tensor"],
355358
flattened_documents: list[Document],
356359
queries: list[str],
357360
answers_per_seq: int,
@@ -389,6 +392,10 @@ def _nest_answers(
389392
nested_answers = []
390393
for query_id in range(query_ids[-1] + 1):
391394
current_answers = []
395+
# `i // answers_per_seq` assumes every sequence contributes exactly `answers_per_seq`
396+
# answers. That's not guaranteed (see _postprocess: invalid candidates are
397+
# filtered out per sequence), but is fine here because `run` always passes a single
398+
# query, so every entry in `query_ids` is 0 and the index lookup is correct for any i.
392399
while i < len(answers_without_query) and query_ids[i // answers_per_seq] == query_id:
393400
current_answers.append(replace(answers_without_query[i], query=queries[query_id]))
394401
i += 1
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
---
2+
fixes:
3+
- |
4+
Fixed ``ExtractiveReader`` raising ``ValueError`` when the number of valid answer spans for a sequence was smaller
5+
than ``answers_per_seq`` (for example with short documents or when ``answers_per_seq`` exceeded the number of
6+
upper-triangular, non-masked (start, end) token pairs). ``_postprocess`` now filters the per-sequence probabilities
7+
by the same validity mask it already applied to the start/end token indices, so the three structures
8+
always have matching lengths.

test/components/readers/test_extractive.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,73 @@ def test_nest_answers(mock_reader: ExtractiveReader):
436436
assert no_answer.score == pytest.approx(expected_no_answer)
437437

438438

439+
def test_postprocess_filters_probs_when_answers_per_seq_exceeds_valid_spans(mock_reader: ExtractiveReader):
440+
# Setup: seq_length=4, attention-masked positions 0-1, so only context positions 2-3
441+
# remain. The number of valid (start <= end) spans is therefore 3: (2,2), (2,3), (3,3).
442+
# We ask for 5 answers per sequence — the 2 extra topk picks fall in the masked region
443+
# and become probability 0 after sigmoid.
444+
start = torch.zeros((2, 4))
445+
end = torch.zeros((2, 4))
446+
sequence_ids = torch.ones((2, 4))
447+
attention_mask = torch.ones((2, 4))
448+
attention_mask[:, :2] = 0
449+
encoding = Mock()
450+
encoding.token_to_chars = lambda i: (int(i), int(i) + 1)
451+
452+
start_candidates, end_candidates, probs = mock_reader._postprocess(
453+
start=start,
454+
end=end,
455+
sequence_ids=sequence_ids,
456+
attention_mask=attention_mask,
457+
answers_per_seq=5,
458+
encodings=[encoding, encoding],
459+
)
460+
461+
# Per-sequence lengths must all agree so the downstream strict=True zips don't blow up.
462+
assert len(start_candidates) == len(end_candidates) == len(probs) == 2
463+
for i in range(2):
464+
assert len(start_candidates[i]) == len(end_candidates[i]) == len(probs[i]) == 3
465+
# All retained probabilities are valid (> 0); masked candidates were dropped.
466+
assert torch.all(probs[i] > 0)
467+
468+
469+
def test_nest_answers_accepts_variable_length_probability_rows(mock_reader: ExtractiveReader):
470+
# Mirrors the new contract from _postprocess: `probabilities` is a list of 1D tensors
471+
# whose lengths match the per-sequence start/end lists.
472+
start = [[0, 1], [0]]
473+
end = [[5, 6], [5]]
474+
probabilities = [torch.tensor([0.8, 0.6]), torch.tensor([0.7])]
475+
query_ids = [0, 0]
476+
document_ids = [0, 1]
477+
478+
nested_answers = mock_reader._nest_answers(
479+
start=start,
480+
end=end,
481+
probabilities=probabilities,
482+
flattened_documents=example_documents[0][:2],
483+
queries=example_queries[:1],
484+
answers_per_seq=2,
485+
top_k=10,
486+
score_threshold=None,
487+
query_ids=query_ids,
488+
document_ids=document_ids,
489+
no_answer=False,
490+
overlap_threshold=None,
491+
)
492+
493+
assert len(nested_answers) == 1
494+
sorted_answers = sorted(nested_answers[0], key=lambda a: a.score, reverse=True)
495+
496+
docs = example_documents[0][:2]
497+
expected = [
498+
(pytest.approx(0.8), docs[0].id, "Angel"),
499+
(pytest.approx(0.7), docs[1].id, "Olaf "),
500+
(pytest.approx(0.6), docs[0].id, "ngela"),
501+
]
502+
actual = [(a.score, a.document.id, a.data) for a in sorted_answers]
503+
assert actual == expected
504+
505+
439506
def test_add_answer_page_number_returns_same_answer(mock_reader: ExtractiveReader, caplog):
440507
# answer.document_offset is None
441508
document = Document(content="I thought a lot about this. The answer is 42.", meta={"page_number": 5})

0 commit comments

Comments
 (0)