@@ -40,6 +40,9 @@ class GroundednessRequirement(Requirement):
4040 grounded. If False (default), response is grounded iff all spans
4141 needing citations are FULLY supported. If True, response is grounded
4242 if spans are fully or partially supported.
43+ max_new_tokens: Maximum tokens for LLM judgment outputs. Increase this
44+ if LLM outputs are being truncated (particularly for complex
45+ responses with many spans). Default is 500.
4346 description: Custom description for the requirement. If None,
4447 generates a default description.
4548
@@ -63,10 +66,12 @@ def __init__(
6366 self ,
6467 documents : Iterable [Document ] | Iterable [str ] | None = None ,
6568 allow_partial_support : bool = False ,
69+ max_new_tokens : int = 500 ,
6670 description : str | None = None ,
6771 ):
6872 """Initialize grounded requirement."""
6973 self .allow_partial_support = allow_partial_support
74+ self .max_new_tokens = max_new_tokens
7075
7176 # Convert documents to Document objects if provided
7277 if documents is not None :
@@ -177,14 +182,11 @@ async def validate(
177182
178183 try :
179184 # Step 1: Citation Generation
180- # Call intrinsic directly for explicit control over model options
181- from ..components .intrinsic ._util import call_intrinsic
185+ # Import lazily to avoid circular dependency
186+ from ..components .intrinsic .rag import find_citations
182187
183- citation_context = context_before_response .add (
184- Message ("assistant" , response , documents = list (documents ))
185- )
186- citations : list [dict ] = call_intrinsic (
187- "citations" , citation_context , backend
188+ citations : list [dict ] = find_citations (
189+ response , list (documents ), context_before_response , backend
188190 )
189191 logger .debug (
190192 f"Step 1 - Citations generated: { len (citations )} citations found"
@@ -219,7 +221,12 @@ async def validate(
219221 # Step 3: Citation Support
220222 try :
221223 span_support = await self ._assess_citation_support (
222- response , citations , span_necessity , backend , context_before_response
224+ response ,
225+ citations ,
226+ span_necessity ,
227+ backend ,
228+ context_before_response ,
229+ documents ,
223230 )
224231 logger .debug (
225232 f"Step 3 - Citation support assessed: { len (span_support )} spans"
@@ -273,7 +280,10 @@ async def _identify_citation_necessity(
273280 result , _ = await backend .generate_from_context (
274281 action ,
275282 context ,
276- model_options = {"temperature" : 0.0 , "max_new_tokens" : 500 },
283+ model_options = {
284+ "temperature" : 0.0 ,
285+ "max_new_tokens" : self .max_new_tokens ,
286+ },
277287 )
278288 await result .avalue ()
279289 output_text = result .value
@@ -295,6 +305,7 @@ async def _assess_citation_support(
295305 span_necessity : dict [tuple [int , int ], bool ],
296306 backend : Backend ,
297307 context : ChatContext ,
308+ documents : list [Document ],
298309 ) -> dict [tuple [int , int ], str ]:
299310 """Assess level of support for spans that need citations.
300311
@@ -307,6 +318,7 @@ async def _assess_citation_support(
307318 span_necessity: Mapping of span (begin, end) to needs_citation flag
308319 backend: Backend for LLM judgment
309320 context: Chat context
321+ documents: List of source documents for context in LLM assessment
310322
311323 Returns:
312324 Dictionary mapping span (begin, end) to support level
@@ -358,7 +370,7 @@ async def _assess_citation_support(
358370 return span_support
359371
360372 # Single batch LLM call for all spans
361- prompt = self ._build_batch_support_prompt (response , spans_to_assess )
373+ prompt = self ._build_batch_support_prompt (response , spans_to_assess , documents )
362374 logger .debug (
363375 f"Batch support assessment prompt (spans={ len (spans_to_assess )} ):\n { prompt } \n "
364376 )
@@ -368,7 +380,10 @@ async def _assess_citation_support(
368380 result , _ = await backend .generate_from_context (
369381 action ,
370382 context ,
371- model_options = {"temperature" : 0.0 , "max_new_tokens" : 500 },
383+ model_options = {
384+ "temperature" : 0.0 ,
385+ "max_new_tokens" : self .max_new_tokens ,
386+ },
372387 )
373388 await result .avalue ()
374389 output_text = result .value
@@ -424,7 +439,7 @@ def _extract_response_spans(
424439 covered_ranges .sort ()
425440 merged_ranges : list [tuple [int , int ]] = []
426441 for begin , end in covered_ranges :
427- if merged_ranges and begin <= merged_ranges [- 1 ][1 ]:
442+ if merged_ranges and begin < merged_ranges [- 1 ][1 ]:
428443 merged_ranges [- 1 ] = (
429444 merged_ranges [- 1 ][0 ],
430445 max (merged_ranges [- 1 ][1 ], end ),
@@ -437,41 +452,38 @@ def _extract_response_spans(
437452 f"Response span extraction - coverage: { covered_chars } /{ len (response )} chars covered by citations"
438453 )
439454
440- # Check if a position is covered by any citation
441- def is_covered (pos : int ) -> bool :
442- for begin , end in merged_ranges :
443- if begin <= pos < end :
444- return True
445- if begin > pos :
446- break
447- return False
448-
449455 # Extract spans by finding boundaries between covered and uncovered regions
456+ # Iterate over merged_ranges boundaries rather than every character for efficiency
450457 spans : list [dict ] = []
451- current_span_start = 0
452- current_is_covered = is_covered (0 ) if response else False
453-
454- for i in range (1 , len (response ) + 1 ):
455- # Check if we're at a boundary (coverage changed or end of response)
456- at_end = i == len (response )
457- next_is_covered = False if at_end else is_covered (i )
458- at_boundary = at_end or next_is_covered != current_is_covered
459-
460- if at_boundary :
461- span_text = response [current_span_start :i ].strip ()
462- if span_text : # Only include non-empty spans
463- spans .append (
464- {
465- "begin" : current_span_start ,
466- "end" : i ,
467- "text" : span_text ,
468- "is_cited" : current_is_covered ,
469- }
470- )
471458
472- current_span_start = i
473- if not at_end :
474- current_is_covered = next_is_covered
459+ # Build boundary points from merged ranges
460+ boundaries = [0 ] # Start of response
461+ for begin , end in merged_ranges :
462+ boundaries .append (begin )
463+ boundaries .append (end )
464+ boundaries .append (len (response )) # End of response
465+ boundaries = sorted (set (boundaries )) # Remove duplicates and sort
466+
467+ # Process each span between boundaries
468+ for i in range (len (boundaries ) - 1 ):
469+ span_start = boundaries [i ]
470+ span_end = boundaries [i + 1 ]
471+
472+ # Determine if this span is covered by any merged range
473+ is_cited = any (
474+ begin <= span_start and span_end <= end for begin , end in merged_ranges
475+ )
476+
477+ span_text = response [span_start :span_end ].strip ()
478+ if span_text : # Only include non-empty spans
479+ spans .append (
480+ {
481+ "begin" : span_start ,
482+ "end" : span_end ,
483+ "text" : span_text ,
484+ "is_cited" : is_cited ,
485+ }
486+ )
475487
476488 logger .debug (f"Response span extraction - extracted { len (spans )} spans" )
477489 for span in spans :
@@ -518,7 +530,7 @@ def _build_necessity_prompt(self, response: str, spans: list[dict]) -> str:
518530 return prompt
519531
520532 def _build_batch_support_prompt (
521- self , response : str , spans_to_assess : list [dict ]
533+ self , response : str , spans_to_assess : list [dict ], documents : list [ Document ]
522534 ) -> str :
523535 """Build prompt to assess citation support level for multiple spans at once.
524536
@@ -530,6 +542,7 @@ def _build_batch_support_prompt(
530542 spans_to_assess: List of span dicts with keys:
531543 - text: span text
532544 - citations: list of citation records for this span
545+ documents: List of source documents for context
533546
534547 Returns:
535548 Formatted prompt for LLM expecting JSON array output
@@ -561,13 +574,26 @@ def _build_batch_support_prompt(
561574
562575 spans_formatted = ",\n " .join (span_assessments )
563576
577+ # Build source documents section for context
578+ documents_section = ""
579+ if documents :
580+ doc_lines = []
581+ for doc in documents :
582+ doc_id = doc .doc_id if hasattr (doc , "doc_id" ) else "unknown"
583+ doc_text = doc .text if hasattr (doc , "text" ) else str (doc )
584+ doc_lines .append (f"Document { doc_id } :\n { doc_text } " )
585+ documents_section = "Source Documents:\n " + "\n \n " .join (doc_lines ) + "\n \n "
586+
564587 prompt = (
565- "Assess the level of support for each response span based on provided citations.\n \n "
588+ "Assess the level of support for each response span based on provided citations "
589+ "and source documents.\n \n "
566590 "For each span, determine if the citations fully support, partially support, "
567- "or do not support the span.\n \n "
591+ "or do not support the span. Consider the full context from the source documents "
592+ "where the citations appear.\n \n "
568593 "Respond with a JSON array of the form:\n "
569594 '[{"span_id": ..., "support_level": ...}, ...]\n \n '
570595 "Support levels must be ONLY one of: FULLY_SUPPORTED, PARTIALLY_SUPPORTED, or NOT_SUPPORTED.\n \n "
596+ f"{ documents_section } "
571597 f"Response context:\n { response } \n \n "
572598 f"Spans to assess:\n [\n { spans_formatted } \n ]\n \n "
573599 "JSON Output:\n "
0 commit comments