Skip to content

Commit 9241e25

Browse files
committed
refactor(rlsapi): tighten infer exception typing
Signed-off-by: Major Hayden <major@redhat.com>
1 parent 8e6f0a8 commit 9241e25

1 file changed

Lines changed: 14 additions & 10 deletions

File tree

src/app/endpoints/rlsapi_v1.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,15 @@
5050

5151
# Default values when RH Identity auth is not configured
5252
AUTH_DISABLED = "auth_disabled"
53+
# Keep this tuple centralized so infer_endpoint can catch all expected backend
54+
# failures in one place while preserving a single telemetry/error-mapping path.
55+
_INFER_HANDLED_EXCEPTIONS = (
56+
RuntimeError,
57+
APIConnectionError,
58+
RateLimitError,
59+
APIStatusError,
60+
OpenAIAPIStatusError,
61+
)
5362

5463

5564
def _get_rh_identity_context(request: Request) -> tuple[str, str]:
@@ -203,7 +212,7 @@ async def retrieve_simple_response(
203212
store=False,
204213
)
205214
response = cast(OpenAIResponseObject, response)
206-
extract_token_usage(response.usage, model_id)
215+
extract_token_usage(response.usage, resolved_model_id)
207216

208217
return extract_text_from_response_items(response.output)
209218

@@ -298,7 +307,8 @@ def _map_inference_error_to_http_exception(
298307
errors.
299308
"""
300309
if isinstance(error, RuntimeError):
301-
if "context_length" in str(error).lower():
310+
error_message = str(error).lower()
311+
if "context_length" in error_message or "context length" in error_message:
302312
logger.error("Prompt too long for request %s: %s", request_id, error)
303313
error_response = PromptTooLongResponse(model=model_id)
304314
return HTTPException(**error_response.model_dump())
@@ -369,7 +379,7 @@ async def infer_endpoint( # pylint: disable=R0914
369379
instructions = _build_instructions(infer_request.context.systeminfo)
370380
model_id = _get_default_model_id()
371381
provider, model = extract_provider_and_model_from_model_id(model_id)
372-
mcp_tools = await get_mcp_tools(request_headers=request.headers)
382+
mcp_tools: list[Any] = await get_mcp_tools(request_headers=request.headers)
373383
logger.debug(
374384
"Request %s: Combined input source length: %d", request_id, len(input_source)
375385
)
@@ -383,13 +393,7 @@ async def infer_endpoint( # pylint: disable=R0914
383393
model_id=model_id,
384394
)
385395
inference_time = time.monotonic() - start_time
386-
except (
387-
RuntimeError,
388-
APIConnectionError,
389-
RateLimitError,
390-
APIStatusError,
391-
OpenAIAPIStatusError,
392-
) as error:
396+
except _INFER_HANDLED_EXCEPTIONS as error:
393397
_record_inference_failure(
394398
background_tasks,
395399
infer_request,

0 commit comments

Comments
 (0)