Skip to content

Commit acf7023

Browse files
committed
fix(rlsapi): improve exception handling and prevent sensitive data leakage
Add missing exception handlers for RuntimeError (context_length → 413) and OpenAIAPIStatusError. Use handle_known_apistatus_errors() for smarter status code mapping instead of generic 500s. Replace raw str(e) in client-facing cause fields with safe generic messages while preserving full details in server-side logs. Extract common error bookkeeping into _record_inference_failure() helper to reduce duplication. Signed-off-by: Major Hayden <major@redhat.com>
1 parent 227e504 commit acf7023

2 files changed

Lines changed: 126 additions & 41 deletions

File tree

src/app/endpoints/rlsapi_v1.py

Lines changed: 71 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request
1111
from llama_stack_api.openai_responses import OpenAIResponseObject
1212
from llama_stack_client import APIConnectionError, APIStatusError, RateLimitError
13+
from openai._exceptions import APIStatusError as OpenAIAPIStatusError
1314

1415
import constants
1516
import metrics
@@ -23,6 +24,7 @@
2324
from models.responses import (
2425
ForbiddenResponse,
2526
InternalServerErrorResponse,
27+
PromptTooLongResponse,
2628
QuotaExceededResponse,
2729
ServiceUnavailableResponse,
2830
UnauthorizedResponse,
@@ -31,6 +33,7 @@
3133
from models.rlsapi.requests import RlsapiV1InferRequest, RlsapiV1SystemInfo
3234
from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse
3335
from observability import InferenceEventData, build_inference_event, send_splunk_event
36+
from utils.query import handle_known_apistatus_errors
3437
from utils.responses import extract_text_from_response_output_item, get_mcp_tools
3538
from utils.suid import get_suid
3639
from log import get_logger
@@ -73,6 +76,7 @@ def _get_rh_identity_context(request: Request) -> tuple[str, str]:
7376
examples=["missing header", "missing token"]
7477
),
7578
403: ForbiddenResponse.openapi_response(examples=["endpoint"]),
79+
413: PromptTooLongResponse.openapi_response(),
7680
422: UnprocessableEntityResponse.openapi_response(),
7781
429: QuotaExceededResponse.openapi_response(),
7882
500: InternalServerErrorResponse.openapi_response(examples=["generic"]),
@@ -229,6 +233,41 @@ def _queue_splunk_event( # pylint: disable=too-many-arguments,too-many-position
229233
background_tasks.add_task(send_splunk_event, event, sourcetype)
230234

231235

236+
def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-positional-arguments
237+
background_tasks: BackgroundTasks,
238+
infer_request: RlsapiV1InferRequest,
239+
request: Request,
240+
request_id: str,
241+
error: Exception,
242+
start_time: float,
243+
) -> float:
244+
"""Record metrics and queue Splunk event for an inference failure.
245+
246+
Args:
247+
background_tasks: FastAPI background tasks for async event sending.
248+
infer_request: The original inference request.
249+
request: The FastAPI request object.
250+
request_id: Unique identifier for the request.
251+
error: The exception that caused the failure.
252+
start_time: Monotonic clock time when inference started.
253+
254+
Returns:
255+
The total inference time in seconds.
256+
"""
257+
inference_time = time.monotonic() - start_time
258+
metrics.llm_calls_failures_total.inc()
259+
_queue_splunk_event(
260+
background_tasks,
261+
infer_request,
262+
request,
263+
request_id,
264+
str(error),
265+
inference_time,
266+
"infer_error",
267+
)
268+
return inference_time
269+
270+
232271
@router.post("/infer", responses=infer_responses)
233272
@authorize(Action.RLSAPI_V1_INFER)
234273
async def infer_endpoint(
@@ -265,6 +304,7 @@ async def infer_endpoint(
265304

266305
input_source = infer_request.get_input_source()
267306
instructions = _build_instructions(infer_request.context.systeminfo)
307+
model_id = _get_default_model_id()
268308
mcp_tools = get_mcp_tools(configuration.mcp_servers)
269309
logger.debug(
270310
"Request %s: Combined input source length: %d", request_id, len(input_source)
@@ -276,58 +316,48 @@ async def infer_endpoint(
276316
input_source, instructions, tools=mcp_tools
277317
)
278318
inference_time = time.monotonic() - start_time
319+
except RuntimeError as e:
320+
if "context_length" in str(e).lower():
321+
_record_inference_failure(
322+
background_tasks, infer_request, request, request_id, e, start_time
323+
)
324+
logger.error("Prompt too long for request %s: %s", request_id, e)
325+
error_response = PromptTooLongResponse(model=model_id)
326+
raise HTTPException(**error_response.model_dump()) from e
327+
_record_inference_failure(
328+
background_tasks, infer_request, request, request_id, e, start_time
329+
)
330+
logger.error("Unexpected RuntimeError for request %s: %s", request_id, e)
331+
raise
279332
except APIConnectionError as e:
280-
inference_time = time.monotonic() - start_time
281-
metrics.llm_calls_failures_total.inc()
333+
_record_inference_failure(
334+
background_tasks, infer_request, request, request_id, e, start_time
335+
)
282336
logger.error(
283337
"Unable to connect to Llama Stack for request %s: %s", request_id, e
284338
)
285-
_queue_splunk_event(
286-
background_tasks,
287-
infer_request,
288-
request,
289-
request_id,
290-
str(e),
291-
inference_time,
292-
"infer_error",
293-
)
294-
response = ServiceUnavailableResponse(
339+
error_response = ServiceUnavailableResponse(
295340
backend_name="Llama Stack",
296-
cause=str(e),
341+
cause="Unable to connect to the inference backend",
297342
)
298-
raise HTTPException(**response.model_dump()) from e
343+
raise HTTPException(**error_response.model_dump()) from e
299344
except RateLimitError as e:
300-
inference_time = time.monotonic() - start_time
301-
metrics.llm_calls_failures_total.inc()
345+
_record_inference_failure(
346+
background_tasks, infer_request, request, request_id, e, start_time
347+
)
302348
logger.error("Rate limit exceeded for request %s: %s", request_id, e)
303-
_queue_splunk_event(
304-
background_tasks,
305-
infer_request,
306-
request,
307-
request_id,
308-
str(e),
309-
inference_time,
310-
"infer_error",
349+
error_response = QuotaExceededResponse(
350+
response="The quota has been exceeded",
351+
cause="Rate limit exceeded, please try again later",
311352
)
312-
response = QuotaExceededResponse(
313-
response="The quota has been exceeded", cause=str(e)
353+
raise HTTPException(**error_response.model_dump()) from e
354+
except (APIStatusError, OpenAIAPIStatusError) as e:
355+
_record_inference_failure(
356+
background_tasks, infer_request, request, request_id, e, start_time
314357
)
315-
raise HTTPException(**response.model_dump()) from e
316-
except APIStatusError as e:
317-
inference_time = time.monotonic() - start_time
318-
metrics.llm_calls_failures_total.inc()
319358
logger.exception("API error for request %s: %s", request_id, e)
320-
_queue_splunk_event(
321-
background_tasks,
322-
infer_request,
323-
request,
324-
request_id,
325-
str(e),
326-
inference_time,
327-
"infer_error",
328-
)
329-
response = InternalServerErrorResponse.generic()
330-
raise HTTPException(**response.model_dump()) from e
359+
error_response = handle_known_apistatus_errors(e, model_id)
360+
raise HTTPException(**error_response.model_dump()) from e
331361

332362
if not response_text:
333363
logger.warning("Empty response from LLM for request %s", request_id)

tests/unit/app/endpoints/test_rlsapi_v1.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,15 @@ def mock_api_connection_error_fixture(mocker: MockerFixture) -> None:
126126
)
127127

128128

129+
@pytest.fixture(name="mock_generic_runtime_error")
130+
def mock_generic_runtime_error_fixture(mocker: MockerFixture) -> None:
131+
"""Mock responses.create() to raise a non-context-length RuntimeError."""
132+
_setup_responses_mock(
133+
mocker,
134+
mocker.AsyncMock(side_effect=RuntimeError("something went wrong")),
135+
)
136+
137+
129138
# --- Test _build_instructions ---
130139

131140

@@ -656,3 +665,49 @@ async def test_infer_endpoint_calls_get_mcp_tools(
656665
)
657666

658667
mock_get_mcp_tools.assert_called_once_with(mock_configuration.mcp_servers)
668+
669+
670+
@pytest.mark.asyncio
671+
async def test_infer_generic_runtime_error_reraises(
672+
mocker: MockerFixture,
673+
mock_configuration: AppConfig,
674+
mock_generic_runtime_error: None,
675+
mock_auth_resolvers: None,
676+
) -> None:
677+
"""Test /infer endpoint re-raises non-context-length RuntimeErrors."""
678+
infer_request = RlsapiV1InferRequest(question="Test question")
679+
mock_request = _create_mock_request(mocker)
680+
mock_background_tasks = _create_mock_background_tasks(mocker)
681+
682+
with pytest.raises(RuntimeError, match="something went wrong"):
683+
await infer_endpoint(
684+
infer_request=infer_request,
685+
request=mock_request,
686+
background_tasks=mock_background_tasks,
687+
auth=MOCK_AUTH,
688+
)
689+
690+
691+
@pytest.mark.asyncio
692+
async def test_infer_generic_runtime_error_records_failure(
693+
mocker: MockerFixture,
694+
mock_configuration: AppConfig,
695+
mock_generic_runtime_error: None,
696+
mock_auth_resolvers: None,
697+
) -> None:
698+
"""Test that non-context-length RuntimeErrors record inference failure metrics."""
699+
infer_request = RlsapiV1InferRequest(question="Test question")
700+
mock_request = _create_mock_request(mocker)
701+
mock_background_tasks = _create_mock_background_tasks(mocker)
702+
703+
with pytest.raises(RuntimeError):
704+
await infer_endpoint(
705+
infer_request=infer_request,
706+
request=mock_request,
707+
background_tasks=mock_background_tasks,
708+
auth=MOCK_AUTH,
709+
)
710+
711+
mock_background_tasks.add_task.assert_called_once()
712+
call_args = mock_background_tasks.add_task.call_args
713+
assert call_args[0][2] == "infer_error"

0 commit comments

Comments
 (0)