Skip to content

Commit 8500949

Browse files
committed
refactor(rlsapi): simplify infer flow and error mapping
Signed-off-by: Major Hayden <major@redhat.com>
1 parent 6528bc2 commit 8500949

1 file changed

Lines changed: 96 additions & 58 deletions

File tree

src/app/endpoints/rlsapi_v1.py

Lines changed: 96 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,7 @@ def _build_instructions(systeminfo: RlsapiV1SystemInfo) -> str:
9999
Returns:
100100
Instructions string for the LLM, with system context if available.
101101
"""
102-
if (
103-
configuration.customization is not None
104-
and configuration.customization.system_prompt is not None
105-
):
106-
base_prompt = configuration.customization.system_prompt
107-
else:
108-
base_prompt = constants.DEFAULT_SYSTEM_PROMPT
102+
base_prompt = _get_base_prompt()
109103

110104
context_parts = []
111105
if systeminfo.os:
@@ -122,6 +116,16 @@ def _build_instructions(systeminfo: RlsapiV1SystemInfo) -> str:
122116
return f"{base_prompt}\n\nUser's system: {system_context}"
123117

124118

119+
def _get_base_prompt() -> str:
120+
"""Get the base system prompt with configuration fallback."""
121+
if (
122+
configuration.customization is not None
123+
and configuration.customization.system_prompt is not None
124+
):
125+
return configuration.customization.system_prompt
126+
return constants.DEFAULT_SYSTEM_PROMPT
127+
128+
125129
def _get_default_model_id() -> str:
126130
"""Get the default model ID from configuration.
127131
@@ -158,7 +162,10 @@ def _get_default_model_id() -> str:
158162

159163

160164
async def retrieve_simple_response(
161-
question: str, instructions: str, tools: list | None = None
165+
question: str,
166+
instructions: str,
167+
tools: list[Any] | None = None,
168+
model_id: str | None = None,
162169
) -> str:
163170
"""Retrieve a simple response from the LLM for a stateless query.
164171
@@ -169,22 +176,23 @@ async def retrieve_simple_response(
169176
question: The combined user input (question + context).
170177
instructions: System instructions for the LLM.
171178
tools: Optional list of MCP tool definitions for the LLM.
179+
model_id: Fully qualified model identifier in provider/model format.
180+
When omitted, the configured default model is used.
172181
173182
Returns:
174183
The LLM-generated response text.
175184
176185
Raises:
177186
APIConnectionError: If the Llama Stack service is unreachable.
178-
HTTPException: 503 if no model is configured.
187+
HTTPException: 503 if no default model is configured.
179188
"""
180189
client = AsyncLlamaStackClientHolder().get_client()
181-
model_id = _get_default_model_id()
182-
183-
logger.debug("Using model %s for rlsapi v1 inference", model_id)
190+
resolved_model_id = model_id or _get_default_model_id()
191+
logger.debug("Using model %s for rlsapi v1 inference", resolved_model_id)
184192

185193
response = await client.responses.create(
186194
input=question,
187-
model=model_id,
195+
model=resolved_model_id,
188196
instructions=instructions,
189197
tools=tools or [],
190198
stream=False,
@@ -200,6 +208,13 @@ def _get_cla_version(request: Request) -> str:
200208
return request.headers.get("User-Agent", "")
201209

202210

211+
def _get_configured_default_model_name() -> str:
212+
"""Get configured default model name for telemetry payloads."""
213+
if configuration.inference is None:
214+
return ""
215+
return configuration.inference.default_model or ""
216+
217+
203218
def _queue_splunk_event( # pylint: disable=too-many-arguments,too-many-positional-arguments
204219
background_tasks: BackgroundTasks,
205220
infer_request: RlsapiV1InferRequest,
@@ -217,11 +232,7 @@ def _queue_splunk_event( # pylint: disable=too-many-arguments,too-many-position
217232
question=infer_request.question,
218233
response=response_text,
219234
inference_time=inference_time,
220-
model=(
221-
(configuration.inference.default_model or "")
222-
if configuration.inference
223-
else ""
224-
),
235+
model=_get_configured_default_model_name(),
225236
org_id=org_id,
226237
system_id=system_id,
227238
request_id=request_id,
@@ -270,6 +281,49 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
270281
return inference_time
271282

272283

284+
def _map_inference_error_to_http_exception(
285+
error: Exception, model_id: str, request_id: str
286+
) -> HTTPException | None:
287+
"""Map known inference errors to HTTPException.
288+
289+
Returns None for RuntimeError values that are not context-length related,
290+
so callers can preserve existing re-raise behavior for unknown runtime
291+
errors.
292+
"""
293+
if isinstance(error, RuntimeError):
294+
if "context_length" in str(error).lower():
295+
logger.error("Prompt too long for request %s: %s", request_id, error)
296+
error_response = PromptTooLongResponse(model=model_id)
297+
return HTTPException(**error_response.model_dump())
298+
logger.error("Unexpected RuntimeError for request %s: %s", request_id, error)
299+
return None
300+
301+
if isinstance(error, APIConnectionError):
302+
logger.error(
303+
"Unable to connect to Llama Stack for request %s: %s", request_id, error
304+
)
305+
error_response = ServiceUnavailableResponse(
306+
backend_name="Llama Stack",
307+
cause="Unable to connect to the inference backend",
308+
)
309+
return HTTPException(**error_response.model_dump())
310+
311+
if isinstance(error, RateLimitError):
312+
logger.error("Rate limit exceeded for request %s: %s", request_id, error)
313+
error_response = QuotaExceededResponse(
314+
response="The quota has been exceeded",
315+
cause="Rate limit exceeded, please try again later",
316+
)
317+
return HTTPException(**error_response.model_dump())
318+
319+
if isinstance(error, (APIStatusError, OpenAIAPIStatusError)):
320+
logger.exception("API error for request %s: %s", request_id, error)
321+
error_response = handle_known_apistatus_errors(error, model_id)
322+
return HTTPException(**error_response.model_dump())
323+
324+
return None
325+
326+
273327
@router.post("/infer", responses=infer_responses)
274328
@authorize(Action.RLSAPI_V1_INFER)
275329
async def infer_endpoint(
@@ -315,51 +369,35 @@ async def infer_endpoint(
315369
start_time = time.monotonic()
316370
try:
317371
response_text = await retrieve_simple_response(
318-
input_source, instructions, tools=mcp_tools
372+
input_source,
373+
instructions,
374+
tools=cast(list[Any], mcp_tools),
375+
model_id=model_id,
319376
)
320377
inference_time = time.monotonic() - start_time
321-
except RuntimeError as e:
322-
if "context_length" in str(e).lower():
323-
_record_inference_failure(
324-
background_tasks, infer_request, request, request_id, e, start_time
325-
)
326-
logger.error("Prompt too long for request %s: %s", request_id, e)
327-
error_response = PromptTooLongResponse(model=model_id)
328-
raise HTTPException(**error_response.model_dump()) from e
329-
_record_inference_failure(
330-
background_tasks, infer_request, request, request_id, e, start_time
331-
)
332-
logger.error("Unexpected RuntimeError for request %s: %s", request_id, e)
333-
raise
334-
except APIConnectionError as e:
335-
_record_inference_failure(
336-
background_tasks, infer_request, request, request_id, e, start_time
337-
)
338-
logger.error(
339-
"Unable to connect to Llama Stack for request %s: %s", request_id, e
340-
)
341-
error_response = ServiceUnavailableResponse(
342-
backend_name="Llama Stack",
343-
cause="Unable to connect to the inference backend",
344-
)
345-
raise HTTPException(**error_response.model_dump()) from e
346-
except RateLimitError as e:
378+
except (
379+
RuntimeError,
380+
APIConnectionError,
381+
RateLimitError,
382+
APIStatusError,
383+
OpenAIAPIStatusError,
384+
) as error:
347385
_record_inference_failure(
348-
background_tasks, infer_request, request, request_id, e, start_time
386+
background_tasks,
387+
infer_request,
388+
request,
389+
request_id,
390+
error,
391+
start_time,
349392
)
350-
logger.error("Rate limit exceeded for request %s: %s", request_id, e)
351-
error_response = QuotaExceededResponse(
352-
response="The quota has been exceeded",
353-
cause="Rate limit exceeded, please try again later",
393+
mapped_error = _map_inference_error_to_http_exception(
394+
error,
395+
model_id,
396+
request_id,
354397
)
355-
raise HTTPException(**error_response.model_dump()) from e
356-
except (APIStatusError, OpenAIAPIStatusError) as e:
357-
_record_inference_failure(
358-
background_tasks, infer_request, request, request_id, e, start_time
359-
)
360-
logger.exception("API error for request %s: %s", request_id, e)
361-
error_response = handle_known_apistatus_errors(e, model_id)
362-
raise HTTPException(**error_response.model_dump()) from e
398+
if mapped_error is not None:
399+
raise mapped_error from error
400+
raise
363401

364402
if not response_text:
365403
logger.warning("Empty response from LLM for request %s", request_id)

0 commit comments

Comments
 (0)