Skip to content

Commit 4f7b91d

Browse files
committed
RSPEED-2529: replace string concatenation with Jinja2 template rendering for system prompts
Operators can now use Jinja2 template variables ({{ date }}, {{ os }}, {{ version }}, {{ arch }}) and conditionals in system prompts. Prompts without template syntax pass through unchanged. Uses SandboxedEnvironment for defense-in-depth and catches TemplateSyntaxError so malformed prompts surface a clear ValueError instead of an opaque traceback on every request. The compiled template is cached via lru_cache. Signed-off-by: Major Hayden <major@redhat.com>
1 parent 4380346 commit 4f7b91d

4 files changed

Lines changed: 181 additions & 91 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ dependencies = [
6767
"azure-core>=1.38.0",
6868
"azure-identity>=1.21.0",
6969
"pyasn1>=0.6.2",
70+
# Used for system prompt template variable rendering
71+
"jinja2>=3.1.0",
7072
]
7173

7274

src/app/endpoints/rlsapi_v1.py

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
from the RHEL Lightspeed Command Line Assistant (CLA).
55
"""
66

7+
import functools
78
import time
89
from datetime import datetime
910
from typing import Annotated, Any, Optional, cast
1011

12+
import jinja2
13+
from jinja2.sandbox import SandboxedEnvironment
1114
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request
1215
from llama_stack_api.openai_responses import OpenAIResponseObject
1316
from llama_stack_client import APIConnectionError, APIStatusError, RateLimitError
@@ -54,6 +57,7 @@
5457
# Keep this tuple centralized so infer_endpoint can catch all expected backend
5558
# failures in one place while preserving a single telemetry/error-mapping path.
5659
_INFER_HANDLED_EXCEPTIONS = (
60+
ValueError,
5761
RuntimeError,
5862
APIConnectionError,
5963
RateLimitError,
@@ -102,44 +106,66 @@ def _get_rh_identity_context(request: Request) -> tuple[str, str]:
102106

103107

104108
def _build_instructions(systeminfo: RlsapiV1SystemInfo) -> str:
105-
"""Build LLM instructions incorporating date and system context.
109+
"""Build LLM instructions by rendering the system prompt as a Jinja2 template.
106110
107-
Enhances the default system prompt with today's date and RHEL system
108-
information to provide the LLM with relevant context about the user's
109-
environment and current time.
111+
The base prompt is rendered with the context variables ``date``, ``os``,
112+
``version``, and ``arch``. Prompts without template markers pass through
113+
unchanged. The compiled template is cached after the first call.
110114
111115
Args:
112116
systeminfo: System information from the client (OS, version, arch).
113117
114118
Returns:
115-
Instructions string for the LLM, with date and system context.
119+
The rendered instructions string for the LLM.
116120
"""
117-
base_prompt = _get_base_prompt()
118121
date_today = datetime.now().strftime("%B %d, %Y")
119122

120-
context_parts = []
121-
if systeminfo.os:
122-
context_parts.append(f"OS: {systeminfo.os}")
123-
if systeminfo.version:
124-
context_parts.append(f"Version: {systeminfo.version}")
125-
if systeminfo.arch:
126-
context_parts.append(f"Architecture: {systeminfo.arch}")
123+
return _get_prompt_template().render(
124+
date=date_today,
125+
os=systeminfo.os or "",
126+
version=systeminfo.version or "",
127+
arch=systeminfo.arch or "",
128+
)
129+
127130

128-
if not context_parts:
129-
return f"{base_prompt}\n\nToday's date: {date_today}"
131+
@functools.lru_cache(maxsize=1)
132+
def _get_prompt_template() -> jinja2.Template:
133+
"""Compile and cache the system prompt as a Jinja2 template.
130134
131-
system_context = ", ".join(context_parts)
132-
return f"{base_prompt}\n\nToday's date: {date_today}\n\nUser's system: {system_context}"
135+
The template is compiled once on first call and reused for all subsequent
136+
requests since the system prompt does not change at runtime.
133137
138+
Uses SandboxedEnvironment to restrict template capabilities. The template
139+
source is admin-controlled (config file), but sandboxing provides
140+
defense-in-depth: if the configuration surface ever expands (e.g. prompts
141+
from a database or API), unsandboxed templates could expose Python
142+
internals via Jinja2's introspection (``__class__``, ``__subclasses__``).
134143
135-
def _get_base_prompt() -> str:
136-
"""Get the base system prompt with configuration fallback."""
137-
if (
138-
configuration.customization is not None
144+
TemplateSyntaxError is caught and re-raised as ValueError so that
145+
malformed prompts produce a clear, actionable error message instead of
146+
an opaque Jinja2 traceback on every request. Because lru_cache does not
147+
cache exceptions, the error will repeat until the admin fixes the config.
148+
"""
149+
# SandboxedEnvironment disables dangerous operations (getattr on dunders,
150+
# calls to unsafe methods) while still supporting the template variables
151+
# and conditionals used in system prompts ({{ date }}, {% if os %}, etc.).
152+
env = SandboxedEnvironment()
153+
154+
prompt = (
155+
configuration.customization.system_prompt
156+
if configuration.customization is not None
139157
and configuration.customization.system_prompt is not None
140-
):
141-
return configuration.customization.system_prompt
142-
return constants.DEFAULT_SYSTEM_PROMPT
158+
else constants.DEFAULT_SYSTEM_PROMPT
159+
)
160+
161+
try:
162+
return env.from_string(prompt)
163+
except jinja2.TemplateSyntaxError as exc:
164+
# Surface the exact syntax problem so operators can fix their config
165+
# without digging through a full Jinja2 stack trace.
166+
raise ValueError(
167+
f"System prompt contains invalid Jinja2 syntax: {exc}"
168+
) from exc
143169

144170

145171
async def _get_default_model_id() -> str:
@@ -319,7 +345,7 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
319345
return inference_time
320346

321347

322-
def _map_inference_error_to_http_exception(
348+
def _map_inference_error_to_http_exception( # pylint: disable=too-many-return-statements
323349
error: Exception, model_id: str, request_id: str
324350
) -> Optional[HTTPException]:
325351
"""Map known inference errors to HTTPException.
@@ -328,6 +354,13 @@ def _map_inference_error_to_http_exception(
328354
so callers can preserve existing re-raise behavior for unknown runtime
329355
errors.
330356
"""
357+
if isinstance(error, ValueError):
358+
logger.error(
359+
"Invalid system prompt template for request %s: %s", request_id, error
360+
)
361+
error_response = InternalServerErrorResponse.generic()
362+
return HTTPException(**error_response.model_dump())
363+
331364
if isinstance(error, RuntimeError):
332365
error_message = str(error).lower()
333366
if "context_length" in error_message or "context length" in error_message:
@@ -398,7 +431,6 @@ async def infer_endpoint( # pylint: disable=R0914
398431
logger.info("Processing rlsapi v1 /infer request %s", request_id)
399432

400433
input_source = infer_request.get_input_source()
401-
instructions = _build_instructions(infer_request.context.systeminfo)
402434
model_id = await _get_default_model_id()
403435
provider, model = extract_provider_and_model_from_model_id(model_id)
404436
mcp_tools: list[Any] = await get_mcp_tools(request_headers=request.headers)
@@ -408,6 +440,7 @@ async def infer_endpoint( # pylint: disable=R0914
408440

409441
start_time = time.monotonic()
410442
try:
443+
instructions = _build_instructions(infer_request.context.systeminfo)
411444
response_text = await retrieve_simple_response(
412445
input_source,
413446
instructions,

tests/unit/app/endpoints/test_rlsapi_v1.py

Lines changed: 118 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
# pylint: disable=unused-argument
55

66
import re
7-
from typing import Any, Optional
7+
from collections.abc import Callable
8+
from typing import Any
89

910
import pytest
1011
from fastapi import HTTPException, status
@@ -17,6 +18,7 @@
1718
AUTH_DISABLED,
1819
_build_instructions,
1920
_get_default_model_id,
21+
_get_prompt_template,
2022
_get_rh_identity_context,
2123
infer_endpoint,
2224
retrieve_simple_response,
@@ -39,6 +41,26 @@
3941
MOCK_AUTH: AuthTuple = ("mock_user_id", "mock_username", False, "mock_token")
4042

4143

44+
@pytest.fixture(autouse=True)
45+
def _clear_prompt_template_cache() -> None:
46+
"""Clear the lru_cache on _get_prompt_template between tests."""
47+
_get_prompt_template.cache_clear()
48+
49+
50+
@pytest.fixture(name="mock_custom_prompt")
51+
def mock_custom_prompt_fixture(mocker: MockerFixture) -> Callable[[str], None]:
52+
"""Factory fixture that patches configuration with a custom system prompt."""
53+
54+
def _set(prompt: str) -> None:
55+
mock_customization = mocker.Mock()
56+
mock_customization.system_prompt = prompt
57+
mock_config = mocker.Mock()
58+
mock_config.customization = mock_customization
59+
mocker.patch("app.endpoints.rlsapi_v1.configuration", mock_config)
60+
61+
return _set
62+
63+
4264
def _create_mock_request(mocker: MockerFixture, rh_identity: Any = None) -> Any:
4365
"""Create a mock FastAPI Request with optional RH Identity data."""
4466
mock_request = mocker.Mock()
@@ -140,93 +162,124 @@ def mock_generic_runtime_error_fixture(mocker: MockerFixture) -> None:
140162
# --- Test _build_instructions ---
141163

142164

143-
@pytest.mark.parametrize(
144-
("systeminfo_kwargs", "expected_contains", "expected_not_contains"),
145-
[
146-
pytest.param(
147-
{"os": "RHEL", "version": "9.3", "arch": "x86_64"},
148-
["OS: RHEL", "Version: 9.3", "Architecture: x86_64"],
149-
[],
150-
id="full_systeminfo",
151-
),
152-
pytest.param(
153-
{"os": "RHEL", "version": "", "arch": ""},
154-
["OS: RHEL"],
155-
["Version:", "Architecture:"],
156-
id="partial_systeminfo",
157-
),
158-
pytest.param(
159-
{},
160-
[constants.DEFAULT_SYSTEM_PROMPT],
161-
["OS:", "Version:", "Architecture:"],
162-
id="empty_systeminfo",
163-
),
164-
],
165-
)
166-
def test_build_instructions(
167-
systeminfo_kwargs: dict[str, str],
168-
expected_contains: list[str],
169-
expected_not_contains: list[str],
170-
) -> None:
171-
"""Test _build_instructions includes date and system info."""
172-
systeminfo = RlsapiV1SystemInfo(**systeminfo_kwargs)
165+
def test_build_instructions_default_prompt_passes_through() -> None:
166+
"""Test _build_instructions returns default prompt unchanged when no template vars."""
167+
systeminfo = RlsapiV1SystemInfo(os="RHEL", version="9.3", arch="x86_64")
173168
result = _build_instructions(systeminfo)
174169

175-
assert re.search(r"Today's date: \w+ \d{2}, \d{4}", result)
176-
for expected in expected_contains:
177-
assert expected in result
178-
for not_expected in expected_not_contains:
179-
assert not_expected not in result
180-
181-
182-
# --- Test _build_instructions with customization.system_prompt ---
170+
assert result == constants.DEFAULT_SYSTEM_PROMPT
183171

184172

185-
@pytest.mark.parametrize(
186-
("custom_prompt", "expected_prompt"),
187-
[
188-
pytest.param(
189-
"You are a RHEL expert.",
190-
"You are a RHEL expert.",
191-
id="customization_system_prompt_set",
192-
),
193-
pytest.param(
194-
None,
195-
constants.DEFAULT_SYSTEM_PROMPT,
196-
id="customization_system_prompt_none",
197-
),
198-
],
199-
)
200-
def test_build_instructions_with_customization(
201-
mocker: MockerFixture,
202-
custom_prompt: Optional[str],
203-
expected_prompt: str,
204-
) -> None:
205-
"""Test _build_instructions uses customization.system_prompt when set."""
173+
def test_build_instructions_with_customization(mocker: MockerFixture) -> None:
174+
"""Test _build_instructions uses customization.system_prompt with template vars."""
175+
template = "Expert assistant.\n\nDate: {{ date }}\nOS: {{ os }}"
206176
mock_customization = mocker.Mock()
207-
mock_customization.system_prompt = custom_prompt
177+
mock_customization.system_prompt = template
208178
mock_config = mocker.Mock()
209179
mock_config.customization = mock_customization
210180
mocker.patch("app.endpoints.rlsapi_v1.configuration", mock_config)
211181

212182
systeminfo = RlsapiV1SystemInfo(os="RHEL", version="9.3", arch="x86_64")
213183
result = _build_instructions(systeminfo)
214184

215-
assert expected_prompt in result
185+
assert "Expert assistant." in result
216186
assert "OS: RHEL" in result
187+
assert re.search(r"Date: \w+ \d{2}, \d{4}", result)
217188

218189

219190
def test_build_instructions_no_customization(mocker: MockerFixture) -> None:
220-
"""Test _build_instructions falls back when customization is None."""
191+
"""Test _build_instructions falls back to DEFAULT_SYSTEM_PROMPT."""
221192
mock_config = mocker.Mock()
222193
mock_config.customization = None
223194
mocker.patch("app.endpoints.rlsapi_v1.configuration", mock_config)
224195

225196
systeminfo = RlsapiV1SystemInfo()
226197
result = _build_instructions(systeminfo)
227198

228-
assert result.startswith(constants.DEFAULT_SYSTEM_PROMPT)
229-
assert re.search(r"Today's date: \w+ \d{2}, \d{4}", result)
199+
assert result == constants.DEFAULT_SYSTEM_PROMPT
200+
201+
202+
# --- Test Jinja2 template rendering ---
203+
204+
205+
def test_build_instructions_renders_jinja2_template(
206+
mock_custom_prompt: Callable[[str], None],
207+
) -> None:
208+
"""Test _build_instructions renders Jinja2 template variables instead of appending."""
209+
mock_custom_prompt(
210+
"You are an assistant.\n\nDate: {{ date }}\nOS: {{ os }} {{ version }} ({{ arch }})"
211+
)
212+
213+
systeminfo = RlsapiV1SystemInfo(os="RHEL", version="9.3", arch="x86_64")
214+
result = _build_instructions(systeminfo)
215+
216+
assert "OS: RHEL 9.3 (x86_64)" in result
217+
assert re.search(r"Date: \w+ \d{2}, \d{4}", result)
218+
assert "Today's date:" not in result
219+
assert "User's system:" not in result
220+
221+
222+
def test_build_instructions_jinja2_none_values_render_empty(
223+
mock_custom_prompt: Callable[[str], None],
224+
) -> None:
225+
"""Test that None system info values render as empty strings, not 'None'."""
226+
mock_custom_prompt("Assistant.\nOS={{ os }} VER={{ version }} ARCH={{ arch }}")
227+
228+
systeminfo = RlsapiV1SystemInfo()
229+
result = _build_instructions(systeminfo)
230+
231+
assert "None" not in result
232+
assert "OS= VER= ARCH=" in result
233+
234+
235+
def test_build_instructions_jinja2_conditionals(
236+
mock_custom_prompt: Callable[[str], None],
237+
) -> None:
238+
"""Test that Jinja2 conditionals work in system prompt templates."""
239+
mock_custom_prompt(
240+
"Assistant.{% if os %} OS: {{ os }}{% endif %}"
241+
"{% if version %} VER: {{ version }}{% endif %}"
242+
)
243+
244+
systeminfo = RlsapiV1SystemInfo(os="RHEL")
245+
result = _build_instructions(systeminfo)
246+
247+
assert "OS: RHEL" in result
248+
assert "VER:" not in result
249+
250+
251+
def test_build_instructions_plain_prompt_passes_through(
252+
mock_custom_prompt: Callable[[str], None],
253+
) -> None:
254+
"""Test that prompts without Jinja2 syntax pass through unchanged."""
255+
plain_prompt = "You are an expert RHEL assistant."
256+
mock_custom_prompt(plain_prompt)
257+
258+
systeminfo = RlsapiV1SystemInfo(os="RHEL", version="9.3", arch="x86_64")
259+
result = _build_instructions(systeminfo)
260+
261+
assert result == plain_prompt
262+
263+
264+
@pytest.mark.parametrize(
265+
"bad_template",
266+
[
267+
pytest.param("Hello {{ unclosed", id="unclosed_variable"),
268+
pytest.param("{% if %}", id="if_without_condition"),
269+
pytest.param("{% endfor %}", id="endfor_without_for"),
270+
],
271+
)
272+
def test_build_instructions_malformed_template_raises_value_error(
273+
mock_custom_prompt: Callable[[str], None],
274+
bad_template: str,
275+
) -> None:
276+
"""Test that invalid Jinja2 syntax in system prompt raises ValueError."""
277+
mock_custom_prompt(bad_template)
278+
279+
systeminfo = RlsapiV1SystemInfo(os="RHEL", version="9.3", arch="x86_64")
280+
281+
with pytest.raises(ValueError, match="invalid Jinja2 syntax"):
282+
_build_instructions(systeminfo)
230283

231284

232285
# --- Test _get_default_model_id ---

0 commit comments

Comments
 (0)