Skip to content

Commit f625b60

Browse files
committed
fix: add default model availability check to readiness probe
The /readiness endpoint only verified provider health but did not check that the configured default model was registered in the Llama Stack model registry. This allowed pods where model registration failed during startup to pass readiness and serve 404s on every inference request. Add check_default_model_available() that verifies the configured default model exists in client.models.list(). When the model is missing, /readiness returns 503 so Kubernetes removes the pod from the service load balancer. Ref: RSPEED-2959 Signed-off-by: Major Hayden <major@redhat.com>
1 parent ca125c4 commit f625b60

5 files changed

Lines changed: 425 additions & 5 deletions

File tree

src/app/endpoints/health.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from authentication.interface import AuthTuple
1616
from authorization.middleware import authorize
1717
from client import AsyncLlamaStackClientHolder
18+
from configuration import configuration
1819
from log import get_logger
1920
from models.config import Action
2021
from models.responses import (
@@ -95,6 +96,31 @@ async def get_providers_health_statuses() -> list[ProviderHealthStatus]:
9596
]
9697

9798

99+
async def check_default_model_available() -> tuple[bool, str]:
100+
"""Check that the configured default model is registered in the model registry.
101+
102+
Retrieves the default model and provider from configuration and delegates
103+
the availability check to the client holder.
104+
105+
Returns:
106+
A tuple of (available, reason) where available is True if the default
107+
model was found or no default model is configured, and reason describes
108+
the outcome.
109+
"""
110+
inference = configuration.inference
111+
if (
112+
inference is None
113+
or not inference.default_model
114+
or not inference.default_provider
115+
):
116+
return True, "No default model configured"
117+
118+
expected_model_id = f"{inference.default_provider}/{inference.default_model}"
119+
120+
client_holder = AsyncLlamaStackClientHolder()
121+
return await client_holder.check_model_available(expected_model_id)
122+
123+
98124
@router.get("/readiness", responses=get_readiness_responses)
99125
@authorize(Action.INFO)
100126
async def readiness_probe_get_method(
@@ -134,11 +160,21 @@ async def readiness_probe_get_method(
134160
unhealthy_provider_names = [p.provider_id for p in unhealthy_providers]
135161
reason = f"Providers not healthy: {', '.join(unhealthy_provider_names)}"
136162
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
137-
else:
138-
ready = True
139-
reason = "All providers are healthy"
163+
return ReadinessResponse(
164+
ready=ready, reason=reason, providers=unhealthy_providers
165+
)
166+
167+
# Check that the default model is registered in the model registry
168+
model_available, model_reason = await check_default_model_available()
169+
if not model_available:
170+
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
171+
return ReadinessResponse(
172+
ready=False, reason=model_reason, providers=unhealthy_providers
173+
)
140174

141-
return ReadinessResponse(ready=ready, reason=reason, providers=unhealthy_providers)
175+
return ReadinessResponse(
176+
ready=True, reason="All providers are healthy", providers=unhealthy_providers
177+
)
142178

143179

144180
@router.get("/liveness", responses=get_liveness_responses)

src/client.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import yaml
99
from fastapi import HTTPException
1010
from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient
11-
from llama_stack_client import APIConnectionError, AsyncLlamaStackClient
11+
from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient
1212

1313
from configuration import configuration
1414
from llama_stack_configuration import YamlDumper, enrich_byok_rag, enrich_solr
@@ -141,6 +141,76 @@ async def reload_library_client(self) -> AsyncLlamaStackClient:
141141
self._lsc = client
142142
return client
143143

144+
async def check_model_available(self, model_id: str) -> tuple[bool, str]:
145+
"""Check if a model is available in the registry, attempting reload if needed.
146+
147+
Verifies the model can be found in the Llama Stack client's model
148+
list. If the model is missing and the client is running in library
149+
mode, attempts a client reload to re-register models before
150+
reporting failure.
151+
152+
The reload re-runs the full Stack initialization pipeline, which
153+
re-attempts model registration with providers. This handles the
154+
case where a transient provider failure (e.g. Vertex AI network
155+
blip) caused model registration to fail on startup. Since
156+
Kubernetes readiness probe failures only remove the pod from
157+
service endpoints without restarting it, the reload provides a
158+
self-healing path.
159+
160+
Args:
161+
model_id: The expected model identifier to look up.
162+
163+
Returns:
164+
A tuple of (available, reason) where available is True if the
165+
model was found, and reason describes the outcome.
166+
"""
167+
try:
168+
client = self.get_client()
169+
models = await client.models.list()
170+
except RuntimeError as e:
171+
logger.warning("Client not initialized, skipping model check: %s", e)
172+
return False, f"Client not initialized: {e!s}"
173+
except (APIConnectionError, APIStatusError) as e:
174+
logger.error("Error checking model availability: %s", e)
175+
return False, f"Error checking model availability: {e!s}"
176+
177+
if any(m.id == model_id for m in models):
178+
return True, f"Model {model_id} is available"
179+
180+
# Model not found - attempt self-healing reload for library clients.
181+
# In server mode there is no library client to reload, so we can
182+
# only detect the missing model and report failure.
183+
if self.is_library_client:
184+
logger.warning(
185+
"Model %s not found, attempting client reload",
186+
model_id,
187+
)
188+
try:
189+
await self.reload_library_client()
190+
client = self.get_client()
191+
reloaded_models = await client.models.list()
192+
if any(m.id == model_id for m in reloaded_models):
193+
logger.info(
194+
"Model %s found after client reload",
195+
model_id,
196+
)
197+
return True, f"Model {model_id} is available after reload"
198+
except (
199+
RuntimeError,
200+
HTTPException,
201+
APIConnectionError,
202+
APIStatusError,
203+
) as err:
204+
logger.error("Client reload failed: %s", err)
205+
206+
registered_ids = [m.id for m in models]
207+
logger.error(
208+
"Model %s not found in registry. Registered models: %s",
209+
model_id,
210+
registered_ids,
211+
)
212+
return False, f"Model {model_id} not found in model registry"
213+
144214
def update_provider_data(self, updates: dict[str, str]) -> AsyncLlamaStackClient:
145215
"""Update provider data headers for service client.
146216

tests/integration/endpoints/test_health_integration.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ async def test_health_readiness(
159159
mock_llama_stack_client_health: AsyncMockType,
160160
test_response: Response,
161161
test_auth: AuthTuple,
162+
mocker: MockerFixture,
162163
) -> None:
163164
"""Test that readiness probe endpoint returns readiness status.
164165
@@ -180,6 +181,12 @@ async def test_health_readiness(
180181
"""
181182
_ = mock_llama_stack_client_health
182183

184+
# Mock check_default_model_available since configuration is not loaded
185+
mock_check_model = mocker.patch(
186+
"app.endpoints.health.check_default_model_available"
187+
)
188+
mock_check_model.return_value = (True, "Default model is available")
189+
183190
result = await readiness_probe_get_method(auth=test_auth, response=test_response)
184191

185192
# Verify that service returns readiness response

tests/unit/app/endpoints/test_health.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
"""Unit tests for the /health REST API endpoint."""
22

3+
from typing import Any
4+
35
import pytest
46
from llama_stack_client import APIConnectionError
57
from pytest_mock import MockerFixture
68

79
from app.endpoints.health import (
810
HealthStatus,
11+
check_default_model_available,
912
get_providers_health_statuses,
1013
liveness_probe_get_method,
1114
readiness_probe_get_method,
@@ -72,6 +75,12 @@ async def test_readiness_probe_success_when_all_providers_healthy(
7275
),
7376
]
7477

78+
# Mock check_default_model_available so it doesn't hit uninitialized client
79+
mock_check_model = mocker.patch(
80+
"app.endpoints.health.check_default_model_available"
81+
)
82+
mock_check_model.return_value = (True, "Default model is available")
83+
7584
# Mock the Response object and auth
7685
mock_response = mocker.Mock()
7786

@@ -87,6 +96,43 @@ async def test_readiness_probe_success_when_all_providers_healthy(
8796
assert len(response.providers) == 0
8897

8998

99+
@pytest.mark.asyncio
100+
async def test_readiness_probe_fails_when_model_not_available(
101+
mocker: MockerFixture,
102+
) -> None:
103+
"""Test readiness returns 503 when providers are healthy but default model is missing."""
104+
mock_authorization_resolvers(mocker)
105+
106+
mock_get_providers = mocker.patch(
107+
"app.endpoints.health.get_providers_health_statuses"
108+
)
109+
mock_get_providers.return_value = [
110+
ProviderHealthStatus(
111+
provider_id="provider1",
112+
status=HealthStatus.OK.value,
113+
message="Provider is healthy",
114+
)
115+
]
116+
117+
mock_check_model = mocker.patch(
118+
"app.endpoints.health.check_default_model_available"
119+
)
120+
mock_check_model.return_value = (
121+
False,
122+
"Default model google-vertex/publishers/google/models/gemini-2.5-flash "
123+
"not found in model registry",
124+
)
125+
126+
mock_response = mocker.Mock()
127+
auth: AuthTuple = ("test_user_id", "test_user", True, "test_token")
128+
129+
response = await readiness_probe_get_method(auth=auth, response=mock_response)
130+
131+
assert response.ready is False
132+
assert "not found in model registry" in response.reason
133+
assert mock_response.status_code == 503
134+
135+
90136
@pytest.mark.asyncio
91137
async def test_liveness_probe(mocker: MockerFixture) -> None:
92138
"""Test the liveness endpoint handler."""
@@ -207,3 +253,87 @@ async def test_get_providers_health_statuses_connection_error(
207253
assert (
208254
result[0].message == "Failed to initialize health check: Connection error."
209255
)
256+
257+
258+
class TestCheckDefaultModelAvailable:
259+
"""Test cases for the check_default_model_available function.
260+
261+
The model availability logic (registry lookup, reload, error handling)
262+
is tested in tests/unit/test_client.py (TestCheckModelAvailable). These
263+
tests verify only the config lookup and delegation in health.py.
264+
"""
265+
266+
EXPECTED_MODEL_ID = "google-vertex/publishers/google/models/gemini-2.5-flash"
267+
268+
@pytest.fixture
269+
def inference_config(self, mocker: MockerFixture) -> Any:
270+
"""Patch configuration with default model and provider."""
271+
mock_config = mocker.patch("app.endpoints.health.configuration")
272+
mock_config.inference.default_model = (
273+
"publishers/google/models/gemini-2.5-flash"
274+
)
275+
mock_config.inference.default_provider = "google-vertex"
276+
return mock_config
277+
278+
@pytest.mark.asyncio
279+
async def test_no_inference_config(self, mocker: MockerFixture) -> None:
280+
"""Test returns True when no inference configuration exists."""
281+
mock_config = mocker.patch("app.endpoints.health.configuration")
282+
mock_config.inference = None
283+
284+
available, reason = await check_default_model_available()
285+
286+
assert available is True
287+
assert reason == "No default model configured"
288+
289+
@pytest.mark.asyncio
290+
async def test_no_default_model_configured(self, mocker: MockerFixture) -> None:
291+
"""Test returns True when no default model is configured."""
292+
mock_config = mocker.patch("app.endpoints.health.configuration")
293+
mock_config.inference.default_model = None
294+
mock_config.inference.default_provider = None
295+
296+
available, reason = await check_default_model_available()
297+
298+
assert available is True
299+
assert reason == "No default model configured"
300+
301+
@pytest.mark.asyncio
302+
@pytest.mark.usefixtures("inference_config")
303+
async def test_delegates_to_client_holder(
304+
self,
305+
mocker: MockerFixture,
306+
) -> None:
307+
"""Test delegates to client holder with correct model ID."""
308+
mock_holder = mocker.patch("app.endpoints.health.AsyncLlamaStackClientHolder")
309+
mock_holder.return_value.check_model_available = mocker.AsyncMock(
310+
return_value=(True, f"Model {self.EXPECTED_MODEL_ID} is available")
311+
)
312+
313+
available, reason = await check_default_model_available()
314+
315+
assert available is True
316+
assert "is available" in reason
317+
mock_holder.return_value.check_model_available.assert_awaited_once_with(
318+
self.EXPECTED_MODEL_ID
319+
)
320+
321+
@pytest.mark.asyncio
322+
@pytest.mark.usefixtures("inference_config")
323+
async def test_returns_holder_failure(
324+
self,
325+
mocker: MockerFixture,
326+
) -> None:
327+
"""Test passes through failure result from client holder."""
328+
mock_holder = mocker.patch("app.endpoints.health.AsyncLlamaStackClientHolder")
329+
mock_holder.return_value.check_model_available = mocker.AsyncMock(
330+
return_value=(
331+
False,
332+
f"Model {self.EXPECTED_MODEL_ID} not found in model registry",
333+
)
334+
)
335+
336+
available, reason = await check_default_model_available()
337+
338+
assert available is False
339+
assert "not found in model registry" in reason

0 commit comments

Comments
 (0)