33"""Handler for REST API call to provide answer to query using Response API."""
44
55import datetime
6- from typing import Annotated , Any , Optional , cast
6+ from typing import Annotated , Any , cast
77
88from fastapi import APIRouter , Depends , HTTPException , Request
99from llama_stack_api .openai_responses import OpenAIResponseObject
5252)
5353from utils .quota import check_tokens_available , get_available_quotas
5454from utils .responses import (
55- build_tool_call_summary ,
56- extract_text_from_response_output_item ,
57- extract_token_usage ,
55+ build_turn_summary ,
56+ deduplicate_referenced_documents ,
5857 extract_vector_store_ids_from_tools ,
5958 get_topic_summary ,
60- parse_referenced_documents ,
6159 prepare_responses_params ,
6260)
6361from utils .shields import (
6664)
6765from utils .suid import normalize_conversation_id
6866from utils .types import (
67+ RAGChunk ,
6968 ResponsesApiParams ,
7069 TurnSummary ,
7170)
@@ -130,7 +129,9 @@ async def query_endpoint_handler(
130129 check_tokens_available (configuration .quota_limiters , user_id )
131130
132131 # Enforce RBAC: optionally disallow overriding model/provider in requests
133- validate_model_provider_override (query_request , request .state .authorized_actions )
132+ validate_model_provider_override (
133+ query_request .model , query_request .provider , request .state .authorized_actions
134+ )
134135
135136 # Validate attachments if provided
136137 if query_request .attachments :
@@ -153,7 +154,7 @@ async def query_endpoint_handler(
153154 client = AsyncLlamaStackClientHolder ().get_client ()
154155
155156 doc_ids_from_chunks : list [ReferencedDocument ] = []
156- pre_rag_chunks : list [Any ] = [] # use your RAGChunk type (or the upstream one)
157+ pre_rag_chunks : list [RAGChunk ] = []
157158
158159 _ , _ , doc_ids_from_chunks , pre_rag_chunks = await perform_vector_search (
159160 client , query_request , configuration
@@ -198,7 +199,7 @@ async def query_endpoint_handler(
198199 turn_summary .rag_chunks = pre_rag_chunks + (turn_summary .rag_chunks or [])
199200
200201 if doc_ids_from_chunks :
201- turn_summary .referenced_documents = parse_referenced_docs (
202+ turn_summary .referenced_documents = deduplicate_referenced_documents (
202203 doc_ids_from_chunks + (turn_summary .referenced_documents or [])
203204 )
204205
@@ -216,7 +217,6 @@ async def query_endpoint_handler(
216217 user_id = user_id ,
217218 model_id = responses_params .model ,
218219 token_usage = turn_summary .token_usage ,
219- configuration = configuration ,
220220 )
221221
222222 logger .info ("Getting available quotas" )
@@ -238,7 +238,6 @@ async def query_endpoint_handler(
238238 completed_at = completed_at ,
239239 summary = turn_summary ,
240240 query_request = query_request ,
241- configuration = configuration ,
242241 skip_userid_check = _skip_userid_check ,
243242 topic_summary = topic_summary ,
244243 )
@@ -258,26 +257,11 @@ async def query_endpoint_handler(
258257 )
259258
260259
261- def parse_referenced_docs (
262- docs : list [ReferencedDocument ],
263- ) -> list [ReferencedDocument ]:
264- """Remove duplicate referenced documents based on URL and title."""
265- seen : set [tuple [str | None , str | None ]] = set ()
266- out : list [ReferencedDocument ] = []
267- for d in docs :
268- key = (str (d .doc_url ) if d .doc_url else None , d .doc_title )
269- if key in seen :
270- continue
271- seen .add (key )
272- out .append (d )
273- return out
274-
275-
276260async def retrieve_response ( # pylint: disable=too-many-locals
277261 client : AsyncLlamaStackClient ,
278262 responses_params : ResponsesApiParams ,
279- vector_store_ids : Optional [ list [str ]] = None ,
280- rag_id_mapping : Optional [ dict [str , str ]] = None ,
263+ vector_store_ids : list [str ] | None = None ,
264+ rag_id_mapping : dict [str , str ] | None = None ,
281265) -> TurnSummary :
282266 """
283267 Retrieve response from LLMs and agents.
@@ -294,8 +278,6 @@ async def retrieve_response( # pylint: disable=too-many-locals
294278 Returns:
295279 TurnSummary: Summary of the LLM response content
296280 """
297- summary = TurnSummary ()
298-
299281 try :
300282 moderation_result = await run_shield_moderation (client , responses_params .input )
301283 if moderation_result .blocked :
@@ -307,8 +289,7 @@ async def retrieve_response( # pylint: disable=too-many-locals
307289 responses_params .input ,
308290 violation_message ,
309291 )
310- summary .llm_response = violation_message
311- return summary
292+ return TurnSummary (llm_response = violation_message )
312293 response = await client .responses .create (** responses_params .model_dump ())
313294 response = cast (OpenAIResponseObject , response )
314295
@@ -327,30 +308,6 @@ async def retrieve_response( # pylint: disable=too-many-locals
327308 error_response = handle_known_apistatus_errors (e , responses_params .model )
328309 raise HTTPException (** error_response .model_dump ()) from e
329310
330- # Process OpenAI response format
331- for output_item in response .output :
332- message_text = extract_text_from_response_output_item (output_item )
333- if message_text :
334- summary .llm_response += message_text
335-
336- tool_call , tool_result = build_tool_call_summary (
337- output_item , summary .rag_chunks , vector_store_ids , rag_id_mapping
338- )
339- if tool_call :
340- summary .tool_calls .append (tool_call )
341- if tool_result :
342- summary .tool_results .append (tool_result )
343-
344- logger .info (
345- "Response processing complete - Tool calls: %d, Response length: %d chars" ,
346- len (summary .tool_calls ),
347- len (summary .llm_response ),
348- )
349-
350- # Extract referenced documents and token usage from Responses API response
351- summary .referenced_documents = parse_referenced_documents (
352- response , vector_store_ids , rag_id_mapping
311+ return build_turn_summary (
312+ response , responses_params .model , vector_store_ids , rag_id_mapping
353313 )
354- summary .token_usage = extract_token_usage (response , responses_params .model )
355-
356- return summary
0 commit comments