1+ # pylint: disable=too-many-locals,too-many-branches,too-many-nested-blocks
2+
3+ """Handler for REST API call to provide answer using Responses API (LCORE specification)."""
4+
5+ import json
6+ import logging
7+ from datetime import UTC , datetime
8+ import sqlite3
9+ from typing import Annotated , Any , AsyncIterator , Optional , Union , cast
10+
11+ from fastapi import APIRouter , Depends , HTTPException , Request
12+ from fastapi .responses import StreamingResponse
13+ from llama_stack_api .openai_responses import (
14+ OpenAIResponseObject ,
15+ OpenAIResponseObjectStream ,
16+ )
17+ from llama_stack_client import (
18+ APIConnectionError ,
19+ APIStatusError as LLSApiStatusError ,
20+ )
21+ from openai ._exceptions import (
22+ APIStatusError as OpenAIAPIStatusError ,
23+ )
24+ import psycopg2
25+ from sqlalchemy .exc import SQLAlchemyError
26+
27+ from authentication import get_auth_dependency
28+ from authentication .interface import AuthTuple
29+ from authorization .azure_token_manager import AzureEntraIDManager
30+ from authorization .middleware import authorize
31+ from cache .cache_error import CacheError
32+ from client import AsyncLlamaStackClientHolder
33+ from configuration import configuration
34+ from models .cache_entry import CacheEntry
35+ from models .config import Action
36+ from models .requests import ResponsesRequest
37+ from models .responses import (
38+ ForbiddenResponse ,
39+ InternalServerErrorResponse ,
40+ NotFoundResponse ,
41+ PromptTooLongResponse ,
42+ ResponsesResponse ,
43+ QuotaExceededResponse ,
44+ ServiceUnavailableResponse ,
45+ UnauthorizedResponse ,
46+ UnprocessableEntityResponse ,
47+ )
48+ from utils .endpoints import (
49+ check_configuration_loaded ,
50+ validate_and_retrieve_conversation ,
51+ )
52+ from utils .mcp_headers import mcp_headers_dependency
53+ from utils .query import (
54+ consume_query_tokens ,
55+ handle_known_apistatus_errors ,
56+ persist_user_conversation_details ,
57+ store_conversation_into_cache ,
58+ store_query_results ,
59+ update_azure_token ,
60+ )
61+ from utils .quota import check_tokens_available , get_available_quotas
62+ from utils .responses import (
63+ extract_text_from_input ,
64+ extract_token_usage ,
65+ get_topic_summary ,
66+ select_model_for_responses ,
67+ validate_model_override_permissions ,
68+ )
69+ from utils .shields import (
70+ append_turn_to_conversation ,
71+ run_shield_moderation ,
72+ )
73+ from utils .suid import normalize_conversation_id , to_llama_stack_conversation_id
74+
75+ logger = logging .getLogger ("app.endpoints.handlers" )
76+ router = APIRouter (tags = ["responses" ])
77+
78+ responses_response : dict [int | str , dict [str , Any ]] = {
79+ 200 : ResponsesResponse .openapi_response (),
80+ 401 : UnauthorizedResponse .openapi_response (
81+ examples = ["missing header" , "missing token" ]
82+ ),
83+ 403 : ForbiddenResponse .openapi_response (
84+ examples = ["endpoint" , "conversation read" , "model override" ]
85+ ),
86+ 404 : NotFoundResponse .openapi_response (
87+ examples = ["model" , "conversation" , "provider" ]
88+ ),
89+ 413 : PromptTooLongResponse .openapi_response (),
90+ 422 : UnprocessableEntityResponse .openapi_response (),
91+ 429 : QuotaExceededResponse .openapi_response (),
92+ 500 : InternalServerErrorResponse .openapi_response (examples = ["configuration" ]),
93+ 503 : ServiceUnavailableResponse .openapi_response (),
94+ }
95+
96+
97+ @router .post (
98+ "/responses" ,
99+ responses = responses_response ,
100+ summary = "Responses Endpoint Handler" ,
101+ )
102+ @authorize (Action .QUERY )
103+ async def responses_endpoint_handler (
104+ request : Request ,
105+ responses_request : ResponsesRequest ,
106+ auth : Annotated [AuthTuple , Depends (get_auth_dependency ())],
107+ mcp_headers : dict [str , dict [str , str ]] = Depends (mcp_headers_dependency ),
108+ ) -> Union [ResponsesResponse , StreamingResponse ]:
109+ """
110+ Handle request to the /responses endpoint using Responses API (LCORE specification).
111+
112+ Processes a POST request to the responses endpoint, forwarding the
113+ user's request to a selected Llama Stack LLM and returning the generated response
114+ following the LCORE OpenAPI specification.
115+
116+ Returns:
117+ ResponsesResponse: Contains the response following LCORE specification (non-streaming).
118+ StreamingResponse: SSE-formatted streaming response with enriched events (streaming).
119+ - response.created event includes conversation attribute
120+ - response.completed event includes available_quotas attribute
121+
122+ Raises:
123+ HTTPException:
124+ - 401: Unauthorized - Missing or invalid credentials
125+ - 403: Forbidden - Insufficient permissions or model override not allowed
126+ - 404: Not Found - Conversation, model, or provider not found
127+ - 413: Prompt too long - Prompt exceeded model's context window size
128+ - 422: Unprocessable Entity - Request validation failed
129+ - 429: Quota limit exceeded - The token quota for model or user has been exceeded
130+ - 500: Internal Server Error - Configuration not loaded or other server errors
131+ - 503: Service Unavailable - Unable to connect to Llama Stack backend
132+ """
133+ check_configuration_loaded (configuration )
134+
135+ started_at = datetime .now (UTC ).strftime ("%Y-%m-%dT%H:%M:%SZ" )
136+ user_id , _ , _skip_userid_check , token = auth
137+
138+ # Check token availability
139+ check_tokens_available (configuration .quota_limiters , user_id )
140+
141+ # Enforce RBAC: optionally disallow overriding model/provider in requests
142+ if responses_request .model :
143+ validate_model_override_permissions (
144+ responses_request .model ,
145+ request .state .authorized_actions ,
146+ )
147+
148+ user_conversation = None
149+ if responses_request .conversation :
150+ logger .debug (
151+ "Conversation ID specified in request: %s" , responses_request .conversation
152+ )
153+ user_conversation = validate_and_retrieve_conversation (
154+ normalized_conv_id = normalize_conversation_id (responses_request .conversation ),
155+ user_id = user_id ,
156+ others_allowed = Action .READ_OTHERS_CONVERSATIONS
157+ in request .state .authorized_actions ,
158+ )
159+ # Convert to llama-stack format if needed
160+ responses_request .conversation = to_llama_stack_conversation_id (user_conversation .id )
161+
162+ client = AsyncLlamaStackClientHolder ().get_client ()
163+
164+ # LCORE-specific: Automatically select model if not provided in request
165+ # This extends the base LLS API which requires model to be specified.
166+ if not responses_request .model :
167+ responses_request .model = await select_model_for_responses (
168+ client , user_conversation
169+ )
170+
171+ # Prepare API request parameters
172+ api_params = responses_request .model_dump (
173+ exclude_none = True , exclude = {"generate_topic_summary" }
174+ )
175+
176+ # Handle Azure token refresh if needed
177+ if (
178+ api_params ["model" ].startswith ("azure" )
179+ and AzureEntraIDManager ().is_entra_id_configured
180+ and AzureEntraIDManager ().is_token_expired
181+ and AzureEntraIDManager ().refresh_token ()
182+ ):
183+ client = await update_azure_token (client )
184+
185+ # Retrieve response using Responses API
186+ try :
187+ # Extract text from input for shield moderation (input can be string or complex object)
188+ input_text_for_moderation = extract_text_from_input (responses_request .input )
189+ moderation_result = await run_shield_moderation (
190+ client , input_text_for_moderation
191+ )
192+ if moderation_result .blocked :
193+ violation_message = moderation_result .message or ""
194+ if responses_request .conversation :
195+ await append_turn_to_conversation (
196+ client ,
197+ responses_request .conversation ,
198+ input_text_for_moderation ,
199+ violation_message ,
200+ )
201+ return ResponsesResponse .model_construct (
202+ status = "blocked" ,
203+ text = violation_message ,
204+ error = {"message" : violation_message },
205+ conversation = responses_request .conversation ,
206+ )
207+
208+ response = await client .responses .create (** api_params )
209+
210+ # Handle streaming response
211+ if responses_request .stream :
212+ stream_iterator = cast (AsyncIterator [OpenAIResponseObjectStream ], response )
213+ return StreamingResponse (
214+ _stream_responses (
215+ stream_iterator ,
216+ responses_request .conversation ,
217+ user_id ,
218+ api_params .get ("model" , "" ),
219+ ),
220+ media_type = "text/event-stream" ,
221+ )
222+
223+ response = cast (OpenAIResponseObject , response )
224+
225+ except RuntimeError as e : # library mode wraps 413 into runtime error
226+ if "context_length" in str (e ).lower ():
227+ error_response = PromptTooLongResponse (model = api_params .get ("model" , "" ))
228+ raise HTTPException (** error_response .model_dump ()) from e
229+ raise e
230+ except APIConnectionError as e :
231+ error_response = ServiceUnavailableResponse (
232+ backend_name = "Llama Stack" ,
233+ cause = str (e ),
234+ )
235+ raise HTTPException (** error_response .model_dump ()) from e
236+ except (LLSApiStatusError , OpenAIAPIStatusError ) as e :
237+ error_response = handle_known_apistatus_errors (e , api_params .get ("model" , "" ))
238+ raise HTTPException (** error_response .model_dump ()) from e
239+
240+ # Extract token usage
241+ token_usage = extract_token_usage (response , api_params ["model" ])
242+
243+ # Consume tokens
244+ logger .info ("Consuming tokens" )
245+ consume_query_tokens (
246+ user_id = user_id ,
247+ model_id = api_params ["model" ],
248+ token_usage = token_usage ,
249+ configuration = configuration ,
250+ )
251+
252+ # Get available quotas
253+ logger .info ("Getting available quotas" )
254+ available_quotas = get_available_quotas (
255+ quota_limiters = configuration .quota_limiters , user_id = user_id
256+ )
257+
258+ # Get topic summary for new conversation
259+ if not user_conversation and responses_request .generate_topic_summary :
260+ logger .debug ("Generating topic summary for new conversation" )
261+ topic_summary = await get_topic_summary (
262+ extract_text_from_input (responses_request .input ), client , api_params ["model" ]
263+ )
264+ else :
265+ topic_summary = None
266+
267+ try :
268+ logger .info ("Persisting conversation details" )
269+ # Extract provider_id from model_id (format: "provider/model")
270+ persist_user_conversation_details (
271+ user_id = user_id ,
272+ conversation_id = responses_request .conversation or "" ,
273+ model_id = api_params ["model" ],
274+ provider_id = api_params ["model" ].split ("/" )[0 ], # type: ignore
275+ topic_summary = topic_summary ,
276+ )
277+ except SQLAlchemyError as e :
278+ logger .exception ("Error persisting conversation details." )
279+ response = InternalServerErrorResponse .database_error ()
280+ raise HTTPException (** response .model_dump ()) from e
281+
282+ # Store conversation in cache
283+ try :
284+ completed_at = datetime .now (UTC ).strftime ("%Y-%m-%dT%H:%M:%SZ" )
285+ cache_entry = CacheEntry (
286+ query = extract_text_from_input (responses_request .input ),
287+ response = "" ,
288+ provider = api_params ["model" ].split ("/" )[0 ], # type: ignore
289+ model = api_params ["model" ],
290+ started_at = started_at ,
291+ completed_at = completed_at ,
292+ referenced_documents = None ,
293+ tool_calls = None ,
294+ tool_results = None ,
295+ )
296+
297+ logger .info ("Storing conversation in cache" )
298+ store_conversation_into_cache (
299+ config = configuration ,
300+ user_id = user_id ,
301+ conversation_id = responses_request .conversation or "" ,
302+ cache_entry = cache_entry ,
303+ _skip_userid_check = _skip_userid_check ,
304+ topic_summary = topic_summary ,
305+ )
306+ except (CacheError , ValueError , psycopg2 .Error , sqlite3 .Error ) as e :
307+ logger .exception ("Error storing conversation in cache: %s" , e )
308+ response = InternalServerErrorResponse .database_error ()
309+ raise HTTPException (** response .model_dump ()) from e
310+
311+ # Extract response fields using model_dump, excluding output/text which are handled separately
312+ response_dict = cast (OpenAIResponseObject , response ).model_dump ()
313+
314+ logger .info ("Building final response" )
315+ return ResponsesResponse (
316+ ** response_dict ,
317+ conversation = responses_request .conversation ,
318+ available_quotas = available_quotas ,
319+ )
320+
321+
322+ async def _stream_responses (
323+ stream : AsyncIterator [OpenAIResponseObjectStream ],
324+ conversation_id : Optional [str ],
325+ user_id : str ,
326+ model_id : str ,
327+ ) -> AsyncIterator [str ]:
328+ """Generate SSE-formatted streaming response with LCORE-enriched events.
329+
330+ Processes streaming chunks from Llama Stack and converts them to
331+ Server-Sent Events (SSE) format, enriching response.created with conversation
332+ and response.completed with available_quotas. All other events are forwarded
333+ exactly as received from the stream.
334+
335+ Args:
336+ stream: The streaming response from Llama Stack
337+ conversation_id: The conversation ID to include in response.created
338+ user_id: User ID for quota retrieval
339+ model_id: Model ID for token usage extraction
340+
341+ Yields:
342+ SSE-formatted strings for streaming events.
343+ """
344+ normalized_conv_id = normalize_conversation_id (conversation_id ) if conversation_id else None
345+ latest_response_object : Optional [OpenAIResponseObject ] = None
346+
347+ async for chunk in stream :
348+ event_type = getattr (chunk , "type" , None )
349+ logger .debug ("Processing streaming chunk, type: %s" , event_type )
350+
351+ # Get the original chunk data as dict (exact same structure as original)
352+ chunk_dict = chunk .model_dump () if hasattr (chunk , "model_dump" ) else {}
353+
354+ # Enrich response.created event with conversation attribute
355+ if event_type == "response.created" :
356+ response_obj = getattr (chunk , "response" , None )
357+ if response_obj :
358+ latest_response_object = cast (OpenAIResponseObject , response_obj )
359+
360+ # Add conversation attribute to the original chunk data
361+ if normalized_conv_id :
362+ chunk_dict ["conversation" ] = normalized_conv_id
363+
364+ # Enrich response.completed event with available_quotas attribute
365+ elif event_type == "response.completed" :
366+ response_obj = getattr (chunk , "response" , None )
367+ if response_obj :
368+ latest_response_object = cast (OpenAIResponseObject , response_obj )
369+
370+ # Extract token usage
371+ token_usage_obj = None
372+ if latest_response_object :
373+ token_usage_obj = extract_token_usage (latest_response_object , model_id )
374+
375+ # Get available quotas
376+ available_quotas = get_available_quotas (
377+ quota_limiters = configuration .quota_limiters , user_id = user_id
378+ )
379+
380+ # Consume tokens
381+ if token_usage_obj and latest_response_object :
382+ consume_query_tokens (
383+ user_id = user_id ,
384+ model_id = model_id ,
385+ token_usage = token_usage_obj ,
386+ configuration = configuration ,
387+ )
388+
389+ # Add available_quotas attribute to the original chunk data
390+ if available_quotas :
391+ chunk_dict ["available_quotas" ] = available_quotas
392+
393+ yield json .dumps (chunk_dict )
0 commit comments