Skip to content

Commit a812eb6

Browse files
committed
Make providers health check in readiness endpoint
1 parent e1c179c commit a812eb6

3 files changed

Lines changed: 263 additions & 17 deletions

File tree

src/app/endpoints/health.py

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,89 @@
88
import logging
99
from typing import Any
1010

11-
from fastapi import APIRouter
12-
13-
from models.responses import ReadinessResponse, LivenessResponse, NotAvailableResponse
11+
from llama_stack.providers.datatypes import HealthStatus
1412

13+
from fastapi import APIRouter, status, Response
14+
from client import get_llama_stack_client
15+
from configuration import configuration
16+
from models.responses import (
17+
LivenessResponse,
18+
ReadinessResponse,
19+
ProviderHealthStatus,
20+
)
1521

1622
logger = logging.getLogger(__name__)
1723
router = APIRouter(tags=["health"])
1824

1925

26+
def get_providers_health_statuses() -> list[ProviderHealthStatus]:
27+
"""Check health of all providers.
28+
29+
Returns:
30+
List of provider health statuses.
31+
"""
32+
try:
33+
llama_stack_config = configuration.llama_stack_configuration
34+
35+
client = get_llama_stack_client(llama_stack_config)
36+
37+
providers = client.providers.list()
38+
logger.debug("Found %d providers", len(providers))
39+
40+
health_results = [
41+
ProviderHealthStatus(
42+
provider_id=provider.provider_id,
43+
status=str(provider.health.get("status", "unknown")),
44+
message=str(provider.health.get("message", "")),
45+
)
46+
for provider in providers
47+
]
48+
return health_results
49+
50+
except Exception as e: # pylint: disable=broad-exception-caught
51+
# eg. no providers defined
52+
logger.error("Failed to check providers health: %s", e)
53+
return [
54+
ProviderHealthStatus(
55+
provider_id="unknown",
56+
status=HealthStatus.ERROR.value,
57+
message=f"Failed to initialize health check: {str(e)}",
58+
)
59+
]
60+
61+
2062
get_readiness_responses: dict[int | str, dict[str, Any]] = {
2163
200: {
2264
"description": "Service is ready",
2365
"model": ReadinessResponse,
2466
},
2567
503: {
2668
"description": "Service is not ready",
27-
"model": NotAvailableResponse,
69+
"model": ReadinessResponse,
2870
},
2971
}
3072

3173

3274
@router.get("/readiness", responses=get_readiness_responses)
33-
def readiness_probe_get_method() -> ReadinessResponse:
34-
"""Ready status of service."""
35-
return ReadinessResponse(ready=True, reason="service is ready")
75+
def readiness_probe_get_method(response: Response) -> ReadinessResponse:
76+
"""Ready status of service with provider health details."""
77+
provider_statuses = get_providers_health_statuses()
78+
79+
# Check if any provider is unhealthy (not counting not_implemented as unhealthy)
80+
unhealthy_providers = [
81+
p for p in provider_statuses if p.status == HealthStatus.ERROR.value
82+
]
83+
84+
if unhealthy_providers:
85+
ready = False
86+
unhealthy_provider_names = [p.provider_id for p in unhealthy_providers]
87+
reason = f"Providers not healthy: {', '.join(unhealthy_provider_names)}"
88+
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
89+
else:
90+
ready = True
91+
reason = "All providers are healthy"
92+
93+
return ReadinessResponse(ready=ready, reason=reason, providers=unhealthy_providers)
3694

3795

3896
get_liveness_responses: dict[int | str, dict[str, Any]] = {

src/models/responses.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,29 +78,56 @@ class InfoResponse(BaseModel):
7878
}
7979

8080

81+
class ProviderHealthStatus(BaseModel):
82+
"""Model representing the health status of a provider.
83+
84+
Attributes:
85+
provider_id: The ID of the provider.
86+
status: The health status ('ok', 'unhealthy', 'not_implemented').
87+
message: Optional message about the health status.
88+
"""
89+
90+
provider_id: str
91+
status: str
92+
message: Optional[str] = None
93+
94+
8195
class ReadinessResponse(BaseModel):
82-
"""Model representing a response to a readiness request.
96+
"""Model representing response to a readiness request.
8397
8498
Attributes:
85-
ready: The readiness of the service.
99+
ready: If service is ready.
86100
reason: The reason for the readiness.
101+
providers: List of unhealthy providers in case of readiness failure.
87102
88103
Example:
89104
```python
90-
readiness_response = ReadinessResponse(ready=True, reason="service is ready")
105+
readiness_response = ReadinessResponse(
106+
ready=False,
107+
reason="Service is not ready",
108+
providers=[
109+
ProviderHealthStatus(
110+
provider_id="ollama",
111+
status="Error",
112+
message="Server is unavailable"
113+
)
114+
]
115+
)
91116
```
92117
"""
93118

94119
ready: bool
95120
reason: str
121+
providers: list[ProviderHealthStatus]
96122

97123
# provides examples for /docs endpoint
98124
model_config = {
99125
"json_schema_extra": {
100126
"examples": [
101127
{
102128
"ready": True,
103-
"reason": "service is ready",
129+
"reason": "Service is ready",
130+
"providers": [],
104131
}
105132
]
106133
}
Lines changed: 167 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,177 @@
1-
from app.endpoints.health import readiness_probe_get_method, liveness_probe_get_method
1+
from unittest.mock import Mock
22

3+
from app.endpoints.health import (
4+
readiness_probe_get_method,
5+
liveness_probe_get_method,
6+
get_providers_health_statuses,
7+
)
8+
from models.responses import ProviderHealthStatus, ReadinessResponse
9+
from llama_stack.providers.datatypes import HealthStatus
310

4-
def test_readiness_probe(mocker):
5-
"""Test the readiness endpoint handler."""
6-
response = readiness_probe_get_method()
11+
12+
def test_readiness_probe_fails_due_to_unhealthy_providers(mocker):
13+
"""Test the readiness endpoint handler fails when providers are unhealthy."""
14+
# Mock get_providers_health_statuses to return an unhealthy provider
15+
mock_get_providers_health_statuses = mocker.patch(
16+
"app.endpoints.health.get_providers_health_statuses"
17+
)
18+
mock_get_providers_health_statuses.return_value = [
19+
ProviderHealthStatus(
20+
provider_id="test_provider",
21+
status=HealthStatus.ERROR.value,
22+
message="Provider is down",
23+
)
24+
]
25+
26+
# Mock the Response object
27+
mock_response = Mock()
28+
29+
response = readiness_probe_get_method(mock_response)
30+
31+
assert response.ready is False
32+
assert "test_provider" in response.reason
33+
assert "Providers not healthy" in response.reason
34+
assert mock_response.status_code == 503
35+
36+
37+
def test_readiness_probe_success_when_all_providers_healthy(mocker):
38+
"""Test the readiness endpoint handler succeeds when all providers are healthy."""
39+
# Mock get_providers_health_statuses to return healthy providers
40+
mock_get_providers_health_statuses = mocker.patch(
41+
"app.endpoints.health.get_providers_health_statuses"
42+
)
43+
mock_get_providers_health_statuses.return_value = [
44+
ProviderHealthStatus(
45+
provider_id="provider1",
46+
status=HealthStatus.OK.value,
47+
message="Provider is healthy",
48+
),
49+
ProviderHealthStatus(
50+
provider_id="provider2",
51+
status=HealthStatus.NOT_IMPLEMENTED.value,
52+
message="Provider does not implement health check",
53+
),
54+
]
55+
56+
# Mock the Response object
57+
mock_response = Mock()
58+
59+
response = readiness_probe_get_method(mock_response)
760
assert response is not None
61+
assert isinstance(response, ReadinessResponse)
862
assert response.ready is True
9-
assert response.reason == "service is ready"
63+
assert response.reason == "All providers are healthy"
64+
# Should return empty list since no providers are unhealthy
65+
assert len(response.providers) == 0
1066

1167

12-
def test_liveness_probe(mocker):
68+
def test_liveness_probe():
1369
"""Test the liveness endpoint handler."""
1470
response = liveness_probe_get_method()
1571
assert response is not None
1672
assert response.alive is True
73+
74+
75+
class TestProviderHealthStatus:
76+
"""Test cases for the ProviderHealthStatus model."""
77+
78+
def test_provider_health_status_creation(self):
79+
"""Test creating a ProviderHealthStatus instance."""
80+
status = ProviderHealthStatus(
81+
provider_id="test_provider", status="ok", message="All good"
82+
)
83+
assert status.provider_id == "test_provider"
84+
assert status.status == "ok"
85+
assert status.message == "All good"
86+
87+
def test_provider_health_status_optional_fields(self):
88+
"""Test creating a ProviderHealthStatus with minimal fields."""
89+
status = ProviderHealthStatus(provider_id="test_provider", status="ok")
90+
assert status.provider_id == "test_provider"
91+
assert status.status == "ok"
92+
assert status.message is None
93+
94+
95+
class TestGetProvidersHealthStatuses:
96+
"""Test cases for the get_providers_health_statuses function."""
97+
98+
def test_get_providers_health_statuses(self, mocker):
99+
"""Test get_providers_health_statuses with healthy providers."""
100+
# Mock the imports
101+
mock_get_llama_stack_client = mocker.patch(
102+
"app.endpoints.health.get_llama_stack_client"
103+
)
104+
mock_configuration = mocker.patch("app.endpoints.health.configuration")
105+
106+
# Mock the client and its methods
107+
mock_client = mocker.Mock()
108+
mock_get_llama_stack_client.return_value = mock_client
109+
110+
# Mock providers.list() to return providers with health
111+
mock_provider_1 = mocker.Mock()
112+
mock_provider_1.provider_id = "provider1"
113+
mock_provider_1.health = {
114+
"status": HealthStatus.OK.value,
115+
"message": "All good",
116+
}
117+
118+
mock_provider_2 = mocker.Mock()
119+
mock_provider_2.provider_id = "provider2"
120+
mock_provider_2.health = {
121+
"status": HealthStatus.NOT_IMPLEMENTED.value,
122+
"message": "Provider does not implement health check",
123+
}
124+
125+
mock_provider_3 = mocker.Mock()
126+
mock_provider_3.provider_id = "unhealthy_provider"
127+
mock_provider_3.health = {
128+
"status": HealthStatus.ERROR.value,
129+
"message": "Connection failed",
130+
}
131+
132+
mock_client.providers.list.return_value = [
133+
mock_provider_1,
134+
mock_provider_2,
135+
mock_provider_3,
136+
]
137+
138+
# Mock configuration
139+
mock_llama_stack_config = mocker.Mock()
140+
mock_configuration.llama_stack_configuration = mock_llama_stack_config
141+
142+
result = get_providers_health_statuses()
143+
144+
assert len(result) == 3
145+
assert result[0].provider_id == "provider1"
146+
assert result[0].status == HealthStatus.OK.value
147+
assert result[0].message == "All good"
148+
assert result[1].provider_id == "provider2"
149+
assert result[1].status == HealthStatus.NOT_IMPLEMENTED.value
150+
assert result[1].message == "Provider does not implement health check"
151+
assert result[2].provider_id == "unhealthy_provider"
152+
assert result[2].status == HealthStatus.ERROR.value
153+
assert result[2].message == "Connection failed"
154+
155+
def test_get_providers_health_statuses_connection_error(self, mocker):
156+
"""Test get_providers_health_statuses when connection fails."""
157+
# Mock the imports
158+
mock_get_llama_stack_client = mocker.patch(
159+
"app.endpoints.health.get_llama_stack_client"
160+
)
161+
mock_configuration = mocker.patch("app.endpoints.health.configuration")
162+
163+
# Mock configuration
164+
mock_llama_stack_config = mocker.Mock()
165+
mock_configuration.llama_stack_configuration = mock_llama_stack_config
166+
167+
# Mock get_llama_stack_client to raise an exception
168+
mock_get_llama_stack_client.side_effect = Exception("Connection error")
169+
170+
result = get_providers_health_statuses()
171+
172+
assert len(result) == 1
173+
assert result[0].provider_id == "unknown"
174+
assert result[0].status == HealthStatus.ERROR.value
175+
assert (
176+
result[0].message == "Failed to initialize health check: Connection error"
177+
)

0 commit comments

Comments
 (0)