Skip to content

Commit 87301b6

Browse files
committed
fix(rlsapi): improve exception handling and prevent sensitive data leakage
Add missing exception handlers for RuntimeError (context_length → 413) and OpenAIAPIStatusError. Use handle_known_apistatus_errors() for smarter status code mapping instead of generic 500s. Replace raw str(e) in client-facing cause fields with safe generic messages while preserving full details in server-side logs. Extract common error bookkeeping into _record_inference_failure() helper to reduce duplication. Signed-off-by: Major Hayden <major@redhat.com>
1 parent 227e504 commit 87301b6

1 file changed

Lines changed: 67 additions & 41 deletions

File tree

src/app/endpoints/rlsapi_v1.py

Lines changed: 67 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request
1111
from llama_stack_api.openai_responses import OpenAIResponseObject
1212
from llama_stack_client import APIConnectionError, APIStatusError, RateLimitError
13+
from openai._exceptions import APIStatusError as OpenAIAPIStatusError
1314

1415
import constants
1516
import metrics
@@ -23,6 +24,7 @@
2324
from models.responses import (
2425
ForbiddenResponse,
2526
InternalServerErrorResponse,
27+
PromptTooLongResponse,
2628
QuotaExceededResponse,
2729
ServiceUnavailableResponse,
2830
UnauthorizedResponse,
@@ -31,6 +33,7 @@
3133
from models.rlsapi.requests import RlsapiV1InferRequest, RlsapiV1SystemInfo
3234
from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse
3335
from observability import InferenceEventData, build_inference_event, send_splunk_event
36+
from utils.query import handle_known_apistatus_errors
3437
from utils.responses import extract_text_from_response_output_item, get_mcp_tools
3538
from utils.suid import get_suid
3639
from log import get_logger
@@ -73,6 +76,7 @@ def _get_rh_identity_context(request: Request) -> tuple[str, str]:
7376
examples=["missing header", "missing token"]
7477
),
7578
403: ForbiddenResponse.openapi_response(examples=["endpoint"]),
79+
413: PromptTooLongResponse.openapi_response(),
7680
422: UnprocessableEntityResponse.openapi_response(),
7781
429: QuotaExceededResponse.openapi_response(),
7882
500: InternalServerErrorResponse.openapi_response(examples=["generic"]),
@@ -229,6 +233,41 @@ def _queue_splunk_event( # pylint: disable=too-many-arguments,too-many-position
229233
background_tasks.add_task(send_splunk_event, event, sourcetype)
230234

231235

236+
def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-positional-arguments
237+
background_tasks: BackgroundTasks,
238+
infer_request: RlsapiV1InferRequest,
239+
request: Request,
240+
request_id: str,
241+
error: Exception,
242+
start_time: float,
243+
) -> float:
244+
"""Record metrics and queue Splunk event for an inference failure.
245+
246+
Args:
247+
background_tasks: FastAPI background tasks for async event sending.
248+
infer_request: The original inference request.
249+
request: The FastAPI request object.
250+
request_id: Unique identifier for the request.
251+
error: The exception that caused the failure.
252+
start_time: Monotonic clock time when inference started.
253+
254+
Returns:
255+
The total inference time in seconds.
256+
"""
257+
inference_time = time.monotonic() - start_time
258+
metrics.llm_calls_failures_total.inc()
259+
_queue_splunk_event(
260+
background_tasks,
261+
infer_request,
262+
request,
263+
request_id,
264+
str(error),
265+
inference_time,
266+
"infer_error",
267+
)
268+
return inference_time
269+
270+
232271
@router.post("/infer", responses=infer_responses)
233272
@authorize(Action.RLSAPI_V1_INFER)
234273
async def infer_endpoint(
@@ -265,6 +304,7 @@ async def infer_endpoint(
265304

266305
input_source = infer_request.get_input_source()
267306
instructions = _build_instructions(infer_request.context.systeminfo)
307+
model_id = _get_default_model_id()
268308
mcp_tools = get_mcp_tools(configuration.mcp_servers)
269309
logger.debug(
270310
"Request %s: Combined input source length: %d", request_id, len(input_source)
@@ -276,58 +316,44 @@ async def infer_endpoint(
276316
input_source, instructions, tools=mcp_tools
277317
)
278318
inference_time = time.monotonic() - start_time
319+
except RuntimeError as e:
320+
if "context_length" in str(e).lower():
321+
_record_inference_failure(
322+
background_tasks, infer_request, request, request_id, e, start_time
323+
)
324+
logger.error("Prompt too long for request %s: %s", request_id, e)
325+
error_response = PromptTooLongResponse(model=model_id)
326+
raise HTTPException(**error_response.model_dump()) from e
327+
raise
279328
except APIConnectionError as e:
280-
inference_time = time.monotonic() - start_time
281-
metrics.llm_calls_failures_total.inc()
329+
_record_inference_failure(
330+
background_tasks, infer_request, request, request_id, e, start_time
331+
)
282332
logger.error(
283333
"Unable to connect to Llama Stack for request %s: %s", request_id, e
284334
)
285-
_queue_splunk_event(
286-
background_tasks,
287-
infer_request,
288-
request,
289-
request_id,
290-
str(e),
291-
inference_time,
292-
"infer_error",
293-
)
294-
response = ServiceUnavailableResponse(
335+
error_response = ServiceUnavailableResponse(
295336
backend_name="Llama Stack",
296-
cause=str(e),
337+
cause="Unable to connect to the inference backend",
297338
)
298-
raise HTTPException(**response.model_dump()) from e
339+
raise HTTPException(**error_response.model_dump()) from e
299340
except RateLimitError as e:
300-
inference_time = time.monotonic() - start_time
301-
metrics.llm_calls_failures_total.inc()
341+
_record_inference_failure(
342+
background_tasks, infer_request, request, request_id, e, start_time
343+
)
302344
logger.error("Rate limit exceeded for request %s: %s", request_id, e)
303-
_queue_splunk_event(
304-
background_tasks,
305-
infer_request,
306-
request,
307-
request_id,
308-
str(e),
309-
inference_time,
310-
"infer_error",
345+
error_response = QuotaExceededResponse(
346+
response="The quota has been exceeded",
347+
cause="Rate limit exceeded, please try again later",
311348
)
312-
response = QuotaExceededResponse(
313-
response="The quota has been exceeded", cause=str(e)
349+
raise HTTPException(**error_response.model_dump()) from e
350+
except (APIStatusError, OpenAIAPIStatusError) as e:
351+
_record_inference_failure(
352+
background_tasks, infer_request, request, request_id, e, start_time
314353
)
315-
raise HTTPException(**response.model_dump()) from e
316-
except APIStatusError as e:
317-
inference_time = time.monotonic() - start_time
318-
metrics.llm_calls_failures_total.inc()
319354
logger.exception("API error for request %s: %s", request_id, e)
320-
_queue_splunk_event(
321-
background_tasks,
322-
infer_request,
323-
request,
324-
request_id,
325-
str(e),
326-
inference_time,
327-
"infer_error",
328-
)
329-
response = InternalServerErrorResponse.generic()
330-
raise HTTPException(**response.model_dump()) from e
355+
error_response = handle_known_apistatus_errors(e, model_id)
356+
raise HTTPException(**error_response.model_dump()) from e
331357

332358
if not response_text:
333359
logger.warning("Empty response from LLM for request %s", request_id)

0 commit comments

Comments
 (0)