Skip to content

Commit 1ee30aa

Browse files
authored
Merge pull request #1006 from major/fix/rlsapi-cleanup
fix(rlsapi): add error handling, metrics, and system context
2 parents 66ba6de + 48a7c8f commit 1ee30aa

3 files changed

Lines changed: 123 additions & 22 deletions

File tree

src/app/endpoints/rlsapi_v1.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99

1010
from fastapi import APIRouter, Depends, HTTPException
1111
from llama_stack.apis.agents.openai_responses import OpenAIResponseObject
12-
from llama_stack_client import APIConnectionError
12+
from llama_stack_client import APIConnectionError, APIStatusError, RateLimitError
1313

1414
import constants
15+
import metrics
1516
from authentication import get_auth_dependency
1617
from authentication.interface import AuthTuple
1718
from authorization.middleware import authorize
@@ -20,11 +21,13 @@
2021
from models.config import Action
2122
from models.responses import (
2223
ForbiddenResponse,
24+
InternalServerErrorResponse,
25+
QuotaExceededResponse,
2326
ServiceUnavailableResponse,
2427
UnauthorizedResponse,
2528
UnprocessableEntityResponse,
2629
)
27-
from models.rlsapi.requests import RlsapiV1InferRequest
30+
from models.rlsapi.requests import RlsapiV1InferRequest, RlsapiV1SystemInfo
2831
from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse
2932
from utils.responses import extract_text_from_response_output_item
3033
from utils.suid import get_suid
@@ -40,10 +43,41 @@
4043
),
4144
403: ForbiddenResponse.openapi_response(examples=["endpoint"]),
4245
422: UnprocessableEntityResponse.openapi_response(),
46+
429: QuotaExceededResponse.openapi_response(),
47+
500: InternalServerErrorResponse.openapi_response(examples=["generic"]),
4348
503: ServiceUnavailableResponse.openapi_response(),
4449
}
4550

4651

52+
def _build_instructions(systeminfo: RlsapiV1SystemInfo) -> str:
53+
"""Build LLM instructions incorporating system context when available.
54+
55+
Enhances the default system prompt with RHEL system information to provide
56+
the LLM with relevant context about the user's environment.
57+
58+
Args:
59+
systeminfo: System information from the client (OS, version, arch).
60+
61+
Returns:
62+
Instructions string for the LLM, with system context if available.
63+
"""
64+
base_prompt = constants.DEFAULT_SYSTEM_PROMPT
65+
66+
context_parts = []
67+
if systeminfo.os:
68+
context_parts.append(f"OS: {systeminfo.os}")
69+
if systeminfo.version:
70+
context_parts.append(f"Version: {systeminfo.version}")
71+
if systeminfo.arch:
72+
context_parts.append(f"Architecture: {systeminfo.arch}")
73+
74+
if not context_parts:
75+
return base_prompt
76+
77+
system_context = ", ".join(context_parts)
78+
return f"{base_prompt}\n\nUser's system: {system_context}"
79+
80+
4781
def _get_default_model_id() -> str:
4882
"""Get the default model ID from configuration.
4983
@@ -77,14 +111,15 @@ def _get_default_model_id() -> str:
77111
)
78112

79113

80-
async def retrieve_simple_response(question: str) -> str:
114+
async def retrieve_simple_response(question: str, instructions: str) -> str:
81115
"""Retrieve a simple response from the LLM for a stateless query.
82116
83117
Uses the Responses API for simple stateless inference, consistent with
84118
other endpoints (query_v2, streaming_query_v2).
85119
86120
Args:
87121
question: The combined user input (question + context).
122+
instructions: System instructions for the LLM.
88123
89124
Returns:
90125
The LLM-generated response text.
@@ -101,7 +136,7 @@ async def retrieve_simple_response(question: str) -> str:
101136
response = await client.responses.create(
102137
input=question,
103138
model=model_id,
104-
instructions=constants.DEFAULT_SYSTEM_PROMPT,
139+
instructions=instructions,
105140
stream=False,
106141
store=False,
107142
)
@@ -144,15 +179,16 @@ async def infer_endpoint(
144179

145180
logger.info("Processing rlsapi v1 /infer request %s", request_id)
146181

147-
# Combine all input sources (question, stdin, attachments, terminal)
148182
input_source = infer_request.get_input_source()
183+
instructions = _build_instructions(infer_request.context.systeminfo)
149184
logger.debug(
150185
"Request %s: Combined input source length: %d", request_id, len(input_source)
151186
)
152187

153188
try:
154-
response_text = await retrieve_simple_response(input_source)
189+
response_text = await retrieve_simple_response(input_source, instructions)
155190
except APIConnectionError as e:
191+
metrics.llm_calls_failures_total.inc()
156192
logger.error(
157193
"Unable to connect to Llama Stack for request %s: %s", request_id, e
158194
)
@@ -161,6 +197,18 @@ async def infer_endpoint(
161197
cause=str(e),
162198
)
163199
raise HTTPException(**response.model_dump()) from e
200+
except RateLimitError as e:
201+
metrics.llm_calls_failures_total.inc()
202+
logger.error("Rate limit exceeded for request %s: %s", request_id, e)
203+
response = QuotaExceededResponse(
204+
response="The quota has been exceeded", cause=str(e)
205+
)
206+
raise HTTPException(**response.model_dump()) from e
207+
except APIStatusError as e:
208+
metrics.llm_calls_failures_total.inc()
209+
logger.exception("API error for request %s: %s", request_id, e)
210+
response = InternalServerErrorResponse.generic()
211+
raise HTTPException(**response.model_dump()) from e
164212

165213
if not response_text:
166214
logger.warning("Empty response from LLM for request %s", request_id)

src/models/rlsapi/requests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class RlsapiV1InferRequest(ConfigurationBase):
126126
Attributes:
127127
question: User question string.
128128
context: Context with system info, terminal output, etc. (defaults provided).
129-
skip_rag: Whether to skip RAG retrieval (default False).
129+
skip_rag: Reserved for future use. RAG retrieval is not yet implemented.
130130
131131
Example:
132132
```python
@@ -152,7 +152,7 @@ class RlsapiV1InferRequest(ConfigurationBase):
152152
)
153153
skip_rag: bool = Field(
154154
default=False,
155-
description="Whether to skip RAG retrieval",
155+
description="Reserved for future use. RAG retrieval is not yet implemented.",
156156
examples=[False, True],
157157
)
158158

tests/unit/app/endpoints/test_rlsapi_v1.py

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import constants
1515
from app.endpoints.rlsapi_v1 import (
16+
_build_instructions,
1617
_get_default_model_id,
1718
infer_endpoint,
1819
retrieve_simple_response,
@@ -30,7 +31,7 @@
3031
from tests.unit.utils.auth_helpers import mock_authorization_resolvers
3132
from utils.suid import check_suid
3233

33-
MOCK_AUTH: AuthTuple = ("test_user_id", "test_user", True, "test_token")
34+
MOCK_AUTH: AuthTuple = ("mock_user_id", "mock_username", False, "mock_token")
3435

3536

3637
def _setup_responses_mock(mocker: MockerFixture, create_behavior: Any) -> None:
@@ -87,6 +88,12 @@ def mock_empty_llm_response_fixture(mocker: MockerFixture) -> None:
8788
_setup_responses_mock(mocker, mocker.AsyncMock(return_value=mock_response))
8889

8990

91+
@pytest.fixture(name="mock_auth_resolvers")
92+
def mock_auth_resolvers_fixture(mocker: MockerFixture) -> None:
93+
"""Mock authorization resolvers for endpoint tests."""
94+
mock_authorization_resolvers(mocker)
95+
96+
9097
@pytest.fixture(name="mock_api_connection_error")
9198
def mock_api_connection_error_fixture(mocker: MockerFixture) -> None:
9299
"""Mock responses.create() to raise APIConnectionError."""
@@ -96,6 +103,47 @@ def mock_api_connection_error_fixture(mocker: MockerFixture) -> None:
96103
)
97104

98105

106+
# --- Test _build_instructions ---
107+
108+
109+
@pytest.mark.parametrize(
110+
("systeminfo_kwargs", "expected_contains", "expected_not_contains"),
111+
[
112+
pytest.param(
113+
{"os": "RHEL", "version": "9.3", "arch": "x86_64"},
114+
["OS: RHEL", "Version: 9.3", "Architecture: x86_64"],
115+
[],
116+
id="full_systeminfo",
117+
),
118+
pytest.param(
119+
{"os": "RHEL", "version": "", "arch": ""},
120+
["OS: RHEL"],
121+
["Version:", "Architecture:"],
122+
id="partial_systeminfo",
123+
),
124+
pytest.param(
125+
{},
126+
[constants.DEFAULT_SYSTEM_PROMPT],
127+
["OS:", "Version:", "Architecture:"],
128+
id="empty_systeminfo",
129+
),
130+
],
131+
)
132+
def test_build_instructions(
133+
systeminfo_kwargs: dict[str, str],
134+
expected_contains: list[str],
135+
expected_not_contains: list[str],
136+
) -> None:
137+
"""Test _build_instructions with various system info combinations."""
138+
systeminfo = RlsapiV1SystemInfo(**systeminfo_kwargs)
139+
result = _build_instructions(systeminfo)
140+
141+
for expected in expected_contains:
142+
assert expected in result
143+
for not_expected in expected_not_contains:
144+
assert not_expected not in result
145+
146+
99147
# --- Test _get_default_model_id ---
100148

101149

@@ -151,7 +199,9 @@ async def test_retrieve_simple_response_success(
151199
mock_configuration: AppConfig, mock_llm_response: None
152200
) -> None:
153201
"""Test retrieve_simple_response returns LLM response text."""
154-
response = await retrieve_simple_response("How do I list files?")
202+
response = await retrieve_simple_response(
203+
"How do I list files?", constants.DEFAULT_SYSTEM_PROMPT
204+
)
155205
assert response == "This is a test LLM response."
156206

157207

@@ -160,7 +210,9 @@ async def test_retrieve_simple_response_empty_output(
160210
mock_configuration: AppConfig, mock_empty_llm_response: None
161211
) -> None:
162212
"""Test retrieve_simple_response handles empty LLM output."""
163-
response = await retrieve_simple_response("Test question")
213+
response = await retrieve_simple_response(
214+
"Test question", constants.DEFAULT_SYSTEM_PROMPT
215+
)
164216
assert response == ""
165217

166218

@@ -170,18 +222,19 @@ async def test_retrieve_simple_response_api_connection_error(
170222
) -> None:
171223
"""Test retrieve_simple_response propagates APIConnectionError."""
172224
with pytest.raises(APIConnectionError):
173-
await retrieve_simple_response("Test question")
225+
await retrieve_simple_response("Test question", constants.DEFAULT_SYSTEM_PROMPT)
174226

175227

176228
# --- Test infer_endpoint ---
177229

178230

179231
@pytest.mark.asyncio
180232
async def test_infer_minimal_request(
181-
mocker: MockerFixture, mock_configuration: AppConfig, mock_llm_response: None
233+
mock_configuration: AppConfig,
234+
mock_llm_response: None,
235+
mock_auth_resolvers: None,
182236
) -> None:
183237
"""Test /infer endpoint returns valid response with LLM text."""
184-
mock_authorization_resolvers(mocker)
185238
request = RlsapiV1InferRequest(question="How do I list files?")
186239

187240
response = await infer_endpoint(infer_request=request, auth=MOCK_AUTH)
@@ -194,10 +247,11 @@ async def test_infer_minimal_request(
194247

195248
@pytest.mark.asyncio
196249
async def test_infer_full_context_request(
197-
mocker: MockerFixture, mock_configuration: AppConfig, mock_llm_response: None
250+
mock_configuration: AppConfig,
251+
mock_llm_response: None,
252+
mock_auth_resolvers: None,
198253
) -> None:
199254
"""Test /infer endpoint handles full context (stdin, attachments, terminal)."""
200-
mock_authorization_resolvers(mocker)
201255
request = RlsapiV1InferRequest(
202256
question="Why did this command fail?",
203257
context=RlsapiV1Context(
@@ -217,10 +271,11 @@ async def test_infer_full_context_request(
217271

218272
@pytest.mark.asyncio
219273
async def test_infer_generates_unique_request_ids(
220-
mocker: MockerFixture, mock_configuration: AppConfig, mock_llm_response: None
274+
mock_configuration: AppConfig,
275+
mock_llm_response: None,
276+
mock_auth_resolvers: None,
221277
) -> None:
222278
"""Test that each /infer call generates a unique request_id."""
223-
mock_authorization_resolvers(mocker)
224279
request = RlsapiV1InferRequest(question="How do I list files?")
225280

226281
response1 = await infer_endpoint(infer_request=request, auth=MOCK_AUTH)
@@ -231,12 +286,11 @@ async def test_infer_generates_unique_request_ids(
231286

232287
@pytest.mark.asyncio
233288
async def test_infer_api_connection_error_returns_503(
234-
mocker: MockerFixture,
235289
mock_configuration: AppConfig,
236290
mock_api_connection_error: None,
291+
mock_auth_resolvers: None,
237292
) -> None:
238293
"""Test /infer endpoint returns 503 when LLM service is unavailable."""
239-
mock_authorization_resolvers(mocker)
240294
request = RlsapiV1InferRequest(question="Test question")
241295

242296
with pytest.raises(HTTPException) as exc_info:
@@ -247,12 +301,11 @@ async def test_infer_api_connection_error_returns_503(
247301

248302
@pytest.mark.asyncio
249303
async def test_infer_empty_llm_response_returns_fallback(
250-
mocker: MockerFixture,
251304
mock_configuration: AppConfig,
252305
mock_empty_llm_response: None,
306+
mock_auth_resolvers: None,
253307
) -> None:
254308
"""Test /infer endpoint returns fallback text when LLM returns empty response."""
255-
mock_authorization_resolvers(mocker)
256309
request = RlsapiV1InferRequest(question="Test question")
257310

258311
response = await infer_endpoint(infer_request=request, auth=MOCK_AUTH)

0 commit comments

Comments
 (0)