Skip to content

Commit b452672

Browse files
committed
feat(rlsapi): include system info context in LLM instructions
Enhance the LLM instructions with the user's RHEL system information (OS, version, architecture) when available. This gives the LLM better context about the environment the user is asking questions about, enabling more relevant and accurate responses. Signed-off-by: Major Hayden <major@redhat.com>
1 parent ba1c5cb commit b452672

2 files changed

Lines changed: 101 additions & 18 deletions

File tree

src/app/endpoints/rlsapi_v1.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
UnauthorizedResponse,
2828
UnprocessableEntityResponse,
2929
)
30-
from models.rlsapi.requests import RlsapiV1InferRequest
30+
from models.rlsapi.requests import RlsapiV1InferRequest, RlsapiV1SystemInfo
3131
from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse
3232
from utils.responses import extract_text_from_response_output_item
3333
from utils.suid import get_suid
@@ -49,6 +49,35 @@
4949
}
5050

5151

52+
def _build_instructions(systeminfo: RlsapiV1SystemInfo) -> str:
53+
"""Build LLM instructions incorporating system context when available.
54+
55+
Enhances the default system prompt with RHEL system information to provide
56+
the LLM with relevant context about the user's environment.
57+
58+
Args:
59+
systeminfo: System information from the client (OS, version, arch).
60+
61+
Returns:
62+
Instructions string for the LLM, with system context if available.
63+
"""
64+
base_prompt = constants.DEFAULT_SYSTEM_PROMPT
65+
66+
context_parts = []
67+
if systeminfo.os:
68+
context_parts.append(f"OS: {systeminfo.os}")
69+
if systeminfo.version:
70+
context_parts.append(f"Version: {systeminfo.version}")
71+
if systeminfo.arch:
72+
context_parts.append(f"Architecture: {systeminfo.arch}")
73+
74+
if not context_parts:
75+
return base_prompt
76+
77+
system_context = ", ".join(context_parts)
78+
return f"{base_prompt}\n\nUser's system: {system_context}"
79+
80+
5281
def _get_default_model_id() -> str:
5382
"""Get the default model ID from configuration.
5483
@@ -82,14 +111,15 @@ def _get_default_model_id() -> str:
82111
)
83112

84113

85-
async def retrieve_simple_response(question: str) -> str:
114+
async def retrieve_simple_response(question: str, instructions: str) -> str:
86115
"""Retrieve a simple response from the LLM for a stateless query.
87116
88117
Uses the Responses API for simple stateless inference, consistent with
89118
other endpoints (query_v2, streaming_query_v2).
90119
91120
Args:
92121
question: The combined user input (question + context).
122+
instructions: System instructions for the LLM.
93123
94124
Returns:
95125
The LLM-generated response text.
@@ -106,7 +136,7 @@ async def retrieve_simple_response(question: str) -> str:
106136
response = await client.responses.create(
107137
input=question,
108138
model=model_id,
109-
instructions=constants.DEFAULT_SYSTEM_PROMPT,
139+
instructions=instructions,
110140
stream=False,
111141
store=False,
112142
)
@@ -149,14 +179,14 @@ async def infer_endpoint(
149179

150180
logger.info("Processing rlsapi v1 /infer request %s", request_id)
151181

152-
# Combine all input sources (question, stdin, attachments, terminal)
153182
input_source = infer_request.get_input_source()
183+
instructions = _build_instructions(infer_request.context.systeminfo)
154184
logger.debug(
155185
"Request %s: Combined input source length: %d", request_id, len(input_source)
156186
)
157187

158188
try:
159-
response_text = await retrieve_simple_response(input_source)
189+
response_text = await retrieve_simple_response(input_source, instructions)
160190
except APIConnectionError as e:
161191
metrics.llm_calls_failures_total.inc()
162192
logger.error(

tests/unit/app/endpoints/test_rlsapi_v1.py

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import constants
1515
from app.endpoints.rlsapi_v1 import (
16+
_build_instructions,
1617
_get_default_model_id,
1718
infer_endpoint,
1819
retrieve_simple_response,
@@ -87,6 +88,12 @@ def mock_empty_llm_response_fixture(mocker: MockerFixture) -> None:
8788
_setup_responses_mock(mocker, mocker.AsyncMock(return_value=mock_response))
8889

8990

91+
@pytest.fixture(name="mock_auth_resolvers")
92+
def mock_auth_resolvers_fixture(mocker: MockerFixture) -> None:
93+
"""Mock authorization resolvers for endpoint tests."""
94+
mock_authorization_resolvers(mocker)
95+
96+
9097
@pytest.fixture(name="mock_api_connection_error")
9198
def mock_api_connection_error_fixture(mocker: MockerFixture) -> None:
9299
"""Mock responses.create() to raise APIConnectionError."""
@@ -96,6 +103,47 @@ def mock_api_connection_error_fixture(mocker: MockerFixture) -> None:
96103
)
97104

98105

106+
# --- Test _build_instructions ---
107+
108+
109+
@pytest.mark.parametrize(
110+
("systeminfo_kwargs", "expected_contains", "expected_not_contains"),
111+
[
112+
pytest.param(
113+
{"os": "RHEL", "version": "9.3", "arch": "x86_64"},
114+
["OS: RHEL", "Version: 9.3", "Architecture: x86_64"],
115+
[],
116+
id="full_systeminfo",
117+
),
118+
pytest.param(
119+
{"os": "RHEL", "version": "", "arch": ""},
120+
["OS: RHEL"],
121+
["Version:", "Architecture:"],
122+
id="partial_systeminfo",
123+
),
124+
pytest.param(
125+
{},
126+
[constants.DEFAULT_SYSTEM_PROMPT],
127+
["OS:", "Version:", "Architecture:"],
128+
id="empty_systeminfo",
129+
),
130+
],
131+
)
132+
def test_build_instructions(
133+
systeminfo_kwargs: dict[str, str],
134+
expected_contains: list[str],
135+
expected_not_contains: list[str],
136+
) -> None:
137+
"""Test _build_instructions with various system info combinations."""
138+
systeminfo = RlsapiV1SystemInfo(**systeminfo_kwargs)
139+
result = _build_instructions(systeminfo)
140+
141+
for expected in expected_contains:
142+
assert expected in result
143+
for not_expected in expected_not_contains:
144+
assert not_expected not in result
145+
146+
99147
# --- Test _get_default_model_id ---
100148

101149

@@ -151,7 +199,9 @@ async def test_retrieve_simple_response_success(
151199
mock_configuration: AppConfig, mock_llm_response: None
152200
) -> None:
153201
"""Test retrieve_simple_response returns LLM response text."""
154-
response = await retrieve_simple_response("How do I list files?")
202+
response = await retrieve_simple_response(
203+
"How do I list files?", constants.DEFAULT_SYSTEM_PROMPT
204+
)
155205
assert response == "This is a test LLM response."
156206

157207

@@ -160,7 +210,9 @@ async def test_retrieve_simple_response_empty_output(
160210
mock_configuration: AppConfig, mock_empty_llm_response: None
161211
) -> None:
162212
"""Test retrieve_simple_response handles empty LLM output."""
163-
response = await retrieve_simple_response("Test question")
213+
response = await retrieve_simple_response(
214+
"Test question", constants.DEFAULT_SYSTEM_PROMPT
215+
)
164216
assert response == ""
165217

166218

@@ -170,18 +222,19 @@ async def test_retrieve_simple_response_api_connection_error(
170222
) -> None:
171223
"""Test retrieve_simple_response propagates APIConnectionError."""
172224
with pytest.raises(APIConnectionError):
173-
await retrieve_simple_response("Test question")
225+
await retrieve_simple_response("Test question", constants.DEFAULT_SYSTEM_PROMPT)
174226

175227

176228
# --- Test infer_endpoint ---
177229

178230

179231
@pytest.mark.asyncio
180232
async def test_infer_minimal_request(
181-
mocker: MockerFixture, mock_configuration: AppConfig, mock_llm_response: None
233+
mock_configuration: AppConfig,
234+
mock_llm_response: None,
235+
mock_auth_resolvers: None,
182236
) -> None:
183237
"""Test /infer endpoint returns valid response with LLM text."""
184-
mock_authorization_resolvers(mocker)
185238
request = RlsapiV1InferRequest(question="How do I list files?")
186239

187240
response = await infer_endpoint(infer_request=request, auth=MOCK_AUTH)
@@ -194,10 +247,11 @@ async def test_infer_minimal_request(
194247

195248
@pytest.mark.asyncio
196249
async def test_infer_full_context_request(
197-
mocker: MockerFixture, mock_configuration: AppConfig, mock_llm_response: None
250+
mock_configuration: AppConfig,
251+
mock_llm_response: None,
252+
mock_auth_resolvers: None,
198253
) -> None:
199254
"""Test /infer endpoint handles full context (stdin, attachments, terminal)."""
200-
mock_authorization_resolvers(mocker)
201255
request = RlsapiV1InferRequest(
202256
question="Why did this command fail?",
203257
context=RlsapiV1Context(
@@ -217,10 +271,11 @@ async def test_infer_full_context_request(
217271

218272
@pytest.mark.asyncio
219273
async def test_infer_generates_unique_request_ids(
220-
mocker: MockerFixture, mock_configuration: AppConfig, mock_llm_response: None
274+
mock_configuration: AppConfig,
275+
mock_llm_response: None,
276+
mock_auth_resolvers: None,
221277
) -> None:
222278
"""Test that each /infer call generates a unique request_id."""
223-
mock_authorization_resolvers(mocker)
224279
request = RlsapiV1InferRequest(question="How do I list files?")
225280

226281
response1 = await infer_endpoint(infer_request=request, auth=MOCK_AUTH)
@@ -231,12 +286,11 @@ async def test_infer_generates_unique_request_ids(
231286

232287
@pytest.mark.asyncio
233288
async def test_infer_api_connection_error_returns_503(
234-
mocker: MockerFixture,
235289
mock_configuration: AppConfig,
236290
mock_api_connection_error: None,
291+
mock_auth_resolvers: None,
237292
) -> None:
238293
"""Test /infer endpoint returns 503 when LLM service is unavailable."""
239-
mock_authorization_resolvers(mocker)
240294
request = RlsapiV1InferRequest(question="Test question")
241295

242296
with pytest.raises(HTTPException) as exc_info:
@@ -247,12 +301,11 @@ async def test_infer_api_connection_error_returns_503(
247301

248302
@pytest.mark.asyncio
249303
async def test_infer_empty_llm_response_returns_fallback(
250-
mocker: MockerFixture,
251304
mock_configuration: AppConfig,
252305
mock_empty_llm_response: None,
306+
mock_auth_resolvers: None,
253307
) -> None:
254308
"""Test /infer endpoint returns fallback text when LLM returns empty response."""
255-
mock_authorization_resolvers(mocker)
256309
request = RlsapiV1InferRequest(question="Test question")
257310

258311
response = await infer_endpoint(infer_request=request, auth=MOCK_AUTH)

0 commit comments

Comments
 (0)