Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions haystack/components/builders/answer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,26 @@ def _extract_answer_string(reply: str, pattern: str | None = None) -> str:
return match.group(1)
return ""

@staticmethod
def _expand_range_references(reply: str) -> list[int]:
"""
Expand range references like [1-3] into individual 1-based indices [1, 2, 3].

:param reply: The Generator output string.
:returns: List of 1-based indices expanded from range references.
"""
expanded = []
for match in re.finditer(r"\[(\d+)-(\d+)\]", reply):
start, end = int(match.group(1)), int(match.group(2))
if start <= end:
expanded.extend(range(start, end + 1))
return expanded

@staticmethod
def _extract_reference_idxs(reply: str, reference_pattern: str) -> set[int]:
document_idxs = re.findall(reference_pattern, reply)
return {int(idx) - 1 for idx in document_idxs}
individual = {int(idx) - 1 for idx in re.findall(reference_pattern, reply)}
range_refs = {idx - 1 for idx in AnswerBuilder._expand_range_references(reply)}
return individual | range_refs

@staticmethod
def _check_num_groups_in_regex(pattern: str) -> None:
Expand Down
5 changes: 5 additions & 0 deletions releasenotes/notes/answer-builder-range-references.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
``AnswerBuilder`` now supports range references (e.g. ``[1-3]``) in Generator
outputs, automatically expanding them into individual document references.
20 changes: 20 additions & 0 deletions test/components/builders/test_answer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,3 +418,23 @@ def test_conversation_history_with_last_message_only_true(self):
assert "all_messages" in answers[0].meta
assert answers[0].meta["all_messages"] == replies
assert len(answers[0].meta["all_messages"]) == 1

def test_run_with_range_reference_pattern(self):
builder = AnswerBuilder(reference_pattern=r"\[(\d+)\]")
replies = ["Paris [1-2] and Rome [3]."]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also add a test with a more complicated reference like [2-5, 7, 10, 14-16] to see if it works as expected

documents = [
Document(content="Paris is the capital of France."),
Document(content="Berlin is the capital of Germany."),
Document(content="Rome is the capital of Italy."),
]
result = builder.run(query="Capitals?", replies=replies, documents=documents)
referenced = [d for d in result["answers"][0].documents if d.meta.get("referenced")]
assert {d.meta["source_index"] for d in referenced} == {1, 2, 3}

def test_run_with_invalid_range_ignored(self):
builder = AnswerBuilder(reference_pattern=r"\[(\d+)\]")
replies = ["Answer [3-1] and [1]."]
documents = [Document(content="Doc 1"), Document(content="Doc 2"), Document(content="Doc 3")]
result = builder.run(query="Q?", replies=replies, documents=documents)
referenced = [d for d in result["answers"][0].documents if d.meta.get("referenced")]
assert {d.meta["source_index"] for d in referenced} == {1}
Loading