diff --git a/haystack/components/builders/answer_builder.py b/haystack/components/builders/answer_builder.py index 99bb4de9be..78d3ed08d2 100644 --- a/haystack/components/builders/answer_builder.py +++ b/haystack/components/builders/answer_builder.py @@ -11,6 +11,9 @@ logger = logging.getLogger(__name__) +DEFAULT_REFERENCE_PATTERN = r"\[(\d+)\]" +EXPANDED_REFERENCE_PATTERN = r"\[(\d+(?:[,-]\d+)*)\]" + @component class AnswerBuilder: @@ -74,6 +77,7 @@ def __init__( last_message_only: bool = False, *, return_only_referenced_documents: bool = True, + expand_reference_ranges: bool = False, ) -> None: """ Creates an instance of the AnswerBuilder component. @@ -104,6 +108,10 @@ def __init__( If True (default value), only the documents that were actually referenced in `replies` are returned. If False, all documents are returned. If `reference_pattern` is not provided, this parameter has no effect, and all documents are returned. + :param expand_reference_ranges: + If True, reference ranges like `[6-10]` are expanded to documents 6 through 10. + Defaults to False for backwards compatibility. + When enabled with the default `reference_pattern`, a broader pattern is used automatically. """ if pattern: AnswerBuilder._check_num_groups_in_regex(pattern) @@ -112,6 +120,7 @@ def __init__( self.reference_pattern = reference_pattern self.last_message_only = last_message_only self.return_only_referenced_documents = return_only_referenced_documents + self.expand_reference_ranges = expand_reference_ranges @component.output_types(answers=list[GeneratedAnswer]) def run( @@ -122,6 +131,7 @@ def run( documents: list[Document] | None = None, pattern: str | None = None, reference_pattern: str | None = None, + expand_reference_ranges: bool | None = None, ) -> dict[str, Any]: """ Turns the output of a Generator into `GeneratedAnswer` objects using regular expressions. @@ -158,6 +168,9 @@ def run( If not specified, no parsing is done, and all documents are returned. References need to be specified as indices of the input documents and start at [1]. Example: `\\[(\\d+)\\]` finds "1" in a string "this is an answer[1]". + :param expand_reference_ranges: + If True, reference ranges like `[6-10]` are expanded to documents 6 through 10. + If not specified, the value from the component initialization is used. :returns: A dictionary with the following keys: - `answers`: The answers received from the output of the Generator. @@ -172,6 +185,12 @@ def run( pattern = pattern or self.pattern reference_pattern = reference_pattern or self.reference_pattern + expand_reference_ranges = ( + self.expand_reference_ranges if expand_reference_ranges is None else expand_reference_ranges + ) + reference_pattern = AnswerBuilder._resolve_reference_pattern( + reference_pattern=reference_pattern, expand_reference_ranges=expand_reference_ranges + ) replies_to_iterate = replies[-1:] if self.last_message_only and replies else replies meta_to_iterate = meta[-1:] if self.last_message_only and meta else meta @@ -188,7 +207,12 @@ def run( referenced_docs = [] if documents: referenced_idxs = ( - AnswerBuilder._extract_reference_idxs(extracted_reply, reference_pattern) + AnswerBuilder._extract_reference_idxs( + extracted_reply, + reference_pattern, + expand_ranges=expand_reference_ranges, + num_documents=len(documents), + ) if reference_pattern else set() ) @@ -245,9 +269,40 @@ def _extract_answer_string(reply: str, pattern: str | None = None) -> str: return "" @staticmethod - def _extract_reference_idxs(reply: str, reference_pattern: str) -> set[int]: - document_idxs = re.findall(reference_pattern, reply) - return {int(idx) - 1 for idx in document_idxs} + def _resolve_reference_pattern(reference_pattern: str | None, expand_reference_ranges: bool) -> str | None: + if not reference_pattern or not expand_reference_ranges: + return reference_pattern + if reference_pattern == DEFAULT_REFERENCE_PATTERN: + return EXPANDED_REFERENCE_PATTERN + return reference_pattern + + @staticmethod + def _extract_reference_idxs( + reply: str, reference_pattern: str, expand_ranges: bool = False, num_documents: int | None = None + ) -> set[int]: + matches = re.findall(reference_pattern, reply) + idxs: set[int] = set() + for match in matches: + if expand_ranges: + for part in match.split(","): + part = part.strip() + if not part: + continue + if "-" in part: + start_str, end_str = part.split("-", 1) + start, end = int(start_str), int(end_str) + if start > end: + continue + # Clamp the range end to the number of documents to avoid materializing a huge + # set from an out-of-range citation like `[1-999999999]` in the Generator output. + if num_documents is not None: + end = min(end, num_documents) + idxs.update(range(start - 1, end)) + else: + idxs.add(int(part) - 1) + else: + idxs.add(int(match) - 1) + return idxs @staticmethod def _check_num_groups_in_regex(pattern: str) -> None: diff --git a/releasenotes/notes/answer-builder-reference-ranges-a1b2c3d4e5f60718.yaml b/releasenotes/notes/answer-builder-reference-ranges-a1b2c3d4e5f60718.yaml new file mode 100644 index 0000000000..eac7a7fe3c --- /dev/null +++ b/releasenotes/notes/answer-builder-reference-ranges-a1b2c3d4e5f60718.yaml @@ -0,0 +1,6 @@ +--- +enhancements: + - | + Added an opt-in ``expand_reference_ranges`` parameter to ``AnswerBuilder``. + When enabled, reference ranges like ``[6-10]`` and comma-separated ranges like ``[1-3,7-9]`` are expanded to the corresponding document indices in RAG answers. + The feature is disabled by default to preserve existing parsing behavior. diff --git a/test/components/builders/test_answer_builder.py b/test/components/builders/test_answer_builder.py index a0dc3a4ae2..e5bc20d3d9 100644 --- a/test/components/builders/test_answer_builder.py +++ b/test/components/builders/test_answer_builder.py @@ -436,3 +436,46 @@ def test_run_does_not_mutate_document_with_empty_meta(self): component = AnswerBuilder() component.run(query="Capital of France?", replies=["Paris."], documents=[doc]) assert doc.meta == {} + + def test_run_expands_reference_ranges_when_enabled(self): + docs = [Document(content=f"doc {i}") for i in range(1, 11)] + component = AnswerBuilder( + reference_pattern=r"\[(\d+)\]", expand_reference_ranges=True, return_only_referenced_documents=False + ) + output = component.run(query="test query", replies=["Answer citing sources [6-10]."], documents=docs) + answer = output["answers"][0] + referenced_docs = [doc for doc in answer.documents if doc.meta["referenced"]] + assert [doc.meta["source_index"] for doc in referenced_docs] == [6, 7, 8, 9, 10] + + def test_run_expands_comma_separated_reference_ranges(self): + docs = [Document(content=f"doc {i}") for i in range(1, 11)] + component = AnswerBuilder( + reference_pattern=r"\[(\d+)\]", expand_reference_ranges=True, return_only_referenced_documents=False + ) + output = component.run(query="test query", replies=["Answer citing sources [1-3,7-9]."], documents=docs) + answer = output["answers"][0] + referenced_docs = [doc for doc in answer.documents if doc.meta["referenced"]] + assert [doc.meta["source_index"] for doc in referenced_docs] == [1, 2, 3, 7, 8, 9] + + def test_run_ignores_invalid_reference_ranges(self): + docs = [Document(content=f"doc {i}") for i in range(1, 5)] + component = AnswerBuilder( + reference_pattern=r"\[(\d+)\]", expand_reference_ranges=True, return_only_referenced_documents=True + ) + output = component.run(query="test query", replies=["Answer citing sources [3-1]."], documents=docs) + assert output["answers"][0].documents == [] + + def test_run_clamps_reference_range_to_number_of_documents(self): + docs = [Document(content=f"doc {i}") for i in range(1, 4)] + component = AnswerBuilder( + reference_pattern=r"\[(\d+)\]", expand_reference_ranges=True, return_only_referenced_documents=True + ) + output = component.run(query="test query", replies=["Answer citing sources [1-100]."], documents=docs) + referenced_docs = output["answers"][0].documents + assert [doc.meta["source_index"] for doc in referenced_docs] == [1, 2, 3] + + def test_run_does_not_expand_reference_ranges_by_default(self): + docs = [Document(content=f"doc {i}") for i in range(1, 11)] + component = AnswerBuilder(reference_pattern=r"\[(\d+)\]", return_only_referenced_documents=True) + output = component.run(query="test query", replies=["Answer citing sources [6-10]."], documents=docs) + assert output["answers"][0].documents == [] diff --git a/test/core/pipeline/features/test_run.py b/test/core/pipeline/features/test_run.py index 57a0881128..e48c5e0c19 100644 --- a/test/core/pipeline/features/test_run.py +++ b/test/core/pipeline/features/test_run.py @@ -897,6 +897,7 @@ def pipeline_that_has_a_component_with_only_default_inputs(pipeline_class): "meta": None, "pattern": None, "query": "What is the capital of France?", + "expand_reference_ranges": None, "reference_pattern": None, "replies": ["Paris"], }, @@ -2382,6 +2383,7 @@ def that_has_an_answer_joiner_variadic_component(pipeline_class): "meta": None, "pattern": None, "query": query, + "expand_reference_ranges": None, "reference_pattern": None, "replies": [reply1], }, @@ -2390,6 +2392,7 @@ def that_has_an_answer_joiner_variadic_component(pipeline_class): "meta": None, "pattern": None, "query": query, + "expand_reference_ranges": None, "reference_pattern": None, "replies": [reply2], }, @@ -2634,6 +2637,7 @@ def run(self, prompt: str) -> dict[str, str]: "meta": None, "pattern": None, "query": "What is the answer?", + "expand_reference_ranges": None, "reference_pattern": None, "replies": ["42"], }, @@ -3028,6 +3032,7 @@ def run(self, query: str) -> dict[str, list[Document]]: "meta": None, "pattern": None, "query": "Does this run reliably?", + "expand_reference_ranges": None, "reference_pattern": None, "replies": ["answer: here is my answer"], }, @@ -3193,6 +3198,7 @@ def has_feedback_loop(pipeline_class): "meta": None, "pattern": None, "query": "Generate code to generate christmas ascii-art", + "expand_reference_ranges": None, "reference_pattern": None, "replies": ["valid code"], }, @@ -3335,6 +3341,7 @@ def has_non_standard_order_loop(pipeline_class): "meta": None, "pattern": None, "query": "Generate code to generate christmas ascii-art", + "expand_reference_ranges": None, "reference_pattern": None, "replies": ["valid code"], }, diff --git a/test/core/super_component/test_super_component.py b/test/core/super_component/test_super_component.py index cd477e3a27..82dcdafa09 100644 --- a/test/core/super_component/test_super_component.py +++ b/test/core/super_component/test_super_component.py @@ -193,6 +193,7 @@ def test_auto_input_mapping(self, rag_pipeline): input_sockets = wrapper.__haystack_input__._sockets_dict # type: ignore[attr-defined] assert set(input_sockets.keys()) == { "documents", + "expand_reference_ranges", "filters", "meta", "pattern", @@ -219,6 +220,7 @@ def test_auto_mapping_sockets(self, rag_pipeline): input_sockets = wrapper.__haystack_input__._sockets_dict # type: ignore[attr-defined] assert set(input_sockets.keys()) == { "documents", + "expand_reference_ranges", "filters", "meta", "pattern",