1111
1212logger = logging .getLogger (__name__ )
1313
14+ DEFAULT_REFERENCE_PATTERN = r"\[(\d+)\]"
15+ EXPANDED_REFERENCE_PATTERN = r"\[(\d+(?:[,-]\d+)*)\]"
16+
1417
1518@component
1619class AnswerBuilder :
@@ -74,6 +77,7 @@ def __init__(
7477 last_message_only : bool = False ,
7578 * ,
7679 return_only_referenced_documents : bool = True ,
80+ expand_reference_ranges : bool = False ,
7781 ) -> None :
7882 """
7983 Creates an instance of the AnswerBuilder component.
@@ -104,6 +108,10 @@ def __init__(
104108 If True (default value), only the documents that were actually referenced in `replies` are returned.
105109 If False, all documents are returned.
106110 If `reference_pattern` is not provided, this parameter has no effect, and all documents are returned.
111+ :param expand_reference_ranges:
112+ If True, reference ranges like `[6-10]` are expanded to documents 6 through 10.
113+ Defaults to False for backwards compatibility.
114+ When enabled with the default `reference_pattern`, a broader pattern is used automatically.
107115 """
108116 if pattern :
109117 AnswerBuilder ._check_num_groups_in_regex (pattern )
@@ -112,6 +120,7 @@ def __init__(
112120 self .reference_pattern = reference_pattern
113121 self .last_message_only = last_message_only
114122 self .return_only_referenced_documents = return_only_referenced_documents
123+ self .expand_reference_ranges = expand_reference_ranges
115124
116125 @component .output_types (answers = list [GeneratedAnswer ])
117126 def run (
@@ -122,6 +131,7 @@ def run(
122131 documents : list [Document ] | None = None ,
123132 pattern : str | None = None ,
124133 reference_pattern : str | None = None ,
134+ expand_reference_ranges : bool | None = None ,
125135 ) -> dict [str , Any ]:
126136 """
127137 Turns the output of a Generator into `GeneratedAnswer` objects using regular expressions.
@@ -158,6 +168,9 @@ def run(
158168 If not specified, no parsing is done, and all documents are returned.
159169 References need to be specified as indices of the input documents and start at [1].
160170 Example: `\\ [(\\ d+)\\ ]` finds "1" in a string "this is an answer[1]".
171+ :param expand_reference_ranges:
172+ If True, reference ranges like `[6-10]` are expanded to documents 6 through 10.
173+ If not specified, the value from the component initialization is used.
161174
162175 :returns: A dictionary with the following keys:
163176 - `answers`: The answers received from the output of the Generator.
@@ -172,6 +185,12 @@ def run(
172185
173186 pattern = pattern or self .pattern
174187 reference_pattern = reference_pattern or self .reference_pattern
188+ expand_reference_ranges = (
189+ self .expand_reference_ranges if expand_reference_ranges is None else expand_reference_ranges
190+ )
191+ reference_pattern = AnswerBuilder ._resolve_reference_pattern (
192+ reference_pattern = reference_pattern , expand_reference_ranges = expand_reference_ranges
193+ )
175194
176195 replies_to_iterate = replies [- 1 :] if self .last_message_only and replies else replies
177196 meta_to_iterate = meta [- 1 :] if self .last_message_only and meta else meta
@@ -188,7 +207,12 @@ def run(
188207 referenced_docs = []
189208 if documents :
190209 referenced_idxs = (
191- AnswerBuilder ._extract_reference_idxs (extracted_reply , reference_pattern )
210+ AnswerBuilder ._extract_reference_idxs (
211+ extracted_reply ,
212+ reference_pattern ,
213+ expand_ranges = expand_reference_ranges ,
214+ num_documents = len (documents ),
215+ )
192216 if reference_pattern
193217 else set ()
194218 )
@@ -245,9 +269,40 @@ def _extract_answer_string(reply: str, pattern: str | None = None) -> str:
245269 return ""
246270
247271 @staticmethod
248- def _extract_reference_idxs (reply : str , reference_pattern : str ) -> set [int ]:
249- document_idxs = re .findall (reference_pattern , reply )
250- return {int (idx ) - 1 for idx in document_idxs }
272+ def _resolve_reference_pattern (reference_pattern : str | None , expand_reference_ranges : bool ) -> str | None :
273+ if not reference_pattern or not expand_reference_ranges :
274+ return reference_pattern
275+ if reference_pattern == DEFAULT_REFERENCE_PATTERN :
276+ return EXPANDED_REFERENCE_PATTERN
277+ return reference_pattern
278+
279+ @staticmethod
280+ def _extract_reference_idxs (
281+ reply : str , reference_pattern : str , expand_ranges : bool = False , num_documents : int | None = None
282+ ) -> set [int ]:
283+ matches = re .findall (reference_pattern , reply )
284+ idxs : set [int ] = set ()
285+ for match in matches :
286+ if expand_ranges :
287+ for part in match .split ("," ):
288+ part = part .strip ()
289+ if not part :
290+ continue
291+ if "-" in part :
292+ start_str , end_str = part .split ("-" , 1 )
293+ start , end = int (start_str ), int (end_str )
294+ if start > end :
295+ continue
296+ # Clamp the range end to the number of documents to avoid materializing a huge
297+ # set from an out-of-range citation like `[1-999999999]` in the Generator output.
298+ if num_documents is not None :
299+ end = min (end , num_documents )
300+ idxs .update (range (start - 1 , end ))
301+ else :
302+ idxs .add (int (part ) - 1 )
303+ else :
304+ idxs .add (int (match ) - 1 )
305+ return idxs
251306
252307 @staticmethod
253308 def _check_num_groups_in_regex (pattern : str ) -> None :
0 commit comments