|
8 | 8 | from enum import Enum |
9 | 9 | from typing import Annotated, Any |
10 | 10 |
|
11 | | -from fastapi import APIRouter, Depends, Response, status |
12 | | -from llama_stack_client import APIConnectionError |
| 11 | +from fastapi import APIRouter, Depends, HTTPException, Response, status |
| 12 | +from llama_stack_client import APIConnectionError, APIStatusError |
13 | 13 |
|
14 | 14 | from authentication import get_auth_dependency |
15 | 15 | from authentication.interface import AuthTuple |
16 | 16 | from authorization.middleware import authorize |
17 | 17 | from client import AsyncLlamaStackClientHolder |
| 18 | +from configuration import configuration |
18 | 19 | from log import get_logger |
19 | 20 | from models.config import Action |
20 | 21 | from models.responses import ( |
@@ -95,6 +96,106 @@ async def get_providers_health_statuses() -> list[ProviderHealthStatus]: |
95 | 96 | ] |
96 | 97 |
|
97 | 98 |
|
| 99 | +def _model_in_registry(models: list, expected_id: str) -> bool: |
| 100 | + """Check if a model with the given ID exists in the model list.""" |
| 101 | + return any(model.id == expected_id for model in models) |
| 102 | + |
| 103 | + |
| 104 | +async def _reload_and_check_model( |
| 105 | + client_holder: AsyncLlamaStackClientHolder, |
| 106 | + expected_model_id: str, |
| 107 | +) -> tuple[bool, str]: |
| 108 | + """Attempt to reload the library client and recheck model availability. |
| 109 | +
|
| 110 | + Only called for library mode clients when the default model is missing |
| 111 | + from the registry after initial lookup. |
| 112 | +
|
| 113 | + Returns: |
| 114 | + A tuple of (found, reason) where found is True if the model was |
| 115 | + found after reloading the client. |
| 116 | + """ |
| 117 | + logger.warning( |
| 118 | + "Default model %s not found, attempting client reload", |
| 119 | + expected_model_id, |
| 120 | + ) |
| 121 | + try: |
| 122 | + await client_holder.reload_library_client() |
| 123 | + client = client_holder.get_client() |
| 124 | + models = await client.models.list() |
| 125 | + if _model_in_registry(models, expected_model_id): |
| 126 | + logger.info( |
| 127 | + "Default model %s found after client reload", |
| 128 | + expected_model_id, |
| 129 | + ) |
| 130 | + return True, ( |
| 131 | + f"Default model {expected_model_id} is available after reload" |
| 132 | + ) |
| 133 | + except ( |
| 134 | + RuntimeError, |
| 135 | + HTTPException, |
| 136 | + APIConnectionError, |
| 137 | + APIStatusError, |
| 138 | + ) as err: |
| 139 | + logger.error("Client reload failed: %s", err) |
| 140 | + return False, "" |
| 141 | + |
| 142 | + |
| 143 | +async def check_default_model_available() -> tuple[bool, str]: |
| 144 | + """Check that the configured default model is registered in the model registry. |
| 145 | +
|
| 146 | + Verifies the default model from configuration can be found in the Llama |
| 147 | + Stack client's model list. This catches cases where a pod started |
| 148 | + successfully and providers report healthy, but model registration failed |
| 149 | + during initialization. |
| 150 | +
|
| 151 | + If the model is missing and the client is running in library mode, attempts |
| 152 | + a client reload to re-register models before reporting failure. |
| 153 | +
|
| 154 | + Returns: |
| 155 | + A tuple of (available, reason) where available is True if the default |
| 156 | + model was found or no default model is configured, and reason describes |
| 157 | + the outcome. |
| 158 | + """ |
| 159 | + inference = configuration.inference |
| 160 | + if inference is None or not inference.default_model or not inference.default_provider: |
| 161 | + return True, "No default model configured" |
| 162 | + |
| 163 | + default_model = inference.default_model |
| 164 | + default_provider = inference.default_provider |
| 165 | + |
| 166 | + expected_model_id = f"{default_provider}/{default_model}" |
| 167 | + |
| 168 | + try: |
| 169 | + client_holder = AsyncLlamaStackClientHolder() |
| 170 | + client = client_holder.get_client() |
| 171 | + models = await client.models.list() |
| 172 | + |
| 173 | + if _model_in_registry(models, expected_model_id): |
| 174 | + return True, f"Default model {expected_model_id} is available" |
| 175 | + |
| 176 | + # Model not found - attempt self-healing reload for library clients |
| 177 | + if client_holder.is_library_client: |
| 178 | + found, reason = await _reload_and_check_model( |
| 179 | + client_holder, expected_model_id |
| 180 | + ) |
| 181 | + if found: |
| 182 | + return True, reason |
| 183 | + |
| 184 | + registered_ids = [m.id for m in models] |
| 185 | + logger.error( |
| 186 | + "Default model %s not found in registry. Registered models: %s", |
| 187 | + expected_model_id, |
| 188 | + registered_ids, |
| 189 | + ) |
| 190 | + return False, f"Default model {expected_model_id} not found in model registry" |
| 191 | + except RuntimeError as e: |
| 192 | + logger.warning("Client not initialized, skipping model check: %s", e) |
| 193 | + return False, f"Client not initialized: {e!s}" |
| 194 | + except (APIConnectionError, APIStatusError) as e: |
| 195 | + logger.error("Error checking model availability: %s", e) |
| 196 | + return False, f"Error checking model availability: {e!s}" |
| 197 | + |
| 198 | + |
98 | 199 | @router.get("/readiness", responses=get_readiness_responses) |
99 | 200 | @authorize(Action.INFO) |
100 | 201 | async def readiness_probe_get_method( |
@@ -134,11 +235,21 @@ async def readiness_probe_get_method( |
134 | 235 | unhealthy_provider_names = [p.provider_id for p in unhealthy_providers] |
135 | 236 | reason = f"Providers not healthy: {', '.join(unhealthy_provider_names)}" |
136 | 237 | response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE |
137 | | - else: |
138 | | - ready = True |
139 | | - reason = "All providers are healthy" |
| 238 | + return ReadinessResponse( |
| 239 | + ready=ready, reason=reason, providers=unhealthy_providers |
| 240 | + ) |
| 241 | + |
| 242 | + # Check that the default model is registered in the model registry |
| 243 | + model_available, model_reason = await check_default_model_available() |
| 244 | + if not model_available: |
| 245 | + response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE |
| 246 | + return ReadinessResponse( |
| 247 | + ready=False, reason=model_reason, providers=unhealthy_providers |
| 248 | + ) |
140 | 249 |
|
141 | | - return ReadinessResponse(ready=ready, reason=reason, providers=unhealthy_providers) |
| 250 | + return ReadinessResponse( |
| 251 | + ready=True, reason="All providers are healthy", providers=unhealthy_providers |
| 252 | + ) |
142 | 253 |
|
143 | 254 |
|
144 | 255 | @router.get("/liveness", responses=get_liveness_responses) |
|
0 commit comments