55"""
66
77import time
8- from typing import Annotated , Any , cast
8+ from typing import Annotated , Any , Optional , cast
99
1010from fastapi import APIRouter , BackgroundTasks , Depends , HTTPException , Request
1111from llama_stack_api .openai_responses import OpenAIResponseObject
5050
5151# Default values when RH Identity auth is not configured
5252AUTH_DISABLED = "auth_disabled"
53+ # Keep this tuple centralized so infer_endpoint can catch all expected backend
54+ # failures in one place while preserving a single telemetry/error-mapping path.
55+ _INFER_HANDLED_EXCEPTIONS = (
56+ RuntimeError ,
57+ APIConnectionError ,
58+ RateLimitError ,
59+ APIStatusError ,
60+ OpenAIAPIStatusError ,
61+ )
5362
5463
5564def _get_rh_identity_context (request : Request ) -> tuple [str , str ]:
@@ -66,7 +75,7 @@ def _get_rh_identity_context(request: Request) -> tuple[str, str]:
6675 Tuple of (org_id, system_id). Returns ("auth_disabled", "auth_disabled")
6776 when RH Identity auth is not configured or data is unavailable.
6877 """
69- rh_identity : RHIdentityData | None = getattr (
78+ rh_identity : Optional [ RHIdentityData ] = getattr (
7079 request .state , "rh_identity_data" , None
7180 )
7281 if rh_identity is None :
@@ -168,8 +177,8 @@ def _get_default_model_id() -> str:
168177async def retrieve_simple_response (
169178 question : str ,
170179 instructions : str ,
171- tools : list [Any ] | None = None ,
172- model_id : str | None = None ,
180+ tools : Optional [ list [Any ]] = None ,
181+ model_id : Optional [ str ] = None ,
173182) -> str :
174183 """Retrieve a simple response from the LLM for a stateless query.
175184
@@ -203,7 +212,7 @@ async def retrieve_simple_response(
203212 store = False ,
204213 )
205214 response = cast (OpenAIResponseObject , response )
206- extract_token_usage (response .usage , model_id )
215+ extract_token_usage (response .usage , resolved_model_id )
207216
208217 return extract_text_from_response_items (response .output )
209218
@@ -290,15 +299,16 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
290299
291300def _map_inference_error_to_http_exception (
292301 error : Exception , model_id : str , request_id : str
293- ) -> HTTPException | None :
302+ ) -> Optional [ HTTPException ] :
294303 """Map known inference errors to HTTPException.
295304
296305 Returns None for RuntimeError values that are not context-length related,
297306 so callers can preserve existing re-raise behavior for unknown runtime
298307 errors.
299308 """
300309 if isinstance (error , RuntimeError ):
301- if "context_length" in str (error ).lower ():
310+ error_message = str (error ).lower ()
311+ if "context_length" in error_message or "context length" in error_message :
302312 logger .error ("Prompt too long for request %s: %s" , request_id , error )
303313 error_response = PromptTooLongResponse (model = model_id )
304314 return HTTPException (** error_response .model_dump ())
@@ -369,7 +379,7 @@ async def infer_endpoint( # pylint: disable=R0914
369379 instructions = _build_instructions (infer_request .context .systeminfo )
370380 model_id = _get_default_model_id ()
371381 provider , model = extract_provider_and_model_from_model_id (model_id )
372- mcp_tools = await get_mcp_tools (request_headers = request .headers )
382+ mcp_tools : list [ Any ] = await get_mcp_tools (request_headers = request .headers )
373383 logger .debug (
374384 "Request %s: Combined input source length: %d" , request_id , len (input_source )
375385 )
@@ -383,13 +393,7 @@ async def infer_endpoint( # pylint: disable=R0914
383393 model_id = model_id ,
384394 )
385395 inference_time = time .monotonic () - start_time
386- except (
387- RuntimeError ,
388- APIConnectionError ,
389- RateLimitError ,
390- APIStatusError ,
391- OpenAIAPIStatusError ,
392- ) as error :
396+ except _INFER_HANDLED_EXCEPTIONS as error :
393397 _record_inference_failure (
394398 background_tasks ,
395399 infer_request ,
0 commit comments