Skip to content

Commit 153d6ae

Browse files
committed
fix: validate model exists in Llama Stack before rlsapi_v1 inference
Add check_model_configured() call to the /infer handler so a misconfigured default_model/default_provider gets a clear 404 instead of an opaque 500 from the inference call. Matches the existing pattern in responses.py. Extract validation into _resolve_validated_model_id() to keep infer_endpoint complexity at B(9). Signed-off-by: Major Hayden <major@redhat.com>
1 parent 6767256 commit 153d6ae

2 files changed

Lines changed: 67 additions & 1 deletion

File tree

src/app/endpoints/rlsapi_v1.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from models.responses import (
3030
ForbiddenResponse,
3131
InternalServerErrorResponse,
32+
NotFoundResponse,
3233
PromptTooLongResponse,
3334
QuotaExceededResponse,
3435
ServiceUnavailableResponse,
@@ -47,6 +48,7 @@
4748
from utils.quota import check_tokens_available
4849
from utils.responses import (
4950
build_turn_summary,
51+
check_model_configured,
5052
extract_text_from_response_items,
5153
extract_token_usage,
5254
get_mcp_tools,
@@ -236,6 +238,28 @@ async def _get_default_model_id() -> str:
236238
return model.id
237239

238240

241+
async def _resolve_validated_model_id() -> str:
242+
"""Resolve and validate the default model against Llama Stack.
243+
244+
Combines model resolution with existence validation so callers get
245+
either a known-good model ID or a clear 404 error.
246+
247+
Returns:
248+
The validated model identifier string in "provider/model" format.
249+
250+
Raises:
251+
HTTPException: 404 if the resolved model does not exist in Llama Stack.
252+
HTTPException: 503 if Llama Stack is unreachable during resolution or validation.
253+
"""
254+
model_id = await _get_default_model_id()
255+
client = AsyncLlamaStackClientHolder().get_client()
256+
if not await check_model_configured(client, model_id):
257+
_, model_name = extract_provider_and_model_from_model_id(model_id)
258+
error_response = NotFoundResponse(resource="model", resource_id=model_name)
259+
raise HTTPException(**error_response.model_dump())
260+
return model_id
261+
262+
239263
async def retrieve_simple_response(
240264
question: str,
241265
instructions: str,
@@ -668,7 +692,7 @@ async def infer_endpoint( # pylint: disable=R0914
668692
if blocked_response is not None:
669693
return blocked_response
670694

671-
model_id = await _get_default_model_id()
695+
model_id = await _resolve_validated_model_id()
672696
provider, model = extract_provider_and_model_from_model_id(model_id)
673697
mcp_tools: list[Any] = await get_mcp_tools(request_headers=request.headers)
674698

tests/unit/app/endpoints/test_rlsapi_v1.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,19 @@ def mock_shield_passed_fixture(mocker: MockerFixture) -> None:
146146
)
147147

148148

149+
@pytest.fixture(autouse=True, name="mock_model_configured")
150+
def mock_model_configured_fixture(mocker: MockerFixture) -> None:
151+
"""Mock model existence check to pass for all endpoint tests by default.
152+
153+
Individual tests can override this by patching check_model_configured
154+
with a different return value.
155+
"""
156+
mocker.patch(
157+
"app.endpoints.rlsapi_v1.check_model_configured",
158+
new=mocker.AsyncMock(return_value=True),
159+
)
160+
161+
149162
@pytest.fixture(name="mock_api_connection_error")
150163
def mock_api_connection_error_fixture(mocker: MockerFixture) -> None:
151164
"""Mock responses.create() to raise APIConnectionError."""
@@ -522,6 +535,35 @@ async def test_infer_endpoint_configuration_not_loaded(
522535
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
523536

524537

538+
@pytest.mark.asyncio
539+
async def test_infer_model_not_found_returns_404(
540+
mocker: MockerFixture,
541+
mock_configuration: AppConfig,
542+
mock_llm_response: None,
543+
mock_auth_resolvers: None,
544+
mock_request_factory: Callable[..., Any],
545+
mock_background_tasks: Any,
546+
) -> None:
547+
"""Test /infer returns HTTP 404 when configured model does not exist in Llama Stack."""
548+
mocker.patch(
549+
"app.endpoints.rlsapi_v1.check_model_configured",
550+
new=mocker.AsyncMock(return_value=False),
551+
)
552+
553+
infer_request = RlsapiV1InferRequest(question="How do I list files?")
554+
mock_request = mock_request_factory()
555+
556+
with pytest.raises(HTTPException) as exc_info:
557+
await infer_endpoint(
558+
infer_request=infer_request,
559+
request=mock_request,
560+
background_tasks=mock_background_tasks,
561+
auth=MOCK_AUTH,
562+
)
563+
assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND
564+
assert "model" in str(exc_info.value.detail).lower()
565+
566+
525567
@pytest.mark.asyncio
526568
async def test_infer_minimal_request(
527569
mocker: MockerFixture,

0 commit comments

Comments
 (0)