Skip to content

Commit 7d37c15

Browse files
committed
fix: add default model availability check to readiness probe
The /readiness endpoint only verified provider health but did not check that the configured default model was registered in the Llama Stack model registry. This allowed pods where model registration failed during startup to pass readiness and serve 404s on every inference request. Add check_default_model_available() that verifies the configured default model exists in client.models.list(). When the model is missing, /readiness returns 503 so Kubernetes removes the pod from the service load balancer. Ref: RSPEED-2959 Signed-off-by: Major Hayden <major@redhat.com>
1 parent ca125c4 commit 7d37c15

2 files changed

Lines changed: 251 additions & 6 deletions

File tree

src/app/endpoints/health.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
from typing import Annotated, Any
1010

1111
from fastapi import APIRouter, Depends, Response, status
12-
from llama_stack_client import APIConnectionError
12+
from llama_stack_client import APIConnectionError, APIStatusError
1313

1414
from authentication import get_auth_dependency
1515
from authentication.interface import AuthTuple
1616
from authorization.middleware import authorize
1717
from client import AsyncLlamaStackClientHolder
18+
from configuration import configuration
1819
from log import get_logger
1920
from models.config import Action
2021
from models.responses import (
@@ -95,6 +96,53 @@ async def get_providers_health_statuses() -> list[ProviderHealthStatus]:
9596
]
9697

9798

99+
async def check_default_model_available() -> tuple[bool, str]:
100+
"""Check that the configured default model is registered in the model registry.
101+
102+
Verifies the default model from configuration can be found in the Llama
103+
Stack client's model list. This catches cases where a pod started
104+
successfully and providers report healthy, but model registration failed
105+
during initialization.
106+
107+
Returns:
108+
A tuple of (available, reason) where available is True if the default
109+
model was found or no default model is configured, and reason describes
110+
the outcome.
111+
"""
112+
if configuration.inference is None:
113+
return True, "No inference configuration"
114+
115+
default_model = configuration.inference.default_model
116+
default_provider = configuration.inference.default_provider
117+
118+
if not default_model or not default_provider:
119+
return True, "No default model configured"
120+
121+
expected_model_id = f"{default_provider}/{default_model}"
122+
123+
try:
124+
client = AsyncLlamaStackClientHolder().get_client()
125+
models = await client.models.list()
126+
127+
for model in models:
128+
if model.id == expected_model_id:
129+
return True, f"Default model {expected_model_id} is available"
130+
131+
registered_ids = [m.id for m in models]
132+
logger.error(
133+
"Default model %s not found in registry. Registered models: %s",
134+
expected_model_id,
135+
registered_ids,
136+
)
137+
return False, (f"Default model {expected_model_id} not found in model registry")
138+
except APIConnectionError as e:
139+
logger.error("Failed to check model availability: %s", e)
140+
return False, f"Failed to check model availability: {e!s}"
141+
except APIStatusError as e:
142+
logger.error("API error checking model availability: %s", e)
143+
return False, f"API error checking model availability: {e!s}"
144+
145+
98146
@router.get("/readiness", responses=get_readiness_responses)
99147
@authorize(Action.INFO)
100148
async def readiness_probe_get_method(
@@ -134,11 +182,21 @@ async def readiness_probe_get_method(
134182
unhealthy_provider_names = [p.provider_id for p in unhealthy_providers]
135183
reason = f"Providers not healthy: {', '.join(unhealthy_provider_names)}"
136184
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
137-
else:
138-
ready = True
139-
reason = "All providers are healthy"
185+
return ReadinessResponse(
186+
ready=ready, reason=reason, providers=unhealthy_providers
187+
)
188+
189+
# Check that the default model is registered in the model registry
190+
model_available, model_reason = await check_default_model_available()
191+
if not model_available:
192+
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
193+
return ReadinessResponse(
194+
ready=False, reason=model_reason, providers=unhealthy_providers
195+
)
140196

141-
return ReadinessResponse(ready=ready, reason=reason, providers=unhealthy_providers)
197+
return ReadinessResponse(
198+
ready=True, reason="All providers are healthy", providers=unhealthy_providers
199+
)
142200

143201

144202
@router.get("/liveness", responses=get_liveness_responses)

tests/unit/app/endpoints/test_health.py

Lines changed: 188 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""Unit tests for the /health REST API endpoint."""
22

33
import pytest
4-
from llama_stack_client import APIConnectionError
4+
from llama_stack_client import APIConnectionError, APIStatusError
55
from pytest_mock import MockerFixture
66

77
from app.endpoints.health import (
88
HealthStatus,
9+
check_default_model_available,
910
get_providers_health_statuses,
1011
liveness_probe_get_method,
1112
readiness_probe_get_method,
@@ -207,3 +208,189 @@ async def test_get_providers_health_statuses_connection_error(
207208
assert (
208209
result[0].message == "Failed to initialize health check: Connection error."
209210
)
211+
212+
213+
class TestCheckDefaultModelAvailable:
214+
"""Test cases for the check_default_model_available function."""
215+
216+
@pytest.mark.asyncio
217+
async def test_model_available(self, mocker: MockerFixture) -> None:
218+
"""Test returns True when the default model is found in the registry."""
219+
mock_config = mocker.patch("app.endpoints.health.configuration")
220+
mock_config.inference.default_model = (
221+
"publishers/google/models/gemini-2.5-flash"
222+
)
223+
mock_config.inference.default_provider = "google-vertex"
224+
225+
mock_lsc = mocker.patch("app.endpoints.health.AsyncLlamaStackClientHolder")
226+
mock_client = mocker.AsyncMock()
227+
mock_lsc.return_value.get_client.return_value = mock_client
228+
229+
mock_model = mocker.Mock()
230+
mock_model.id = "google-vertex/publishers/google/models/gemini-2.5-flash"
231+
mock_client.models.list.return_value = [mock_model]
232+
233+
available, reason = await check_default_model_available()
234+
235+
assert available is True
236+
assert "is available" in reason
237+
238+
@pytest.mark.asyncio
239+
async def test_model_not_found(self, mocker: MockerFixture) -> None:
240+
"""Test returns False when the default model is missing from the registry."""
241+
mock_config = mocker.patch("app.endpoints.health.configuration")
242+
mock_config.inference.default_model = (
243+
"publishers/google/models/gemini-2.5-flash"
244+
)
245+
mock_config.inference.default_provider = "google-vertex"
246+
247+
mock_lsc = mocker.patch("app.endpoints.health.AsyncLlamaStackClientHolder")
248+
mock_client = mocker.AsyncMock()
249+
mock_lsc.return_value.get_client.return_value = mock_client
250+
251+
mock_model = mocker.Mock()
252+
mock_model.id = "some-other-provider/some-other-model"
253+
mock_client.models.list.return_value = [mock_model]
254+
255+
available, reason = await check_default_model_available()
256+
257+
assert available is False
258+
assert "not found in model registry" in reason
259+
260+
@pytest.mark.asyncio
261+
async def test_no_inference_config(self, mocker: MockerFixture) -> None:
262+
"""Test returns True when no inference configuration exists."""
263+
mock_config = mocker.patch("app.endpoints.health.configuration")
264+
mock_config.inference = None
265+
266+
available, reason = await check_default_model_available()
267+
268+
assert available is True
269+
assert reason == "No inference configuration"
270+
271+
@pytest.mark.asyncio
272+
async def test_no_default_model_configured(self, mocker: MockerFixture) -> None:
273+
"""Test returns True when no default model is configured."""
274+
mock_config = mocker.patch("app.endpoints.health.configuration")
275+
mock_config.inference.default_model = None
276+
mock_config.inference.default_provider = None
277+
278+
available, reason = await check_default_model_available()
279+
280+
assert available is True
281+
assert reason == "No default model configured"
282+
283+
@pytest.mark.asyncio
284+
async def test_connection_error(self, mocker: MockerFixture) -> None:
285+
"""Test returns False when model list call fails with connection error."""
286+
mock_config = mocker.patch("app.endpoints.health.configuration")
287+
mock_config.inference.default_model = (
288+
"publishers/google/models/gemini-2.5-flash"
289+
)
290+
mock_config.inference.default_provider = "google-vertex"
291+
292+
mock_lsc = mocker.patch("app.endpoints.health.AsyncLlamaStackClientHolder")
293+
mock_client = mocker.AsyncMock()
294+
mock_lsc.return_value.get_client.return_value = mock_client
295+
mock_client.models.list.side_effect = APIConnectionError(request=mocker.Mock())
296+
297+
available, reason = await check_default_model_available()
298+
299+
assert available is False
300+
assert "Failed to check model availability" in reason
301+
302+
@pytest.mark.asyncio
303+
async def test_api_status_error(self, mocker: MockerFixture) -> None:
304+
"""Test returns False when model list call fails with API status error."""
305+
mock_config = mocker.patch("app.endpoints.health.configuration")
306+
mock_config.inference.default_model = (
307+
"publishers/google/models/gemini-2.5-flash"
308+
)
309+
mock_config.inference.default_provider = "google-vertex"
310+
311+
mock_lsc = mocker.patch("app.endpoints.health.AsyncLlamaStackClientHolder")
312+
mock_client = mocker.AsyncMock()
313+
mock_lsc.return_value.get_client.return_value = mock_client
314+
315+
mock_response = mocker.Mock()
316+
mock_response.status_code = 500
317+
mock_response.headers = {}
318+
mock_client.models.list.side_effect = APIStatusError(
319+
message="Internal error",
320+
response=mock_response,
321+
body=None,
322+
)
323+
324+
available, reason = await check_default_model_available()
325+
326+
assert available is False
327+
assert "API error checking model availability" in reason
328+
329+
330+
@pytest.mark.asyncio
331+
async def test_readiness_probe_fails_when_model_not_available(
332+
mocker: MockerFixture,
333+
) -> None:
334+
"""Test readiness returns 503 when providers are healthy but default model is missing."""
335+
mock_authorization_resolvers(mocker)
336+
337+
mock_get_providers = mocker.patch(
338+
"app.endpoints.health.get_providers_health_statuses"
339+
)
340+
mock_get_providers.return_value = [
341+
ProviderHealthStatus(
342+
provider_id="provider1",
343+
status=HealthStatus.OK.value,
344+
message="Provider is healthy",
345+
)
346+
]
347+
348+
mock_check_model = mocker.patch(
349+
"app.endpoints.health.check_default_model_available"
350+
)
351+
mock_check_model.return_value = (
352+
False,
353+
"Default model google-vertex/publishers/google/models/gemini-2.5-flash "
354+
"not found in model registry",
355+
)
356+
357+
mock_response = mocker.Mock()
358+
auth: AuthTuple = ("test_user_id", "test_user", True, "test_token")
359+
360+
response = await readiness_probe_get_method(auth=auth, response=mock_response)
361+
362+
assert response.ready is False
363+
assert "not found in model registry" in response.reason
364+
assert mock_response.status_code == 503
365+
366+
367+
@pytest.mark.asyncio
368+
async def test_readiness_probe_succeeds_with_healthy_providers_and_model(
369+
mocker: MockerFixture,
370+
) -> None:
371+
"""Test readiness returns 200 when providers are healthy and default model is available."""
372+
mock_authorization_resolvers(mocker)
373+
374+
mock_get_providers = mocker.patch(
375+
"app.endpoints.health.get_providers_health_statuses"
376+
)
377+
mock_get_providers.return_value = [
378+
ProviderHealthStatus(
379+
provider_id="provider1",
380+
status=HealthStatus.OK.value,
381+
message="Provider is healthy",
382+
)
383+
]
384+
385+
mock_check_model = mocker.patch(
386+
"app.endpoints.health.check_default_model_available"
387+
)
388+
mock_check_model.return_value = (True, "Default model is available")
389+
390+
mock_response = mocker.Mock()
391+
auth: AuthTuple = ("test_user_id", "test_user", True, "test_token")
392+
393+
response = await readiness_probe_get_method(auth=auth, response=mock_response)
394+
395+
assert response.ready is True
396+
assert response.reason == "All providers are healthy"

0 commit comments

Comments
 (0)