Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions haystack/components/readers/extractive.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def _postprocess(
attention_mask: "torch.Tensor",
answers_per_seq: int,
encodings: list["Encoding"],
) -> tuple[list[list[int]], list[list[int]], "torch.Tensor"]:
) -> tuple[list[list[int]], list[list[int]], list["torch.Tensor"]]:
"""
Turns start and end logits into probabilities for each answer span.

Expand Down Expand Up @@ -302,10 +302,12 @@ def _postprocess(

start_candidates_tokens_to_chars = []
end_candidates_tokens_to_chars = []
valid_candidates_values: list[torch.Tensor] = []
for i, (s_candidates, e_candidates, encoding) in enumerate(
zip(start_candidates, end_candidates, encodings, strict=True)
):
# Those with probabilities > 0 are valid
# Those with probabilities > 0 are valid. topk may include masked candidates
# when answers_per_seq exceeds the number of valid spans, so filter all three lists together.
valid = candidates_values[i] > 0
s_char_spans = []
e_char_spans = []
Expand All @@ -318,8 +320,9 @@ def _postprocess(
e_char_spans.append(encoding.token_to_chars(end_token)[1])
start_candidates_tokens_to_chars.append(s_char_spans)
end_candidates_tokens_to_chars.append(e_char_spans)
valid_candidates_values.append(candidates_values[i][valid])

return start_candidates_tokens_to_chars, end_candidates_tokens_to_chars, candidates_values
return start_candidates_tokens_to_chars, end_candidates_tokens_to_chars, valid_candidates_values

def _add_answer_page_number(self, answer: ExtractedAnswer) -> ExtractedAnswer:
if answer.meta is None:
Expand Down Expand Up @@ -351,7 +354,7 @@ def _nest_answers(
*,
start: list[list[int]],
end: list[list[int]],
probabilities: "torch.Tensor",
probabilities: list["torch.Tensor"],
flattened_documents: list[Document],
queries: list[str],
answers_per_seq: int,
Expand Down Expand Up @@ -389,6 +392,10 @@ def _nest_answers(
nested_answers = []
for query_id in range(query_ids[-1] + 1):
current_answers = []
# `i // answers_per_seq` assumes every sequence contributes exactly `answers_per_seq`
# answers. That's not guaranteed (see _postprocess: invalid candidates are
# filtered out per sequence), but is fine here because `run` always passes a single
# query, so every entry in `query_ids` is 0 and the index lookup is correct for any i.
while i < len(answers_without_query) and query_ids[i // answers_per_seq] == query_id:
current_answers.append(replace(answers_without_query[i], query=queries[query_id]))
i += 1
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
fixes:
- |
Fixed ``ExtractiveReader`` raising ``ValueError`` when the number of valid answer spans for a sequence was smaller
than ``answers_per_seq`` (for example with short documents or when ``answers_per_seq`` exceeded the number of
upper-triangular, non-masked (start, end) token pairs). ``_postprocess`` now filters the per-sequence probabilities
by the same validity mask it already applied to the start/end token indices, so the three structures
always have matching lengths.
67 changes: 67 additions & 0 deletions test/components/readers/test_extractive.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,73 @@ def test_nest_answers(mock_reader: ExtractiveReader):
assert no_answer.score == pytest.approx(expected_no_answer)


def test_postprocess_filters_probs_when_answers_per_seq_exceeds_valid_spans(mock_reader: ExtractiveReader):
# Setup: seq_length=4, attention-masked positions 0-1, so only context positions 2-3
# remain. The number of valid (start <= end) spans is therefore 3: (2,2), (2,3), (3,3).
# We ask for 5 answers per sequence — the 2 extra topk picks fall in the masked region
# and become probability 0 after sigmoid.
start = torch.zeros((2, 4))
end = torch.zeros((2, 4))
sequence_ids = torch.ones((2, 4))
attention_mask = torch.ones((2, 4))
attention_mask[:, :2] = 0
encoding = Mock()
encoding.token_to_chars = lambda i: (int(i), int(i) + 1)

start_candidates, end_candidates, probs = mock_reader._postprocess(
start=start,
end=end,
sequence_ids=sequence_ids,
attention_mask=attention_mask,
answers_per_seq=5,
encodings=[encoding, encoding],
)

# Per-sequence lengths must all agree so the downstream strict=True zips don't blow up.
assert len(start_candidates) == len(end_candidates) == len(probs) == 2
for i in range(2):
assert len(start_candidates[i]) == len(end_candidates[i]) == len(probs[i]) == 3
# All retained probabilities are valid (> 0); masked candidates were dropped.
assert torch.all(probs[i] > 0)


def test_nest_answers_accepts_variable_length_probability_rows(mock_reader: ExtractiveReader):
# Mirrors the new contract from _postprocess: `probabilities` is a list of 1D tensors
# whose lengths match the per-sequence start/end lists.
start = [[0, 1], [0]]
end = [[5, 6], [5]]
probabilities = [torch.tensor([0.8, 0.6]), torch.tensor([0.7])]
query_ids = [0, 0]
document_ids = [0, 1]

nested_answers = mock_reader._nest_answers(
start=start,
end=end,
probabilities=probabilities,
flattened_documents=example_documents[0][:2],
queries=example_queries[:1],
answers_per_seq=2,
top_k=10,
score_threshold=None,
query_ids=query_ids,
document_ids=document_ids,
no_answer=False,
overlap_threshold=None,
)

assert len(nested_answers) == 1
sorted_answers = sorted(nested_answers[0], key=lambda a: a.score, reverse=True)

docs = example_documents[0][:2]
expected = [
(pytest.approx(0.8), docs[0].id, "Angel"),
(pytest.approx(0.7), docs[1].id, "Olaf "),
(pytest.approx(0.6), docs[0].id, "ngela"),
]
actual = [(a.score, a.document.id, a.data) for a in sorted_answers]
assert actual == expected


def test_add_answer_page_number_returns_same_answer(mock_reader: ExtractiveReader, caplog):
# answer.document_offset is None
document = Document(content="I thought a lot about this. The answer is 42.", meta={"page_number": 5})
Expand Down
Loading