Skip to content

Commit 31b094c

Browse files
committed
fix(rlsapi_v1): handle RuntimeError for context length in infer endpoint
- Add RuntimeError catch block matching query.py and streaming_query.py pattern - Return 413 with PromptTooLongResponse when context_length error detected - Re-raise non-context-length RuntimeErrors for proper middleware handling - Add unit tests for both context_length and other RuntimeError scenarios Signed-off-by: Major Hayden <major@redhat.com>
1 parent 81e303a commit 31b094c

6 files changed

Lines changed: 118 additions & 56 deletions

File tree

src/app/endpoints/query.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,13 @@ async def retrieve_response( # pylint: disable=too-many-locals
257257
response = await client.responses.create(**responses_params.model_dump())
258258
response = cast(OpenAIResponseObject, response)
259259

260-
except RuntimeError as e: # library mode wraps 413 into runtime error
260+
except RuntimeError as e: # library mode wraps HTTP errors as RuntimeError
261261
if "context_length" in str(e).lower():
262262
error_response = PromptTooLongResponse(model=responses_params.model)
263263
raise HTTPException(**error_response.model_dump()) from e
264-
raise e
264+
logger.exception("RuntimeError during inference")
265+
error_response = InternalServerErrorResponse.generic()
266+
raise HTTPException(**error_response.model_dump()) from e
265267
except APIConnectionError as e:
266268
error_response = ServiceUnavailableResponse(
267269
backend_name="Llama Stack",

src/app/endpoints/rlsapi_v1.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from models.responses import (
2525
ForbiddenResponse,
2626
InternalServerErrorResponse,
27+
PromptTooLongResponse,
2728
QuotaExceededResponse,
2829
ServiceUnavailableResponse,
2930
UnauthorizedResponse,
@@ -270,6 +271,26 @@ async def infer_endpoint(
270271
input_source, instructions, tools=mcp_tools
271272
)
272273
inference_time = time.monotonic() - start_time
274+
except RuntimeError as e:
275+
# Library mode wraps HTTP errors as RuntimeError
276+
inference_time = time.monotonic() - start_time
277+
metrics.llm_calls_failures_total.inc()
278+
_queue_splunk_event(
279+
background_tasks,
280+
infer_request,
281+
request,
282+
request_id,
283+
str(e),
284+
inference_time,
285+
"infer_error",
286+
)
287+
if "context_length" in str(e).lower():
288+
logger.error("Prompt too long for request %s: %s", request_id, e)
289+
error_response = PromptTooLongResponse(model=_get_default_model_id())
290+
raise HTTPException(**error_response.model_dump()) from e
291+
logger.exception("RuntimeError during inference for request %s", request_id)
292+
response = InternalServerErrorResponse.generic()
293+
raise HTTPException(**response.model_dump()) from e
273294
except APIConnectionError as e:
274295
inference_time = time.monotonic() - start_time
275296
metrics.llm_calls_failures_total.inc()

src/app/endpoints/streaming_query.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -269,11 +269,13 @@ async def retrieve_response_generator(
269269
return response_generator(response, context, turn_summary), turn_summary
270270

271271
# Handle know LLS client errors only at stream creation time and shield execution
272-
except RuntimeError as e: # library mode wraps 413 into runtime error
272+
except RuntimeError as e: # library mode wraps HTTP errors as RuntimeError
273273
if "context_length" in str(e).lower():
274274
error_response = PromptTooLongResponse(model=responses_params.model)
275275
raise HTTPException(**error_response.model_dump()) from e
276-
raise e
276+
logger.exception("RuntimeError during streaming inference")
277+
error_response = InternalServerErrorResponse.generic()
278+
raise HTTPException(**error_response.model_dump()) from e
277279
except APIConnectionError as e:
278280
error_response = ServiceUnavailableResponse(
279281
backend_name="Llama Stack",
@@ -407,9 +409,9 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
407409
chunk_id = 0
408410
media_type = context.query_request.media_type or MEDIA_TYPE_JSON
409411
text_parts: list[str] = []
410-
mcp_calls: dict[int, tuple[str, str]] = (
411-
{}
412-
) # output_index -> (mcp_call_id, mcp_call_name)
412+
mcp_calls: dict[
413+
int, tuple[str, str]
414+
] = {} # output_index -> (mcp_call_id, mcp_call_name)
413415
latest_response_object: Optional[OpenAIResponseObject] = None
414416

415417
logger.debug("Starting streaming response (Responses API) processing")

tests/unit/app/endpoints/test_query.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Any
88

99
import pytest
10-
from fastapi import HTTPException, Request
10+
from fastapi import HTTPException, Request, status
1111
from llama_stack_api.openai_responses import OpenAIResponseObject
1212
from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient
1313
from pytest_mock import MockerFixture
@@ -110,9 +110,7 @@ async def test_successful_query_no_conversation(
110110
mocker: MockerFixture,
111111
) -> None:
112112
"""Test successful query without existing conversation."""
113-
query_request = QueryRequest(
114-
query="What is Kubernetes?"
115-
) # pyright: ignore[reportCallIssue]
113+
query_request = QueryRequest(query="What is Kubernetes?") # pyright: ignore[reportCallIssue]
116114

117115
mocker.patch("app.endpoints.query.configuration", setup_configuration)
118116
mocker.patch("app.endpoints.query.check_configuration_loaded")
@@ -386,9 +384,7 @@ async def test_query_azure_token_refresh(
386384
mocker: MockerFixture,
387385
) -> None:
388386
"""Test query refreshes Azure token when needed."""
389-
query_request = QueryRequest(
390-
query="What is Kubernetes?"
391-
) # pyright: ignore[reportCallIssue]
387+
query_request = QueryRequest(query="What is Kubernetes?") # pyright: ignore[reportCallIssue]
392388

393389
mocker.patch("app.endpoints.query.configuration", setup_configuration)
394390
mocker.patch("app.endpoints.query.check_configuration_loaded")
@@ -659,8 +655,9 @@ async def test_retrieve_response_runtime_error_other(
659655
side_effect=RuntimeError("Some other error")
660656
)
661657

662-
with pytest.raises(RuntimeError):
658+
with pytest.raises(HTTPException) as exc_info:
663659
await retrieve_response(mock_client, mock_responses_params)
660+
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
664661

665662
@pytest.mark.asyncio
666663
async def test_retrieve_response_with_tool_calls(

tests/unit/app/endpoints/test_rlsapi_v1.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,26 @@ def mock_api_connection_error_fixture(mocker: MockerFixture) -> None:
126126
)
127127

128128

129+
@pytest.fixture(name="mock_runtime_error_context_length")
130+
def mock_runtime_error_context_length_fixture(mocker: MockerFixture) -> None:
131+
"""Mock responses.create() to raise RuntimeError with context_length message."""
132+
_setup_responses_mock(
133+
mocker,
134+
mocker.AsyncMock(
135+
side_effect=RuntimeError("context_length exceeded maximum tokens")
136+
),
137+
)
138+
139+
140+
@pytest.fixture(name="mock_runtime_error_other")
141+
def mock_runtime_error_other_fixture(mocker: MockerFixture) -> None:
142+
"""Mock responses.create() to raise RuntimeError with non-context_length message."""
143+
_setup_responses_mock(
144+
mocker,
145+
mocker.AsyncMock(side_effect=RuntimeError("Some other runtime error")),
146+
)
147+
148+
129149
# --- Test _build_instructions ---
130150

131151

@@ -400,6 +420,51 @@ async def test_infer_api_connection_error_returns_503(
400420
assert exc_info.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
401421

402422

423+
@pytest.mark.asyncio
424+
async def test_infer_runtime_error_context_length_returns_413(
425+
mocker: MockerFixture,
426+
mock_configuration: AppConfig,
427+
mock_runtime_error_context_length: None,
428+
mock_auth_resolvers: None,
429+
) -> None:
430+
"""Test /infer returns 413 when LLM raises RuntimeError with context_length."""
431+
infer_request = RlsapiV1InferRequest(question="Test question")
432+
mock_request = _create_mock_request(mocker)
433+
mock_background_tasks = _create_mock_background_tasks(mocker)
434+
435+
with pytest.raises(HTTPException) as exc_info:
436+
await infer_endpoint(
437+
infer_request=infer_request,
438+
request=mock_request,
439+
background_tasks=mock_background_tasks,
440+
auth=MOCK_AUTH,
441+
)
442+
443+
assert exc_info.value.status_code == status.HTTP_413_REQUEST_ENTITY_TOO_LARGE
444+
445+
446+
@pytest.mark.asyncio
447+
async def test_infer_runtime_error_other_reraises(
448+
mocker: MockerFixture,
449+
mock_configuration: AppConfig,
450+
mock_runtime_error_other: None,
451+
mock_auth_resolvers: None,
452+
) -> None:
453+
"""Test /infer returns 500 for RuntimeError when not context_length related."""
454+
infer_request = RlsapiV1InferRequest(question="Test question")
455+
mock_request = _create_mock_request(mocker)
456+
mock_background_tasks = _create_mock_background_tasks(mocker)
457+
458+
with pytest.raises(HTTPException) as exc_info:
459+
await infer_endpoint(
460+
infer_request=infer_request,
461+
request=mock_request,
462+
background_tasks=mock_background_tasks,
463+
auth=MOCK_AUTH,
464+
)
465+
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
466+
467+
403468
@pytest.mark.asyncio
404469
async def test_infer_empty_llm_response_returns_fallback(
405470
mocker: MockerFixture,

tests/unit/app/endpoints/test_streaming_query.py

Lines changed: 16 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any
77

88
import pytest
9-
from fastapi import HTTPException, Request
9+
from fastapi import HTTPException, Request, status
1010
from fastapi.responses import StreamingResponse
1111
from llama_stack_api.openai_responses import (
1212
OpenAIResponseObject,
@@ -257,20 +257,14 @@ class TestOLSCompatibilityIntegration:
257257

258258
def test_media_type_validation(self) -> None:
259259
"""Test that media type validation works correctly."""
260-
valid_request = QueryRequest(
261-
query="test", media_type="application/json"
262-
) # pyright: ignore[reportCallIssue]
260+
valid_request = QueryRequest(query="test", media_type="application/json") # pyright: ignore[reportCallIssue]
263261
assert valid_request.media_type == "application/json"
264262

265-
valid_request = QueryRequest(
266-
query="test", media_type="text/plain"
267-
) # pyright: ignore[reportCallIssue]
263+
valid_request = QueryRequest(query="test", media_type="text/plain") # pyright: ignore[reportCallIssue]
268264
assert valid_request.media_type == "text/plain"
269265

270266
with pytest.raises(ValueError, match="media_type must be either"):
271-
QueryRequest(
272-
query="test", media_type="invalid/type"
273-
) # pyright: ignore[reportCallIssue]
267+
QueryRequest(query="test", media_type="invalid/type") # pyright: ignore[reportCallIssue]
274268

275269
def test_ols_end_event_structure(self) -> None:
276270
"""Test that end event follows OLS structure."""
@@ -322,9 +316,7 @@ async def test_successful_streaming_query(
322316
mocker: MockerFixture,
323317
) -> None:
324318
"""Test successful streaming query."""
325-
query_request = QueryRequest(
326-
query="What is Kubernetes?"
327-
) # pyright: ignore[reportCallIssue]
319+
query_request = QueryRequest(query="What is Kubernetes?") # pyright: ignore[reportCallIssue]
328320

329321
mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration)
330322
mocker.patch("app.endpoints.streaming_query.check_configuration_loaded")
@@ -574,9 +566,7 @@ async def test_streaming_query_azure_token_refresh(
574566
mocker: MockerFixture,
575567
) -> None:
576568
"""Test streaming query refreshes Azure token when needed."""
577-
query_request = QueryRequest(
578-
query="What is Kubernetes?"
579-
) # pyright: ignore[reportCallIssue]
569+
query_request = QueryRequest(query="What is Kubernetes?") # pyright: ignore[reportCallIssue]
580570

581571
mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration)
582572
mocker.patch("app.endpoints.streaming_query.check_configuration_loaded")
@@ -679,9 +669,7 @@ async def test_retrieve_response_generator_success(
679669

680670
mock_context = mocker.Mock(spec=ResponseGeneratorContext)
681671
mock_context.client = mock_client
682-
mock_context.query_request = QueryRequest(
683-
query="test"
684-
) # pyright: ignore[reportCallIssue]
672+
mock_context.query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue]
685673

686674
async def mock_response_gen() -> AsyncIterator[str]:
687675
yield "test"
@@ -769,9 +757,7 @@ async def test_retrieve_response_generator_connection_error(
769757

770758
mock_context = mocker.Mock(spec=ResponseGeneratorContext)
771759
mock_context.client = mock_client
772-
mock_context.query_request = QueryRequest(
773-
query="test"
774-
) # pyright: ignore[reportCallIssue]
760+
mock_context.query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue]
775761

776762
mocker.patch(
777763
"app.endpoints.streaming_query.run_shield_moderation",
@@ -822,9 +808,7 @@ async def test_retrieve_response_generator_api_status_error(
822808

823809
mock_context = mocker.Mock(spec=ResponseGeneratorContext)
824810
mock_context.client = mock_client
825-
mock_context.query_request = QueryRequest(
826-
query="test"
827-
) # pyright: ignore[reportCallIssue]
811+
mock_context.query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue]
828812

829813
mocker.patch(
830814
"app.endpoints.streaming_query.run_shield_moderation",
@@ -872,9 +856,7 @@ async def test_retrieve_response_generator_runtime_error_context_length(
872856

873857
mock_context = mocker.Mock(spec=ResponseGeneratorContext)
874858
mock_context.client = mock_client
875-
mock_context.query_request = QueryRequest(
876-
query="test"
877-
) # pyright: ignore[reportCallIssue]
859+
mock_context.query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue]
878860

879861
mocker.patch(
880862
"app.endpoints.streaming_query.run_shield_moderation",
@@ -919,9 +901,7 @@ async def test_retrieve_response_generator_runtime_error_other(
919901

920902
mock_context = mocker.Mock(spec=ResponseGeneratorContext)
921903
mock_context.client = mock_client
922-
mock_context.query_request = QueryRequest(
923-
query="test"
924-
) # pyright: ignore[reportCallIssue]
904+
mock_context.query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue]
925905

926906
mocker.patch(
927907
"app.endpoints.streaming_query.run_shield_moderation",
@@ -932,8 +912,9 @@ async def test_retrieve_response_generator_runtime_error_other(
932912
side_effect=RuntimeError("Some other error")
933913
)
934914

935-
with pytest.raises(RuntimeError):
915+
with pytest.raises(HTTPException) as exc_info:
936916
await retrieve_response_generator(mock_responses_params, mock_context)
917+
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
937918

938919

939920
class TestGenerateResponse:
@@ -950,9 +931,7 @@ async def mock_generator() -> AsyncIterator[str]:
950931
mock_context = mocker.Mock(spec=ResponseGeneratorContext)
951932
mock_context.conversation_id = "conv_123"
952933
mock_context.user_id = "user_123"
953-
mock_context.query_request = QueryRequest(
954-
query="test"
955-
) # pyright: ignore[reportCallIssue]
934+
mock_context.query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue]
956935
mock_context.started_at = "2024-01-01T00:00:00Z"
957936
mock_context.skip_userid_check = False
958937

@@ -1047,9 +1026,7 @@ async def mock_generator() -> AsyncIterator[str]:
10471026

10481027
mock_context = mocker.Mock(spec=ResponseGeneratorContext)
10491028
mock_context.conversation_id = "conv_123"
1050-
mock_context.query_request = QueryRequest(
1051-
query="test"
1052-
) # pyright: ignore[reportCallIssue]
1029+
mock_context.query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue]
10531030
mock_context.started_at = "2024-01-01T00:00:00Z"
10541031
mock_context.skip_userid_check = False
10551032

@@ -1082,9 +1059,7 @@ async def mock_generator() -> AsyncIterator[str]:
10821059

10831060
mock_context = mocker.Mock(spec=ResponseGeneratorContext)
10841061
mock_context.conversation_id = "conv_123"
1085-
mock_context.query_request = QueryRequest(
1086-
query="test"
1087-
) # pyright: ignore[reportCallIssue]
1062+
mock_context.query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue]
10881063
mock_context.started_at = "2024-01-01T00:00:00Z"
10891064
mock_context.skip_userid_check = False
10901065

0 commit comments

Comments
 (0)