Skip to content

Commit b48d3e7

Browse files
vedjawcursoragentjulian-rischclaude
authored
feat: add opt-in reference range expansion to AnswerBuilder (#11623)
Signed-off-by: vedjaw <vedant.jawandhia@gmail.com> Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: Julian Risch <julian.risch@deepset.ai> Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent b552024 commit b48d3e7

5 files changed

Lines changed: 117 additions & 4 deletions

File tree

haystack/components/builders/answer_builder.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111

1212
logger = logging.getLogger(__name__)
1313

14+
DEFAULT_REFERENCE_PATTERN = r"\[(\d+)\]"
15+
EXPANDED_REFERENCE_PATTERN = r"\[(\d+(?:[,-]\d+)*)\]"
16+
1417

1518
@component
1619
class AnswerBuilder:
@@ -74,6 +77,7 @@ def __init__(
7477
last_message_only: bool = False,
7578
*,
7679
return_only_referenced_documents: bool = True,
80+
expand_reference_ranges: bool = False,
7781
) -> None:
7882
"""
7983
Creates an instance of the AnswerBuilder component.
@@ -104,6 +108,10 @@ def __init__(
104108
If True (default value), only the documents that were actually referenced in `replies` are returned.
105109
If False, all documents are returned.
106110
If `reference_pattern` is not provided, this parameter has no effect, and all documents are returned.
111+
:param expand_reference_ranges:
112+
If True, reference ranges like `[6-10]` are expanded to documents 6 through 10.
113+
Defaults to False for backwards compatibility.
114+
When enabled with the default `reference_pattern`, a broader pattern is used automatically.
107115
"""
108116
if pattern:
109117
AnswerBuilder._check_num_groups_in_regex(pattern)
@@ -112,6 +120,7 @@ def __init__(
112120
self.reference_pattern = reference_pattern
113121
self.last_message_only = last_message_only
114122
self.return_only_referenced_documents = return_only_referenced_documents
123+
self.expand_reference_ranges = expand_reference_ranges
115124

116125
@component.output_types(answers=list[GeneratedAnswer])
117126
def run(
@@ -122,6 +131,7 @@ def run(
122131
documents: list[Document] | None = None,
123132
pattern: str | None = None,
124133
reference_pattern: str | None = None,
134+
expand_reference_ranges: bool | None = None,
125135
) -> dict[str, Any]:
126136
"""
127137
Turns the output of a Generator into `GeneratedAnswer` objects using regular expressions.
@@ -158,6 +168,9 @@ def run(
158168
If not specified, no parsing is done, and all documents are returned.
159169
References need to be specified as indices of the input documents and start at [1].
160170
Example: `\\[(\\d+)\\]` finds "1" in a string "this is an answer[1]".
171+
:param expand_reference_ranges:
172+
If True, reference ranges like `[6-10]` are expanded to documents 6 through 10.
173+
If not specified, the value from the component initialization is used.
161174
162175
:returns: A dictionary with the following keys:
163176
- `answers`: The answers received from the output of the Generator.
@@ -172,6 +185,12 @@ def run(
172185

173186
pattern = pattern or self.pattern
174187
reference_pattern = reference_pattern or self.reference_pattern
188+
expand_reference_ranges = (
189+
self.expand_reference_ranges if expand_reference_ranges is None else expand_reference_ranges
190+
)
191+
reference_pattern = AnswerBuilder._resolve_reference_pattern(
192+
reference_pattern=reference_pattern, expand_reference_ranges=expand_reference_ranges
193+
)
175194

176195
replies_to_iterate = replies[-1:] if self.last_message_only and replies else replies
177196
meta_to_iterate = meta[-1:] if self.last_message_only and meta else meta
@@ -188,7 +207,12 @@ def run(
188207
referenced_docs = []
189208
if documents:
190209
referenced_idxs = (
191-
AnswerBuilder._extract_reference_idxs(extracted_reply, reference_pattern)
210+
AnswerBuilder._extract_reference_idxs(
211+
extracted_reply,
212+
reference_pattern,
213+
expand_ranges=expand_reference_ranges,
214+
num_documents=len(documents),
215+
)
192216
if reference_pattern
193217
else set()
194218
)
@@ -245,9 +269,40 @@ def _extract_answer_string(reply: str, pattern: str | None = None) -> str:
245269
return ""
246270

247271
@staticmethod
248-
def _extract_reference_idxs(reply: str, reference_pattern: str) -> set[int]:
249-
document_idxs = re.findall(reference_pattern, reply)
250-
return {int(idx) - 1 for idx in document_idxs}
272+
def _resolve_reference_pattern(reference_pattern: str | None, expand_reference_ranges: bool) -> str | None:
273+
if not reference_pattern or not expand_reference_ranges:
274+
return reference_pattern
275+
if reference_pattern == DEFAULT_REFERENCE_PATTERN:
276+
return EXPANDED_REFERENCE_PATTERN
277+
return reference_pattern
278+
279+
@staticmethod
280+
def _extract_reference_idxs(
281+
reply: str, reference_pattern: str, expand_ranges: bool = False, num_documents: int | None = None
282+
) -> set[int]:
283+
matches = re.findall(reference_pattern, reply)
284+
idxs: set[int] = set()
285+
for match in matches:
286+
if expand_ranges:
287+
for part in match.split(","):
288+
part = part.strip()
289+
if not part:
290+
continue
291+
if "-" in part:
292+
start_str, end_str = part.split("-", 1)
293+
start, end = int(start_str), int(end_str)
294+
if start > end:
295+
continue
296+
# Clamp the range end to the number of documents to avoid materializing a huge
297+
# set from an out-of-range citation like `[1-999999999]` in the Generator output.
298+
if num_documents is not None:
299+
end = min(end, num_documents)
300+
idxs.update(range(start - 1, end))
301+
else:
302+
idxs.add(int(part) - 1)
303+
else:
304+
idxs.add(int(match) - 1)
305+
return idxs
251306

252307
@staticmethod
253308
def _check_num_groups_in_regex(pattern: str) -> None:
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
enhancements:
3+
- |
4+
Added an opt-in ``expand_reference_ranges`` parameter to ``AnswerBuilder``.
5+
When enabled, reference ranges like ``[6-10]`` and comma-separated ranges like ``[1-3,7-9]`` are expanded to the corresponding document indices in RAG answers.
6+
The feature is disabled by default to preserve existing parsing behavior.

test/components/builders/test_answer_builder.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,3 +436,46 @@ def test_run_does_not_mutate_document_with_empty_meta(self):
436436
component = AnswerBuilder()
437437
component.run(query="Capital of France?", replies=["Paris."], documents=[doc])
438438
assert doc.meta == {}
439+
440+
def test_run_expands_reference_ranges_when_enabled(self):
441+
docs = [Document(content=f"doc {i}") for i in range(1, 11)]
442+
component = AnswerBuilder(
443+
reference_pattern=r"\[(\d+)\]", expand_reference_ranges=True, return_only_referenced_documents=False
444+
)
445+
output = component.run(query="test query", replies=["Answer citing sources [6-10]."], documents=docs)
446+
answer = output["answers"][0]
447+
referenced_docs = [doc for doc in answer.documents if doc.meta["referenced"]]
448+
assert [doc.meta["source_index"] for doc in referenced_docs] == [6, 7, 8, 9, 10]
449+
450+
def test_run_expands_comma_separated_reference_ranges(self):
451+
docs = [Document(content=f"doc {i}") for i in range(1, 11)]
452+
component = AnswerBuilder(
453+
reference_pattern=r"\[(\d+)\]", expand_reference_ranges=True, return_only_referenced_documents=False
454+
)
455+
output = component.run(query="test query", replies=["Answer citing sources [1-3,7-9]."], documents=docs)
456+
answer = output["answers"][0]
457+
referenced_docs = [doc for doc in answer.documents if doc.meta["referenced"]]
458+
assert [doc.meta["source_index"] for doc in referenced_docs] == [1, 2, 3, 7, 8, 9]
459+
460+
def test_run_ignores_invalid_reference_ranges(self):
461+
docs = [Document(content=f"doc {i}") for i in range(1, 5)]
462+
component = AnswerBuilder(
463+
reference_pattern=r"\[(\d+)\]", expand_reference_ranges=True, return_only_referenced_documents=True
464+
)
465+
output = component.run(query="test query", replies=["Answer citing sources [3-1]."], documents=docs)
466+
assert output["answers"][0].documents == []
467+
468+
def test_run_clamps_reference_range_to_number_of_documents(self):
469+
docs = [Document(content=f"doc {i}") for i in range(1, 4)]
470+
component = AnswerBuilder(
471+
reference_pattern=r"\[(\d+)\]", expand_reference_ranges=True, return_only_referenced_documents=True
472+
)
473+
output = component.run(query="test query", replies=["Answer citing sources [1-100]."], documents=docs)
474+
referenced_docs = output["answers"][0].documents
475+
assert [doc.meta["source_index"] for doc in referenced_docs] == [1, 2, 3]
476+
477+
def test_run_does_not_expand_reference_ranges_by_default(self):
478+
docs = [Document(content=f"doc {i}") for i in range(1, 11)]
479+
component = AnswerBuilder(reference_pattern=r"\[(\d+)\]", return_only_referenced_documents=True)
480+
output = component.run(query="test query", replies=["Answer citing sources [6-10]."], documents=docs)
481+
assert output["answers"][0].documents == []

test/core/pipeline/features/test_run.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,7 @@ def pipeline_that_has_a_component_with_only_default_inputs(pipeline_class):
897897
"meta": None,
898898
"pattern": None,
899899
"query": "What is the capital of France?",
900+
"expand_reference_ranges": None,
900901
"reference_pattern": None,
901902
"replies": ["Paris"],
902903
},
@@ -2382,6 +2383,7 @@ def that_has_an_answer_joiner_variadic_component(pipeline_class):
23822383
"meta": None,
23832384
"pattern": None,
23842385
"query": query,
2386+
"expand_reference_ranges": None,
23852387
"reference_pattern": None,
23862388
"replies": [reply1],
23872389
},
@@ -2390,6 +2392,7 @@ def that_has_an_answer_joiner_variadic_component(pipeline_class):
23902392
"meta": None,
23912393
"pattern": None,
23922394
"query": query,
2395+
"expand_reference_ranges": None,
23932396
"reference_pattern": None,
23942397
"replies": [reply2],
23952398
},
@@ -2634,6 +2637,7 @@ def run(self, prompt: str) -> dict[str, str]:
26342637
"meta": None,
26352638
"pattern": None,
26362639
"query": "What is the answer?",
2640+
"expand_reference_ranges": None,
26372641
"reference_pattern": None,
26382642
"replies": ["42"],
26392643
},
@@ -3028,6 +3032,7 @@ def run(self, query: str) -> dict[str, list[Document]]:
30283032
"meta": None,
30293033
"pattern": None,
30303034
"query": "Does this run reliably?",
3035+
"expand_reference_ranges": None,
30313036
"reference_pattern": None,
30323037
"replies": ["answer: here is my answer"],
30333038
},
@@ -3193,6 +3198,7 @@ def has_feedback_loop(pipeline_class):
31933198
"meta": None,
31943199
"pattern": None,
31953200
"query": "Generate code to generate christmas ascii-art",
3201+
"expand_reference_ranges": None,
31963202
"reference_pattern": None,
31973203
"replies": ["valid code"],
31983204
},
@@ -3335,6 +3341,7 @@ def has_non_standard_order_loop(pipeline_class):
33353341
"meta": None,
33363342
"pattern": None,
33373343
"query": "Generate code to generate christmas ascii-art",
3344+
"expand_reference_ranges": None,
33383345
"reference_pattern": None,
33393346
"replies": ["valid code"],
33403347
},

test/core/super_component/test_super_component.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def test_auto_input_mapping(self, rag_pipeline):
193193
input_sockets = wrapper.__haystack_input__._sockets_dict # type: ignore[attr-defined]
194194
assert set(input_sockets.keys()) == {
195195
"documents",
196+
"expand_reference_ranges",
196197
"filters",
197198
"meta",
198199
"pattern",
@@ -219,6 +220,7 @@ def test_auto_mapping_sockets(self, rag_pipeline):
219220
input_sockets = wrapper.__haystack_input__._sockets_dict # type: ignore[attr-defined]
220221
assert set(input_sockets.keys()) == {
221222
"documents",
223+
"expand_reference_ranges",
222224
"filters",
223225
"meta",
224226
"pattern",

0 commit comments

Comments
 (0)