Skip to content

Commit 8e6f0a8

Browse files
committed
refactor(rlsapi): simplify infer flow and error mapping
Signed-off-by: Major Hayden <major@redhat.com>
1 parent 579818f commit 8e6f0a8

1 file changed

Lines changed: 90 additions & 85 deletions

File tree

src/app/endpoints/rlsapi_v1.py

Lines changed: 90 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,7 @@ def _build_instructions(systeminfo: RlsapiV1SystemInfo) -> str:
103103
Returns:
104104
Instructions string for the LLM, with system context if available.
105105
"""
106-
if (
107-
configuration.customization is not None
108-
and configuration.customization.system_prompt is not None
109-
):
110-
base_prompt = configuration.customization.system_prompt
111-
else:
112-
base_prompt = constants.DEFAULT_SYSTEM_PROMPT
106+
base_prompt = _get_base_prompt()
113107

114108
context_parts = []
115109
if systeminfo.os:
@@ -126,6 +120,16 @@ def _build_instructions(systeminfo: RlsapiV1SystemInfo) -> str:
126120
return f"{base_prompt}\n\nUser's system: {system_context}"
127121

128122

123+
def _get_base_prompt() -> str:
124+
"""Get the base system prompt with configuration fallback."""
125+
if (
126+
configuration.customization is not None
127+
and configuration.customization.system_prompt is not None
128+
):
129+
return configuration.customization.system_prompt
130+
return constants.DEFAULT_SYSTEM_PROMPT
131+
132+
129133
def _get_default_model_id() -> str:
130134
"""Get the default model ID from configuration.
131135
@@ -162,7 +166,10 @@ def _get_default_model_id() -> str:
162166

163167

164168
async def retrieve_simple_response(
165-
question: str, instructions: str, tools: list | None = None
169+
question: str,
170+
instructions: str,
171+
tools: list[Any] | None = None,
172+
model_id: str | None = None,
166173
) -> str:
167174
"""Retrieve a simple response from the LLM for a stateless query.
168175
@@ -173,22 +180,23 @@ async def retrieve_simple_response(
173180
question: The combined user input (question + context).
174181
instructions: System instructions for the LLM.
175182
tools: Optional list of MCP tool definitions for the LLM.
183+
model_id: Fully qualified model identifier in provider/model format.
184+
When omitted, the configured default model is used.
176185
177186
Returns:
178187
The LLM-generated response text.
179188
180189
Raises:
181190
APIConnectionError: If the Llama Stack service is unreachable.
182-
HTTPException: 503 if no model is configured.
191+
HTTPException: 503 if no default model is configured.
183192
"""
184193
client = AsyncLlamaStackClientHolder().get_client()
185-
model_id = _get_default_model_id()
186-
187-
logger.debug("Using model %s for rlsapi v1 inference", model_id)
194+
resolved_model_id = model_id or _get_default_model_id()
195+
logger.debug("Using model %s for rlsapi v1 inference", resolved_model_id)
188196

189197
response = await client.responses.create(
190198
input=question,
191-
model=model_id,
199+
model=resolved_model_id,
192200
instructions=instructions,
193201
tools=tools or [],
194202
stream=False,
@@ -205,6 +213,13 @@ def _get_cla_version(request: Request) -> str:
205213
return request.headers.get("User-Agent", "")
206214

207215

216+
def _get_configured_default_model_name() -> str:
217+
"""Get configured default model name for telemetry payloads."""
218+
if configuration.inference is None:
219+
return ""
220+
return configuration.inference.default_model or ""
221+
222+
208223
def _queue_splunk_event( # pylint: disable=too-many-arguments,too-many-positional-arguments
209224
background_tasks: BackgroundTasks,
210225
infer_request: RlsapiV1InferRequest,
@@ -222,11 +237,7 @@ def _queue_splunk_event( # pylint: disable=too-many-arguments,too-many-position
222237
question=infer_request.question,
223238
response=response_text,
224239
inference_time=inference_time,
225-
model=(
226-
(configuration.inference.default_model or "")
227-
if configuration.inference
228-
else ""
229-
),
240+
model=_get_configured_default_model_name(),
230241
org_id=org_id,
231242
system_id=system_id,
232243
request_id=request_id,
@@ -277,6 +288,49 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
277288
return inference_time
278289

279290

291+
def _map_inference_error_to_http_exception(
292+
error: Exception, model_id: str, request_id: str
293+
) -> HTTPException | None:
294+
"""Map known inference errors to HTTPException.
295+
296+
Returns None for RuntimeError values that are not context-length related,
297+
so callers can preserve existing re-raise behavior for unknown runtime
298+
errors.
299+
"""
300+
if isinstance(error, RuntimeError):
301+
if "context_length" in str(error).lower():
302+
logger.error("Prompt too long for request %s: %s", request_id, error)
303+
error_response = PromptTooLongResponse(model=model_id)
304+
return HTTPException(**error_response.model_dump())
305+
logger.error("Unexpected RuntimeError for request %s: %s", request_id, error)
306+
return None
307+
308+
if isinstance(error, APIConnectionError):
309+
logger.error(
310+
"Unable to connect to Llama Stack for request %s: %s", request_id, error
311+
)
312+
error_response = ServiceUnavailableResponse(
313+
backend_name="Llama Stack",
314+
cause="Unable to connect to the inference backend",
315+
)
316+
return HTTPException(**error_response.model_dump())
317+
318+
if isinstance(error, RateLimitError):
319+
logger.error("Rate limit exceeded for request %s: %s", request_id, error)
320+
error_response = QuotaExceededResponse(
321+
response="The quota has been exceeded",
322+
cause="Rate limit exceeded, please try again later",
323+
)
324+
return HTTPException(**error_response.model_dump())
325+
326+
if isinstance(error, (APIStatusError, OpenAIAPIStatusError)):
327+
logger.exception("API error for request %s: %s", request_id, error)
328+
error_response = handle_known_apistatus_errors(error, model_id)
329+
return HTTPException(**error_response.model_dump())
330+
331+
return None
332+
333+
280334
@router.post("/infer", responses=infer_responses)
281335
@authorize(Action.RLSAPI_V1_INFER)
282336
async def infer_endpoint( # pylint: disable=R0914
@@ -323,86 +377,37 @@ async def infer_endpoint( # pylint: disable=R0914
323377
start_time = time.monotonic()
324378
try:
325379
response_text = await retrieve_simple_response(
326-
input_source, instructions, tools=mcp_tools
380+
input_source,
381+
instructions,
382+
tools=cast(list[Any], mcp_tools),
383+
model_id=model_id,
327384
)
328385
inference_time = time.monotonic() - start_time
329-
except RuntimeError as e:
330-
if "context_length" in str(e).lower():
331-
_record_inference_failure(
332-
background_tasks,
333-
infer_request,
334-
request,
335-
request_id,
336-
e,
337-
start_time,
338-
model,
339-
provider,
340-
)
341-
logger.error("Prompt too long for request %s: %s", request_id, e)
342-
error_response = PromptTooLongResponse(model=model_id)
343-
raise HTTPException(**error_response.model_dump()) from e
344-
_record_inference_failure(
345-
background_tasks,
346-
infer_request,
347-
request,
348-
request_id,
349-
e,
350-
start_time,
351-
model,
352-
provider,
353-
)
354-
logger.error("Unexpected RuntimeError for request %s: %s", request_id, e)
355-
raise
356-
except APIConnectionError as e:
357-
_record_inference_failure(
358-
background_tasks,
359-
infer_request,
360-
request,
361-
request_id,
362-
e,
363-
start_time,
364-
model,
365-
provider,
366-
)
367-
logger.error(
368-
"Unable to connect to Llama Stack for request %s: %s", request_id, e
369-
)
370-
error_response = ServiceUnavailableResponse(
371-
backend_name="Llama Stack",
372-
cause="Unable to connect to the inference backend",
373-
)
374-
raise HTTPException(**error_response.model_dump()) from e
375-
except RateLimitError as e:
386+
except (
387+
RuntimeError,
388+
APIConnectionError,
389+
RateLimitError,
390+
APIStatusError,
391+
OpenAIAPIStatusError,
392+
) as error:
376393
_record_inference_failure(
377394
background_tasks,
378395
infer_request,
379396
request,
380397
request_id,
381-
e,
398+
error,
382399
start_time,
383400
model,
384401
provider,
385402
)
386-
logger.error("Rate limit exceeded for request %s: %s", request_id, e)
387-
error_response = QuotaExceededResponse(
388-
response="The quota has been exceeded",
389-
cause="Rate limit exceeded, please try again later",
390-
)
391-
raise HTTPException(**error_response.model_dump()) from e
392-
except (APIStatusError, OpenAIAPIStatusError) as e:
393-
_record_inference_failure(
394-
background_tasks,
395-
infer_request,
396-
request,
403+
mapped_error = _map_inference_error_to_http_exception(
404+
error,
405+
model_id,
397406
request_id,
398-
e,
399-
start_time,
400-
model,
401-
provider,
402407
)
403-
logger.exception("API error for request %s: %s", request_id, e)
404-
error_response = handle_known_apistatus_errors(e, model_id)
405-
raise HTTPException(**error_response.model_dump()) from e
408+
if mapped_error is not None:
409+
raise mapped_error from error
410+
raise
406411

407412
if not response_text:
408413
logger.warning("Empty response from LLM for request %s", request_id)

0 commit comments

Comments
 (0)