@@ -103,13 +103,7 @@ def _build_instructions(systeminfo: RlsapiV1SystemInfo) -> str:
103103 Returns:
104104 Instructions string for the LLM, with system context if available.
105105 """
106- if (
107- configuration .customization is not None
108- and configuration .customization .system_prompt is not None
109- ):
110- base_prompt = configuration .customization .system_prompt
111- else :
112- base_prompt = constants .DEFAULT_SYSTEM_PROMPT
106+ base_prompt = _get_base_prompt ()
113107
114108 context_parts = []
115109 if systeminfo .os :
@@ -126,6 +120,16 @@ def _build_instructions(systeminfo: RlsapiV1SystemInfo) -> str:
126120 return f"{ base_prompt } \n \n User's system: { system_context } "
127121
128122
123+ def _get_base_prompt () -> str :
124+ """Get the base system prompt with configuration fallback."""
125+ if (
126+ configuration .customization is not None
127+ and configuration .customization .system_prompt is not None
128+ ):
129+ return configuration .customization .system_prompt
130+ return constants .DEFAULT_SYSTEM_PROMPT
131+
132+
129133def _get_default_model_id () -> str :
130134 """Get the default model ID from configuration.
131135
@@ -162,7 +166,10 @@ def _get_default_model_id() -> str:
162166
163167
164168async def retrieve_simple_response (
165- question : str , instructions : str , tools : list | None = None
169+ question : str ,
170+ instructions : str ,
171+ tools : list [Any ] | None = None ,
172+ model_id : str | None = None ,
166173) -> str :
167174 """Retrieve a simple response from the LLM for a stateless query.
168175
@@ -173,22 +180,23 @@ async def retrieve_simple_response(
173180 question: The combined user input (question + context).
174181 instructions: System instructions for the LLM.
175182 tools: Optional list of MCP tool definitions for the LLM.
183+ model_id: Fully qualified model identifier in provider/model format.
184+ When omitted, the configured default model is used.
176185
177186 Returns:
178187 The LLM-generated response text.
179188
180189 Raises:
181190 APIConnectionError: If the Llama Stack service is unreachable.
182- HTTPException: 503 if no model is configured.
191+ HTTPException: 503 if no default model is configured.
183192 """
184193 client = AsyncLlamaStackClientHolder ().get_client ()
185- model_id = _get_default_model_id ()
186-
187- logger .debug ("Using model %s for rlsapi v1 inference" , model_id )
194+ resolved_model_id = model_id or _get_default_model_id ()
195+ logger .debug ("Using model %s for rlsapi v1 inference" , resolved_model_id )
188196
189197 response = await client .responses .create (
190198 input = question ,
191- model = model_id ,
199+ model = resolved_model_id ,
192200 instructions = instructions ,
193201 tools = tools or [],
194202 stream = False ,
@@ -205,6 +213,13 @@ def _get_cla_version(request: Request) -> str:
205213 return request .headers .get ("User-Agent" , "" )
206214
207215
216+ def _get_configured_default_model_name () -> str :
217+ """Get configured default model name for telemetry payloads."""
218+ if configuration .inference is None :
219+ return ""
220+ return configuration .inference .default_model or ""
221+
222+
208223def _queue_splunk_event ( # pylint: disable=too-many-arguments,too-many-positional-arguments
209224 background_tasks : BackgroundTasks ,
210225 infer_request : RlsapiV1InferRequest ,
@@ -222,11 +237,7 @@ def _queue_splunk_event( # pylint: disable=too-many-arguments,too-many-position
222237 question = infer_request .question ,
223238 response = response_text ,
224239 inference_time = inference_time ,
225- model = (
226- (configuration .inference .default_model or "" )
227- if configuration .inference
228- else ""
229- ),
240+ model = _get_configured_default_model_name (),
230241 org_id = org_id ,
231242 system_id = system_id ,
232243 request_id = request_id ,
@@ -277,6 +288,49 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
277288 return inference_time
278289
279290
291+ def _map_inference_error_to_http_exception (
292+ error : Exception , model_id : str , request_id : str
293+ ) -> HTTPException | None :
294+ """Map known inference errors to HTTPException.
295+
296+ Returns None for RuntimeError values that are not context-length related,
297+ so callers can preserve existing re-raise behavior for unknown runtime
298+ errors.
299+ """
300+ if isinstance (error , RuntimeError ):
301+ if "context_length" in str (error ).lower ():
302+ logger .error ("Prompt too long for request %s: %s" , request_id , error )
303+ error_response = PromptTooLongResponse (model = model_id )
304+ return HTTPException (** error_response .model_dump ())
305+ logger .error ("Unexpected RuntimeError for request %s: %s" , request_id , error )
306+ return None
307+
308+ if isinstance (error , APIConnectionError ):
309+ logger .error (
310+ "Unable to connect to Llama Stack for request %s: %s" , request_id , error
311+ )
312+ error_response = ServiceUnavailableResponse (
313+ backend_name = "Llama Stack" ,
314+ cause = "Unable to connect to the inference backend" ,
315+ )
316+ return HTTPException (** error_response .model_dump ())
317+
318+ if isinstance (error , RateLimitError ):
319+ logger .error ("Rate limit exceeded for request %s: %s" , request_id , error )
320+ error_response = QuotaExceededResponse (
321+ response = "The quota has been exceeded" ,
322+ cause = "Rate limit exceeded, please try again later" ,
323+ )
324+ return HTTPException (** error_response .model_dump ())
325+
326+ if isinstance (error , (APIStatusError , OpenAIAPIStatusError )):
327+ logger .exception ("API error for request %s: %s" , request_id , error )
328+ error_response = handle_known_apistatus_errors (error , model_id )
329+ return HTTPException (** error_response .model_dump ())
330+
331+ return None
332+
333+
280334@router .post ("/infer" , responses = infer_responses )
281335@authorize (Action .RLSAPI_V1_INFER )
282336async def infer_endpoint ( # pylint: disable=R0914
@@ -323,86 +377,37 @@ async def infer_endpoint( # pylint: disable=R0914
323377 start_time = time .monotonic ()
324378 try :
325379 response_text = await retrieve_simple_response (
326- input_source , instructions , tools = mcp_tools
380+ input_source ,
381+ instructions ,
382+ tools = cast (list [Any ], mcp_tools ),
383+ model_id = model_id ,
327384 )
328385 inference_time = time .monotonic () - start_time
329- except RuntimeError as e :
330- if "context_length" in str (e ).lower ():
331- _record_inference_failure (
332- background_tasks ,
333- infer_request ,
334- request ,
335- request_id ,
336- e ,
337- start_time ,
338- model ,
339- provider ,
340- )
341- logger .error ("Prompt too long for request %s: %s" , request_id , e )
342- error_response = PromptTooLongResponse (model = model_id )
343- raise HTTPException (** error_response .model_dump ()) from e
344- _record_inference_failure (
345- background_tasks ,
346- infer_request ,
347- request ,
348- request_id ,
349- e ,
350- start_time ,
351- model ,
352- provider ,
353- )
354- logger .error ("Unexpected RuntimeError for request %s: %s" , request_id , e )
355- raise
356- except APIConnectionError as e :
357- _record_inference_failure (
358- background_tasks ,
359- infer_request ,
360- request ,
361- request_id ,
362- e ,
363- start_time ,
364- model ,
365- provider ,
366- )
367- logger .error (
368- "Unable to connect to Llama Stack for request %s: %s" , request_id , e
369- )
370- error_response = ServiceUnavailableResponse (
371- backend_name = "Llama Stack" ,
372- cause = "Unable to connect to the inference backend" ,
373- )
374- raise HTTPException (** error_response .model_dump ()) from e
375- except RateLimitError as e :
386+ except (
387+ RuntimeError ,
388+ APIConnectionError ,
389+ RateLimitError ,
390+ APIStatusError ,
391+ OpenAIAPIStatusError ,
392+ ) as error :
376393 _record_inference_failure (
377394 background_tasks ,
378395 infer_request ,
379396 request ,
380397 request_id ,
381- e ,
398+ error ,
382399 start_time ,
383400 model ,
384401 provider ,
385402 )
386- logger .error ("Rate limit exceeded for request %s: %s" , request_id , e )
387- error_response = QuotaExceededResponse (
388- response = "The quota has been exceeded" ,
389- cause = "Rate limit exceeded, please try again later" ,
390- )
391- raise HTTPException (** error_response .model_dump ()) from e
392- except (APIStatusError , OpenAIAPIStatusError ) as e :
393- _record_inference_failure (
394- background_tasks ,
395- infer_request ,
396- request ,
403+ mapped_error = _map_inference_error_to_http_exception (
404+ error ,
405+ model_id ,
397406 request_id ,
398- e ,
399- start_time ,
400- model ,
401- provider ,
402407 )
403- logger . exception ( "API error for request %s: %s" , request_id , e )
404- error_response = handle_known_apistatus_errors ( e , model_id )
405- raise HTTPException ( ** error_response . model_dump ()) from e
408+ if mapped_error is not None :
409+ raise mapped_error from error
410+ raise
406411
407412 if not response_text :
408413 logger .warning ("Empty response from LLM for request %s" , request_id )
0 commit comments