Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 59 additions & 4 deletions haystack/components/builders/answer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

logger = logging.getLogger(__name__)

DEFAULT_REFERENCE_PATTERN = r"\[(\d+)\]"
EXPANDED_REFERENCE_PATTERN = r"\[(\d+(?:[,-]\d+)*)\]"


@component
class AnswerBuilder:
Expand Down Expand Up @@ -74,6 +77,7 @@ def __init__(
last_message_only: bool = False,
*,
return_only_referenced_documents: bool = True,
expand_reference_ranges: bool = False,
) -> None:
"""
Creates an instance of the AnswerBuilder component.
Expand Down Expand Up @@ -104,6 +108,10 @@ def __init__(
If True (default value), only the documents that were actually referenced in `replies` are returned.
If False, all documents are returned.
If `reference_pattern` is not provided, this parameter has no effect, and all documents are returned.
:param expand_reference_ranges:
If True, reference ranges like `[6-10]` are expanded to documents 6 through 10.
Defaults to False for backwards compatibility.
When enabled with the default `reference_pattern`, a broader pattern is used automatically.
"""
if pattern:
AnswerBuilder._check_num_groups_in_regex(pattern)
Expand All @@ -112,6 +120,7 @@ def __init__(
self.reference_pattern = reference_pattern
self.last_message_only = last_message_only
self.return_only_referenced_documents = return_only_referenced_documents
self.expand_reference_ranges = expand_reference_ranges

@component.output_types(answers=list[GeneratedAnswer])
def run(
Expand All @@ -122,6 +131,7 @@ def run(
documents: list[Document] | None = None,
pattern: str | None = None,
reference_pattern: str | None = None,
expand_reference_ranges: bool | None = None,
) -> dict[str, Any]:
"""
Turns the output of a Generator into `GeneratedAnswer` objects using regular expressions.
Expand Down Expand Up @@ -158,6 +168,9 @@ def run(
If not specified, no parsing is done, and all documents are returned.
References need to be specified as indices of the input documents and start at [1].
Example: `\\[(\\d+)\\]` finds "1" in a string "this is an answer[1]".
:param expand_reference_ranges:
If True, reference ranges like `[6-10]` are expanded to documents 6 through 10.
If not specified, the value from the component initialization is used.

:returns: A dictionary with the following keys:
- `answers`: The answers received from the output of the Generator.
Expand All @@ -172,6 +185,12 @@ def run(

pattern = pattern or self.pattern
reference_pattern = reference_pattern or self.reference_pattern
expand_reference_ranges = (
self.expand_reference_ranges if expand_reference_ranges is None else expand_reference_ranges
)
reference_pattern = AnswerBuilder._resolve_reference_pattern(
reference_pattern=reference_pattern, expand_reference_ranges=expand_reference_ranges
)

replies_to_iterate = replies[-1:] if self.last_message_only and replies else replies
meta_to_iterate = meta[-1:] if self.last_message_only and meta else meta
Expand All @@ -188,7 +207,12 @@ def run(
referenced_docs = []
if documents:
referenced_idxs = (
AnswerBuilder._extract_reference_idxs(extracted_reply, reference_pattern)
AnswerBuilder._extract_reference_idxs(
extracted_reply,
reference_pattern,
expand_ranges=expand_reference_ranges,
num_documents=len(documents),
)
if reference_pattern
else set()
)
Expand Down Expand Up @@ -245,9 +269,40 @@ def _extract_answer_string(reply: str, pattern: str | None = None) -> str:
return ""

@staticmethod
def _extract_reference_idxs(reply: str, reference_pattern: str) -> set[int]:
document_idxs = re.findall(reference_pattern, reply)
return {int(idx) - 1 for idx in document_idxs}
def _resolve_reference_pattern(reference_pattern: str | None, expand_reference_ranges: bool) -> str | None:
if not reference_pattern or not expand_reference_ranges:
return reference_pattern
if reference_pattern == DEFAULT_REFERENCE_PATTERN:
return EXPANDED_REFERENCE_PATTERN
return reference_pattern

@staticmethod
def _extract_reference_idxs(
reply: str, reference_pattern: str, expand_ranges: bool = False, num_documents: int | None = None
) -> set[int]:
matches = re.findall(reference_pattern, reply)
idxs: set[int] = set()
for match in matches:
if expand_ranges:
for part in match.split(","):
part = part.strip()
if not part:
continue
if "-" in part:
start_str, end_str = part.split("-", 1)
start, end = int(start_str), int(end_str)
if start > end:
continue
# Clamp the range end to the number of documents to avoid materializing a huge
# set from an out-of-range citation like `[1-999999999]` in the Generator output.
if num_documents is not None:
end = min(end, num_documents)
idxs.update(range(start - 1, end))
else:
idxs.add(int(part) - 1)
else:
idxs.add(int(match) - 1)
return idxs

@staticmethod
def _check_num_groups_in_regex(pattern: str) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
enhancements:
- |
Added an opt-in ``expand_reference_ranges`` parameter to ``AnswerBuilder``.
When enabled, reference ranges like ``[6-10]`` and comma-separated ranges like ``[1-3,7-9]`` are expanded to the corresponding document indices in RAG answers.
The feature is disabled by default to preserve existing parsing behavior.
43 changes: 43 additions & 0 deletions test/components/builders/test_answer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,3 +436,46 @@ def test_run_does_not_mutate_document_with_empty_meta(self):
component = AnswerBuilder()
component.run(query="Capital of France?", replies=["Paris."], documents=[doc])
assert doc.meta == {}

def test_run_expands_reference_ranges_when_enabled(self):
docs = [Document(content=f"doc {i}") for i in range(1, 11)]
component = AnswerBuilder(
reference_pattern=r"\[(\d+)\]", expand_reference_ranges=True, return_only_referenced_documents=False
)
output = component.run(query="test query", replies=["Answer citing sources [6-10]."], documents=docs)
answer = output["answers"][0]
referenced_docs = [doc for doc in answer.documents if doc.meta["referenced"]]
assert [doc.meta["source_index"] for doc in referenced_docs] == [6, 7, 8, 9, 10]

def test_run_expands_comma_separated_reference_ranges(self):
docs = [Document(content=f"doc {i}") for i in range(1, 11)]
component = AnswerBuilder(
reference_pattern=r"\[(\d+)\]", expand_reference_ranges=True, return_only_referenced_documents=False
)
output = component.run(query="test query", replies=["Answer citing sources [1-3,7-9]."], documents=docs)
answer = output["answers"][0]
referenced_docs = [doc for doc in answer.documents if doc.meta["referenced"]]
assert [doc.meta["source_index"] for doc in referenced_docs] == [1, 2, 3, 7, 8, 9]

def test_run_ignores_invalid_reference_ranges(self):
docs = [Document(content=f"doc {i}") for i in range(1, 5)]
component = AnswerBuilder(
reference_pattern=r"\[(\d+)\]", expand_reference_ranges=True, return_only_referenced_documents=True
)
output = component.run(query="test query", replies=["Answer citing sources [3-1]."], documents=docs)
assert output["answers"][0].documents == []

def test_run_clamps_reference_range_to_number_of_documents(self):
docs = [Document(content=f"doc {i}") for i in range(1, 4)]
component = AnswerBuilder(
reference_pattern=r"\[(\d+)\]", expand_reference_ranges=True, return_only_referenced_documents=True
)
output = component.run(query="test query", replies=["Answer citing sources [1-100]."], documents=docs)
referenced_docs = output["answers"][0].documents
assert [doc.meta["source_index"] for doc in referenced_docs] == [1, 2, 3]

def test_run_does_not_expand_reference_ranges_by_default(self):
docs = [Document(content=f"doc {i}") for i in range(1, 11)]
component = AnswerBuilder(reference_pattern=r"\[(\d+)\]", return_only_referenced_documents=True)
output = component.run(query="test query", replies=["Answer citing sources [6-10]."], documents=docs)
assert output["answers"][0].documents == []
7 changes: 7 additions & 0 deletions test/core/pipeline/features/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,7 @@ def pipeline_that_has_a_component_with_only_default_inputs(pipeline_class):
"meta": None,
"pattern": None,
"query": "What is the capital of France?",
"expand_reference_ranges": None,
"reference_pattern": None,
"replies": ["Paris"],
},
Expand Down Expand Up @@ -2382,6 +2383,7 @@ def that_has_an_answer_joiner_variadic_component(pipeline_class):
"meta": None,
"pattern": None,
"query": query,
"expand_reference_ranges": None,
"reference_pattern": None,
"replies": [reply1],
},
Expand All @@ -2390,6 +2392,7 @@ def that_has_an_answer_joiner_variadic_component(pipeline_class):
"meta": None,
"pattern": None,
"query": query,
"expand_reference_ranges": None,
"reference_pattern": None,
"replies": [reply2],
},
Expand Down Expand Up @@ -2634,6 +2637,7 @@ def run(self, prompt: str) -> dict[str, str]:
"meta": None,
"pattern": None,
"query": "What is the answer?",
"expand_reference_ranges": None,
"reference_pattern": None,
"replies": ["42"],
},
Expand Down Expand Up @@ -3028,6 +3032,7 @@ def run(self, query: str) -> dict[str, list[Document]]:
"meta": None,
"pattern": None,
"query": "Does this run reliably?",
"expand_reference_ranges": None,
"reference_pattern": None,
"replies": ["answer: here is my answer"],
},
Expand Down Expand Up @@ -3193,6 +3198,7 @@ def has_feedback_loop(pipeline_class):
"meta": None,
"pattern": None,
"query": "Generate code to generate christmas ascii-art",
"expand_reference_ranges": None,
"reference_pattern": None,
"replies": ["valid code"],
},
Expand Down Expand Up @@ -3335,6 +3341,7 @@ def has_non_standard_order_loop(pipeline_class):
"meta": None,
"pattern": None,
"query": "Generate code to generate christmas ascii-art",
"expand_reference_ranges": None,
"reference_pattern": None,
"replies": ["valid code"],
},
Expand Down
2 changes: 2 additions & 0 deletions test/core/super_component/test_super_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def test_auto_input_mapping(self, rag_pipeline):
input_sockets = wrapper.__haystack_input__._sockets_dict # type: ignore[attr-defined]
assert set(input_sockets.keys()) == {
"documents",
"expand_reference_ranges",
"filters",
"meta",
"pattern",
Expand All @@ -219,6 +220,7 @@ def test_auto_mapping_sockets(self, rag_pipeline):
input_sockets = wrapper.__haystack_input__._sockets_dict # type: ignore[attr-defined]
assert set(input_sockets.keys()) == {
"documents",
"expand_reference_ranges",
"filters",
"meta",
"pattern",
Expand Down
Loading