11"""Streaming query handler using Responses API (v2)."""
22
33import logging
4+ import traceback
45from typing import Annotated , Any , AsyncIterator , Optional , cast
6+ from urllib .parse import urljoin
57
68from fastapi import APIRouter , Depends , Request
79from fastapi .responses import StreamingResponse
1416 OpenAIResponseObjectStreamResponseOutputTextDelta ,
1517 OpenAIResponseObjectStreamResponseOutputTextDone ,
1618)
17- from llama_stack_client import AsyncLlamaStackClient
19+ from llama_stack_client import APIConnectionError , APIStatusError , AsyncLlamaStackClient
1820
1921from app .endpoints .query import (
2022 is_transcripts_enabled ,
5153 InternalServerErrorResponse ,
5254 NotFoundResponse ,
5355 QuotaExceededResponse ,
56+ ReferencedDocument ,
5457 ServiceUnavailableResponse ,
5558 StreamingQueryResponse ,
5659 UnauthorizedResponse ,
97100
98101def create_responses_response_generator ( # pylint: disable=too-many-locals,too-many-statements
99102 context : ResponseGeneratorContext ,
103+ doc_ids_from_chunks : Optional [list [ReferencedDocument ]] = None ,
100104) -> Any :
101105 """
102106 Create a response generator function for Responses API streaming.
@@ -106,6 +110,7 @@ def create_responses_response_generator( # pylint: disable=too-many-locals,too-
106110
107111 Args:
108112 context: Context object containing all necessary parameters for response generation
113+ doc_ids_from_chunks: Referenced documents extracted from vector DB chunks
109114
110115 Returns:
111116 An async generator function that yields SSE-formatted strings
@@ -294,17 +299,21 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
294299 model_id = context .model_id ,
295300 provider_id = context .provider_id ,
296301 )
297- referenced_documents = parse_referenced_documents_from_responses_api (
302+ response_referenced_documents = parse_referenced_documents_from_responses_api (
298303 cast (OpenAIResponseObject , latest_response_object )
299304 )
305+ # Combine doc_ids_from_chunks with response_referenced_documents
306+ all_referenced_documents = (
307+ doc_ids_from_chunks or []
308+ ) + response_referenced_documents
300309 available_quotas = get_available_quotas (
301310 configuration .quota_limiters , context .user_id
302311 )
303312 yield stream_end_event (
304313 context .metadata_map ,
305314 token_usage ,
306315 available_quotas ,
307- referenced_documents ,
316+ all_referenced_documents ,
308317 media_type ,
309318 )
310319
@@ -382,7 +391,7 @@ async def retrieve_response( # pylint: disable=too-many-locals
382391 query_request : QueryRequest ,
383392 token : str ,
384393 mcp_headers : Optional [dict [str , dict [str , str ]]] = None ,
385- ) -> tuple [AsyncIterator [OpenAIResponseObjectStream ], str ]:
394+ ) -> tuple [AsyncIterator [OpenAIResponseObjectStream ], str , list [ ReferencedDocument ] ]:
386395 """
387396 Retrieve response from LLMs and agents.
388397
@@ -403,8 +412,8 @@ async def retrieve_response( # pylint: disable=too-many-locals
403412 Multi-cluster proxy headers for tool integrations.
404413
405414 Returns:
406- tuple: A tuple containing the streaming response object
407- and the conversation ID.
415+ tuple: A tuple containing the streaming response object,
416+ the conversation ID, and the list of referenced documents from vector DB chunks .
408417 """
409418 # use system prompt from request or default one
410419 system_prompt = get_system_prompt (query_request , configuration )
@@ -415,11 +424,180 @@ async def retrieve_response( # pylint: disable=too-many-locals
415424 if query_request .attachments :
416425 validate_attachments_metadata (query_request .attachments )
417426
418- # Prepare tools for responses API
427+ # Prepare tools for responses API - skip RAG tools since we're doing direct vector query
419428 toolgroups = await prepare_tools_for_responses_api (
420- client , query_request , token , configuration , mcp_headers = mcp_headers
429+ client ,
430+ query_request ,
431+ token ,
432+ configuration ,
433+ mcp_headers = mcp_headers ,
434+ skip_rag_tools = True ,
421435 )
422436
437+ # Extract RAG chunks from vector DB query response BEFORE calling responses API
438+ rag_chunks = []
439+ doc_ids_from_chunks = []
440+ retrieved_chunks = []
441+ retrieved_scores = []
442+
443+ # When offline is False, use reference_url for chunk source
444+ # When offline is True, use parent_id for chunk source
445+ # TODO: move this setting to a higher level configuration
446+ offline = True
447+
448+ try :
449+ # Get vector stores for direct querying
450+ if query_request .vector_store_ids :
451+ vector_store_ids = query_request .vector_store_ids
452+ logger .info (
453+ "Using specified vector_store_ids for direct query: %s" ,
454+ vector_store_ids ,
455+ )
456+ else :
457+ vector_store_ids = [
458+ vector_store .id
459+ for vector_store in (await client .vector_stores .list ()).data
460+ ]
461+ logger .info (
462+ "Using all available vector_store_ids for direct query: %s" ,
463+ vector_store_ids ,
464+ )
465+
466+ if vector_store_ids :
467+ vector_store_id = vector_store_ids [0 ] # Use first available vector store
468+
469+ params = {"k" : 5 , "score_threshold" : 0.0 , "mode" : "hybrid" }
470+ logger .info ("Initial params: %s" , params )
471+ logger .info ("query_request.solr: %s" , query_request .solr )
472+ if query_request .solr :
473+ # Pass the entire solr dict under the 'solr' key
474+ params ["solr" ] = query_request .solr
475+ logger .info ("Final params with solr filters: %s" , params )
476+ else :
477+ logger .info ("No solr filters provided" )
478+ logger .info ("Final params being sent to vector_io.query: %s" , params )
479+
480+ query_response = await client .vector_io .query (
481+ vector_store_id = vector_store_id ,
482+ query = query_request .query ,
483+ params = params ,
484+ )
485+
486+ logger .info ("The query response total payload: %s" , query_response )
487+
488+ if query_response .chunks :
489+ retrieved_chunks = query_response .chunks
490+ retrieved_scores = (
491+ query_response .scores if hasattr (query_response , "scores" ) else []
492+ )
493+
494+ # Extract doc_ids from chunks for referenced_documents
495+ metadata_doc_ids = set ()
496+
497+ for chunk in query_response .chunks :
498+ logger .info ("Extract doc ids from chunk: %s" , chunk )
499+
500+ # 1) dict metadata
501+ md = getattr (chunk , "metadata" , None ) or {}
502+ doc_id = md .get ("doc_id" ) or md .get ("document_id" )
503+ title = md .get ("title" )
504+
505+ # 2) typed chunk_metadata
506+ if not doc_id :
507+ cm = getattr (chunk , "chunk_metadata" , None )
508+ if cm is not None :
509+ # cm might be a pydantic model or a dict depending on caller
510+ if isinstance (cm , dict ):
511+ doc_id = cm .get ("doc_id" ) or cm .get ("document_id" )
512+ title = title or cm .get ("title" )
513+ reference_url = cm .get ("reference_url" )
514+ else :
515+ doc_id = getattr (cm , "doc_id" , None ) or getattr (
516+ cm , "document_id" , None
517+ )
518+ title = title or getattr (cm , "title" , None )
519+ reference_url = getattr (cm , "reference_url" , None )
520+ else :
521+ reference_url = None
522+ else :
523+ reference_url = md .get ("reference_url" )
524+
525+ if not doc_id and not reference_url :
526+ continue
527+
528+ # Build URL based on offline flag
529+ if offline :
530+ # Use parent/doc path
531+ reference_doc = doc_id
532+ doc_url = "https://mimir.corp.redhat.com" + reference_doc
533+ else :
534+ # Use reference_url if online
535+ reference_doc = reference_url or doc_id
536+ doc_url = (
537+ reference_doc
538+ if reference_doc .startswith ("http" )
539+ else ("https://mimir.corp.redhat.com" + reference_doc )
540+ )
541+
542+ if reference_doc and reference_doc not in metadata_doc_ids :
543+ metadata_doc_ids .add (reference_doc )
544+ doc_ids_from_chunks .append (
545+ ReferencedDocument (
546+ doc_title = title ,
547+ doc_url = doc_url ,
548+ )
549+ )
550+
551+ logger .info (
552+ "Extracted %d unique document IDs from chunks" , len (doc_ids_from_chunks )
553+ )
554+
555+ except (
556+ APIConnectionError ,
557+ APIStatusError ,
558+ AttributeError ,
559+ KeyError ,
560+ ValueError ,
561+ ) as e :
562+ logger .warning ("Failed to query vector database for chunks: %s" , e )
563+ logger .debug ("Vector DB query error details: %s" , traceback .format_exc ())
564+ # Continue without RAG chunks
565+
566+ # Convert retrieved chunks to RAGChunk format
567+ for i , chunk in enumerate (retrieved_chunks ):
568+ # Extract source from chunk metadata based on offline flag
569+ source = None
570+ if chunk .metadata :
571+ if offline :
572+ parent_id = chunk .metadata .get ("parent_id" )
573+ if parent_id :
574+ source = urljoin ("https://mimir.corp.redhat.com" , parent_id )
575+ else :
576+ source = chunk .metadata .get ("reference_url" )
577+
578+ # Get score from retrieved_scores list if available
579+ score = retrieved_scores [i ] if i < len (retrieved_scores ) else None
580+
581+ rag_chunks .append (
582+ RAGChunk (
583+ content = chunk .content ,
584+ source = source ,
585+ score = score ,
586+ )
587+ )
588+
589+ logger .info ("Retrieved %d chunks from vector DB" , len (rag_chunks ))
590+
591+ # Format RAG context for injection into user message
592+ rag_context = ""
593+ if rag_chunks :
594+ context_chunks = []
595+ for chunk in rag_chunks [:5 ]: # Limit to top 5 chunks
596+ chunk_text = f"Source: { chunk .source or 'Unknown' } \n { chunk .content } "
597+ context_chunks .append (chunk_text )
598+ rag_context = "\n \n Relevant documentation:\n " + "\n \n " .join (context_chunks )
599+ logger .info ("Injecting %d RAG chunks into user message" , len (context_chunks ))
600+
423601 # Prepare input for Responses API
424602 # Convert attachments to text and concatenate with query
425603 input_text = query_request .query
@@ -430,6 +608,9 @@ async def retrieve_response( # pylint: disable=too-many-locals
430608 f"{ attachment .content } "
431609 )
432610
611+ # Add RAG context to input text
612+ input_text += rag_context
613+
433614 # Handle conversation ID for Responses API
434615 # Create conversation upfront if not provided
435616 conversation_id = query_request .conversation_id
@@ -475,4 +656,8 @@ async def retrieve_response( # pylint: disable=too-many-locals
475656 response = await client .responses .create (** create_params )
476657 response_stream = cast (AsyncIterator [OpenAIResponseObjectStream ], response )
477658
478- return response_stream , normalize_conversation_id (conversation_id )
659+ return (
660+ response_stream ,
661+ normalize_conversation_id (conversation_id ),
662+ doc_ids_from_chunks ,
663+ )
0 commit comments