Skip to content

Commit aa41ccd

Browse files
authored
Merge pull request #1636 from major/fix/readiness-check-model-availability
RSPEED-2959: Add default model availability check to readiness probe
2 parents b601f34 + f625b60 commit aa41ccd

5 files changed

Lines changed: 425 additions & 5 deletions

File tree

src/app/endpoints/health.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from authentication.interface import AuthTuple
1616
from authorization.middleware import authorize
1717
from client import AsyncLlamaStackClientHolder
18+
from configuration import configuration
1819
from log import get_logger
1920
from models.config import Action
2021
from models.responses import (
@@ -95,6 +96,31 @@ async def get_providers_health_statuses() -> list[ProviderHealthStatus]:
9596
]
9697

9798

99+
async def check_default_model_available() -> tuple[bool, str]:
100+
"""Check that the configured default model is registered in the model registry.
101+
102+
Retrieves the default model and provider from configuration and delegates
103+
the availability check to the client holder.
104+
105+
Returns:
106+
A tuple of (available, reason) where available is True if the default
107+
model was found or no default model is configured, and reason describes
108+
the outcome.
109+
"""
110+
inference = configuration.inference
111+
if (
112+
inference is None
113+
or not inference.default_model
114+
or not inference.default_provider
115+
):
116+
return True, "No default model configured"
117+
118+
expected_model_id = f"{inference.default_provider}/{inference.default_model}"
119+
120+
client_holder = AsyncLlamaStackClientHolder()
121+
return await client_holder.check_model_available(expected_model_id)
122+
123+
98124
@router.get("/readiness", responses=get_readiness_responses)
99125
@authorize(Action.INFO)
100126
async def readiness_probe_get_method(
@@ -134,11 +160,21 @@ async def readiness_probe_get_method(
134160
unhealthy_provider_names = [p.provider_id for p in unhealthy_providers]
135161
reason = f"Providers not healthy: {', '.join(unhealthy_provider_names)}"
136162
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
137-
else:
138-
ready = True
139-
reason = "All providers are healthy"
163+
return ReadinessResponse(
164+
ready=ready, reason=reason, providers=unhealthy_providers
165+
)
166+
167+
# Check that the default model is registered in the model registry
168+
model_available, model_reason = await check_default_model_available()
169+
if not model_available:
170+
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
171+
return ReadinessResponse(
172+
ready=False, reason=model_reason, providers=unhealthy_providers
173+
)
140174

141-
return ReadinessResponse(ready=ready, reason=reason, providers=unhealthy_providers)
175+
return ReadinessResponse(
176+
ready=True, reason="All providers are healthy", providers=unhealthy_providers
177+
)
142178

143179

144180
@router.get("/liveness", responses=get_liveness_responses)

src/client.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import yaml
99
from fastapi import HTTPException
1010
from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient
11-
from llama_stack_client import APIConnectionError, AsyncLlamaStackClient
11+
from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient
1212

1313
from configuration import configuration
1414
from llama_stack_configuration import YamlDumper, enrich_byok_rag, enrich_solr
@@ -141,6 +141,76 @@ async def reload_library_client(self) -> AsyncLlamaStackClient:
141141
self._lsc = client
142142
return client
143143

144+
async def check_model_available(self, model_id: str) -> tuple[bool, str]:
145+
"""Check if a model is available in the registry, attempting reload if needed.
146+
147+
Verifies the model can be found in the Llama Stack client's model
148+
list. If the model is missing and the client is running in library
149+
mode, attempts a client reload to re-register models before
150+
reporting failure.
151+
152+
The reload re-runs the full Stack initialization pipeline, which
153+
re-attempts model registration with providers. This handles the
154+
case where a transient provider failure (e.g. Vertex AI network
155+
blip) caused model registration to fail on startup. Since
156+
Kubernetes readiness probe failures only remove the pod from
157+
service endpoints without restarting it, the reload provides a
158+
self-healing path.
159+
160+
Args:
161+
model_id: The expected model identifier to look up.
162+
163+
Returns:
164+
A tuple of (available, reason) where available is True if the
165+
model was found, and reason describes the outcome.
166+
"""
167+
try:
168+
client = self.get_client()
169+
models = await client.models.list()
170+
except RuntimeError as e:
171+
logger.warning("Client not initialized, skipping model check: %s", e)
172+
return False, f"Client not initialized: {e!s}"
173+
except (APIConnectionError, APIStatusError) as e:
174+
logger.error("Error checking model availability: %s", e)
175+
return False, f"Error checking model availability: {e!s}"
176+
177+
if any(m.id == model_id for m in models):
178+
return True, f"Model {model_id} is available"
179+
180+
# Model not found - attempt self-healing reload for library clients.
181+
# In server mode there is no library client to reload, so we can
182+
# only detect the missing model and report failure.
183+
if self.is_library_client:
184+
logger.warning(
185+
"Model %s not found, attempting client reload",
186+
model_id,
187+
)
188+
try:
189+
await self.reload_library_client()
190+
client = self.get_client()
191+
reloaded_models = await client.models.list()
192+
if any(m.id == model_id for m in reloaded_models):
193+
logger.info(
194+
"Model %s found after client reload",
195+
model_id,
196+
)
197+
return True, f"Model {model_id} is available after reload"
198+
except (
199+
RuntimeError,
200+
HTTPException,
201+
APIConnectionError,
202+
APIStatusError,
203+
) as err:
204+
logger.error("Client reload failed: %s", err)
205+
206+
registered_ids = [m.id for m in models]
207+
logger.error(
208+
"Model %s not found in registry. Registered models: %s",
209+
model_id,
210+
registered_ids,
211+
)
212+
return False, f"Model {model_id} not found in model registry"
213+
144214
def update_provider_data(self, updates: dict[str, str]) -> AsyncLlamaStackClient:
145215
"""Update provider data headers for service client.
146216

tests/integration/endpoints/test_health_integration.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ async def test_health_readiness(
159159
mock_llama_stack_client_health: AsyncMockType,
160160
test_response: Response,
161161
test_auth: AuthTuple,
162+
mocker: MockerFixture,
162163
) -> None:
163164
"""Test that readiness probe endpoint returns readiness status.
164165
@@ -180,6 +181,12 @@ async def test_health_readiness(
180181
"""
181182
_ = mock_llama_stack_client_health
182183

184+
# Mock check_default_model_available since configuration is not loaded
185+
mock_check_model = mocker.patch(
186+
"app.endpoints.health.check_default_model_available"
187+
)
188+
mock_check_model.return_value = (True, "Default model is available")
189+
183190
result = await readiness_probe_get_method(auth=test_auth, response=test_response)
184191

185192
# Verify that service returns readiness response

tests/unit/app/endpoints/test_health.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
"""Unit tests for the /health REST API endpoint."""
22

3+
from typing import Any
4+
35
import pytest
46
from llama_stack_client import APIConnectionError
57
from pytest_mock import MockerFixture
68

79
from app.endpoints.health import (
810
HealthStatus,
11+
check_default_model_available,
912
get_providers_health_statuses,
1013
liveness_probe_get_method,
1114
readiness_probe_get_method,
@@ -72,6 +75,12 @@ async def test_readiness_probe_success_when_all_providers_healthy(
7275
),
7376
]
7477

78+
# Mock check_default_model_available so it doesn't hit uninitialized client
79+
mock_check_model = mocker.patch(
80+
"app.endpoints.health.check_default_model_available"
81+
)
82+
mock_check_model.return_value = (True, "Default model is available")
83+
7584
# Mock the Response object and auth
7685
mock_response = mocker.Mock()
7786

@@ -87,6 +96,43 @@ async def test_readiness_probe_success_when_all_providers_healthy(
8796
assert len(response.providers) == 0
8897

8998

99+
@pytest.mark.asyncio
100+
async def test_readiness_probe_fails_when_model_not_available(
101+
mocker: MockerFixture,
102+
) -> None:
103+
"""Test readiness returns 503 when providers are healthy but default model is missing."""
104+
mock_authorization_resolvers(mocker)
105+
106+
mock_get_providers = mocker.patch(
107+
"app.endpoints.health.get_providers_health_statuses"
108+
)
109+
mock_get_providers.return_value = [
110+
ProviderHealthStatus(
111+
provider_id="provider1",
112+
status=HealthStatus.OK.value,
113+
message="Provider is healthy",
114+
)
115+
]
116+
117+
mock_check_model = mocker.patch(
118+
"app.endpoints.health.check_default_model_available"
119+
)
120+
mock_check_model.return_value = (
121+
False,
122+
"Default model google-vertex/publishers/google/models/gemini-2.5-flash "
123+
"not found in model registry",
124+
)
125+
126+
mock_response = mocker.Mock()
127+
auth: AuthTuple = ("test_user_id", "test_user", True, "test_token")
128+
129+
response = await readiness_probe_get_method(auth=auth, response=mock_response)
130+
131+
assert response.ready is False
132+
assert "not found in model registry" in response.reason
133+
assert mock_response.status_code == 503
134+
135+
90136
@pytest.mark.asyncio
91137
async def test_liveness_probe(mocker: MockerFixture) -> None:
92138
"""Test the liveness endpoint handler."""
@@ -207,3 +253,87 @@ async def test_get_providers_health_statuses_connection_error(
207253
assert (
208254
result[0].message == "Failed to initialize health check: Connection error."
209255
)
256+
257+
258+
class TestCheckDefaultModelAvailable:
259+
"""Test cases for the check_default_model_available function.
260+
261+
The model availability logic (registry lookup, reload, error handling)
262+
is tested in tests/unit/test_client.py (TestCheckModelAvailable). These
263+
tests verify only the config lookup and delegation in health.py.
264+
"""
265+
266+
EXPECTED_MODEL_ID = "google-vertex/publishers/google/models/gemini-2.5-flash"
267+
268+
@pytest.fixture
269+
def inference_config(self, mocker: MockerFixture) -> Any:
270+
"""Patch configuration with default model and provider."""
271+
mock_config = mocker.patch("app.endpoints.health.configuration")
272+
mock_config.inference.default_model = (
273+
"publishers/google/models/gemini-2.5-flash"
274+
)
275+
mock_config.inference.default_provider = "google-vertex"
276+
return mock_config
277+
278+
@pytest.mark.asyncio
279+
async def test_no_inference_config(self, mocker: MockerFixture) -> None:
280+
"""Test returns True when no inference configuration exists."""
281+
mock_config = mocker.patch("app.endpoints.health.configuration")
282+
mock_config.inference = None
283+
284+
available, reason = await check_default_model_available()
285+
286+
assert available is True
287+
assert reason == "No default model configured"
288+
289+
@pytest.mark.asyncio
290+
async def test_no_default_model_configured(self, mocker: MockerFixture) -> None:
291+
"""Test returns True when no default model is configured."""
292+
mock_config = mocker.patch("app.endpoints.health.configuration")
293+
mock_config.inference.default_model = None
294+
mock_config.inference.default_provider = None
295+
296+
available, reason = await check_default_model_available()
297+
298+
assert available is True
299+
assert reason == "No default model configured"
300+
301+
@pytest.mark.asyncio
302+
@pytest.mark.usefixtures("inference_config")
303+
async def test_delegates_to_client_holder(
304+
self,
305+
mocker: MockerFixture,
306+
) -> None:
307+
"""Test delegates to client holder with correct model ID."""
308+
mock_holder = mocker.patch("app.endpoints.health.AsyncLlamaStackClientHolder")
309+
mock_holder.return_value.check_model_available = mocker.AsyncMock(
310+
return_value=(True, f"Model {self.EXPECTED_MODEL_ID} is available")
311+
)
312+
313+
available, reason = await check_default_model_available()
314+
315+
assert available is True
316+
assert "is available" in reason
317+
mock_holder.return_value.check_model_available.assert_awaited_once_with(
318+
self.EXPECTED_MODEL_ID
319+
)
320+
321+
@pytest.mark.asyncio
322+
@pytest.mark.usefixtures("inference_config")
323+
async def test_returns_holder_failure(
324+
self,
325+
mocker: MockerFixture,
326+
) -> None:
327+
"""Test passes through failure result from client holder."""
328+
mock_holder = mocker.patch("app.endpoints.health.AsyncLlamaStackClientHolder")
329+
mock_holder.return_value.check_model_available = mocker.AsyncMock(
330+
return_value=(
331+
False,
332+
f"Model {self.EXPECTED_MODEL_ID} not found in model registry",
333+
)
334+
)
335+
336+
available, reason = await check_default_model_available()
337+
338+
assert available is False
339+
assert "not found in model registry" in reason

0 commit comments

Comments
 (0)