Skip to content

Commit 4476220

Browse files
authored
Merge pull request #1232 from radofuchs/LCORE_1206_OLS_test_alignment
LCORE-1206: add tests for too long question
2 parents 37829bf + 7935999 commit 4476220

5 files changed

Lines changed: 288 additions & 7 deletions

File tree

tests/e2e/features/environment.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
from tests.e2e.utils.prow_utils import restore_llama_stack_pod
1717
from behave.runner import Context
1818

19+
from tests.e2e.utils.llama_stack_shields import (
20+
register_shield,
21+
unregister_shield,
22+
)
1923
from tests.e2e.utils.utils import (
2024
create_config_backup,
2125
is_prow_environment,
@@ -169,6 +173,27 @@ def before_scenario(context: Context, scenario: Scenario) -> None:
169173
scenario.skip("Skipped in library mode (no separate llama-stack container)")
170174
return
171175

176+
# @disable-shields: unregister shield via client.shields.delete("llama-guard").
177+
# Only in server mode: in library mode there is no separate Llama Stack to call,
178+
# and unregistering in the test process would not affect the app's in-process instance.
179+
if "disable-shields" in scenario.effective_tags:
180+
if context.is_library_mode:
181+
scenario.skip(
182+
"Shield unregister/register only applies in server mode (Llama Stack as a "
183+
"separate service). In library mode the app's shields cannot be disabled from e2e."
184+
)
185+
return
186+
try:
187+
saved = unregister_shield("llama-guard")
188+
context.llama_guard_provider_id = saved[0] if saved else None
189+
context.llama_guard_provider_shield_id = saved[1] if saved else None
190+
print("Unregistered shield llama-guard for this scenario")
191+
except Exception as e: # pylint: disable=broad-exception-caught
192+
scenario.skip(
193+
f"Could not unregister shield (is Llama Stack reachable?): {e}"
194+
)
195+
return
196+
172197
mode_dir = "library-mode" if context.is_library_mode else "server-mode"
173198

174199
if "InvalidFeedbackStorageConfig" in scenario.effective_tags:
@@ -217,6 +242,52 @@ def after_scenario(context: Context, scenario: Scenario) -> None:
217242
switch_config(context.feature_config)
218243
restart_container("lightspeed-stack")
219244

245+
# @disable-shields: re-register shield only if we unregistered one (avoid creating a shield that did not exist)
246+
if "disable-shields" in scenario.effective_tags:
247+
provider_id = getattr(context, "llama_guard_provider_id", None)
248+
provider_shield_id = getattr(context, "llama_guard_provider_shield_id", None)
249+
if provider_id is not None and provider_shield_id is not None:
250+
try:
251+
register_shield(
252+
"llama-guard",
253+
provider_id=provider_id,
254+
provider_shield_id=provider_shield_id,
255+
)
256+
print("Re-registered shield llama-guard")
257+
except Exception as e: # pylint: disable=broad-exception-caught
258+
print(f"Warning: Could not re-register shield: {e}")
259+
260+
261+
def _print_llama_stack_diagnostics() -> None:
262+
"""Print container state, health, and recent logs to diagnose why llama-stack did not recover."""
263+
print("--- llama-stack diagnostics ---")
264+
for label, cmd in [
265+
("State", ["docker", "inspect", "--format={{.State}}", "llama-stack"]),
266+
("Health", ["docker", "inspect", "--format={{.State.Health}}", "llama-stack"]),
267+
]:
268+
try:
269+
r = subprocess.run(
270+
cmd, capture_output=True, text=True, timeout=5, check=False
271+
)
272+
print(f" {label}: {r.stdout.strip() if r.stdout else r.stderr or 'N/A'}")
273+
except subprocess.TimeoutExpired:
274+
print(f" {label}: (inspect timed out)")
275+
try:
276+
r = subprocess.run(
277+
["docker", "logs", "--tail", "40", "llama-stack"],
278+
capture_output=True,
279+
text=True,
280+
timeout=10,
281+
check=False,
282+
)
283+
out = (r.stdout or "") + (r.stderr or "")
284+
print(" Logs (last 40 lines):")
285+
for line in out.strip().splitlines():
286+
print(f" {line}")
287+
except subprocess.TimeoutExpired:
288+
print(" Logs: (timed out)")
289+
print("--- end diagnostics ---")
290+
220291

221292
def _restore_llama_stack(context: Context) -> None:
222293
"""Restore Llama Stack connection after disruption."""
@@ -263,9 +334,15 @@ def _restore_llama_stack(context: Context) -> None:
263334
time.sleep(5)
264335
else:
265336
print("Warning: Llama Stack may not be fully healthy after restoration")
337+
_print_llama_stack_diagnostics()
266338

267339
except subprocess.CalledProcessError as e:
268340
print(f"Warning: Could not restore Llama Stack connection: {e}")
341+
if e.stderr:
342+
print(f" docker start stderr: {e.stderr}")
343+
if e.stdout:
344+
print(f" docker start stdout: {e.stdout}")
345+
_print_llama_stack_diagnostics()
269346

270347

271348
def before_feature(context: Context, feature: Feature) -> None:

tests/e2e/features/query.feature

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,20 @@ Scenario: Check if LLM responds for query request with error for missing query
216216
}
217217
"""
218218
Then The status code of the response is 200
219+
220+
Scenario: Check if query with shields returns 413 when question is too long for model context
221+
Given The system is in default state
222+
And I set the Authorization header to Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva
223+
When I use "query" to ask question with too-long query and authorization header
224+
Then The status code of the response is 413
225+
And The body of the response contains Prompt is too long
226+
227+
#https://issues.redhat.com/browse/LCORE-1387
228+
@skip
229+
@disable-shields
230+
Scenario: Check if query without shields returns 413 when question is too long for model context
231+
Given The system is in default state
232+
And I set the Authorization header to Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva
233+
When I use "query" to ask question with too-long query and authorization header
234+
Then The status code of the response is 413
235+
And The body of the response contains Prompt is too long

tests/e2e/features/steps/llm_query_response.py

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
def wait_for_complete_response(context: Context) -> None:
1616
"""Wait for the response to be complete."""
1717
context.response_data = _parse_streaming_response(context.response.text)
18-
print(context.response_data)
1918
context.response.raise_for_status()
2019
assert context.response_data["finished"] is True
2120

@@ -31,10 +30,21 @@ def ask_question(context: Context, endpoint: str) -> None:
3130
json_str = replace_placeholders(context, context.text or "{}")
3231

3332
data = json.loads(json_str)
34-
print(f"Request data: {data}")
3533
context.response = requests.post(url, json=data, timeout=DEFAULT_LLM_TIMEOUT)
3634

3735

36+
def _read_streamed_response(response: requests.Response) -> str:
37+
"""Read a streaming response body, tolerating premature close (e.g. after error event)."""
38+
chunks = []
39+
try:
40+
for line in response.iter_lines(decode_unicode=True):
41+
if line is not None:
42+
chunks.append(line + "\n")
43+
except requests.exceptions.ChunkedEncodingError:
44+
pass # Server may close stream after sending an error event
45+
return "".join(chunks)
46+
47+
3848
@step('I use "{endpoint}" to ask question with authorization header')
3949
def ask_question_authorized(context: Context, endpoint: str) -> None:
4050
"""Call the service REST API endpoint with question."""
@@ -46,10 +56,40 @@ def ask_question_authorized(context: Context, endpoint: str) -> None:
4656
json_str = replace_placeholders(context, context.text or "{}")
4757

4858
data = json.loads(json_str)
49-
print(f"Request data: {data}")
50-
context.response = requests.post(
51-
url, json=data, headers=context.auth_headers, timeout=DEFAULT_LLM_TIMEOUT
52-
)
59+
if endpoint == "streaming_query":
60+
resp = requests.post(
61+
url,
62+
json=data,
63+
headers=context.auth_headers,
64+
timeout=DEFAULT_LLM_TIMEOUT,
65+
stream=True,
66+
)
67+
# Consume stream so server close after error event does not raise
68+
body = _read_streamed_response(resp)
69+
resp._content = body.encode(resp.encoding or "utf-8")
70+
context.response = resp
71+
else:
72+
context.response = requests.post(
73+
url, json=data, headers=context.auth_headers, timeout=DEFAULT_LLM_TIMEOUT
74+
)
75+
76+
77+
# Query length chosen to exceed typical model context windows (e.g. 128k tokens)
78+
_TOO_LONG_QUERY_LENGTH = 80_000
79+
80+
81+
@step('I use "{endpoint}" to ask question with too-long query and authorization header')
82+
def ask_question_too_long_authorized(context: Context, endpoint: str) -> None:
83+
"""Call the query endpoint with a query string that exceeds model context (expect 413)."""
84+
long_query = "what is openshift?" * _TOO_LONG_QUERY_LENGTH
85+
payload = {
86+
"query": long_query,
87+
"model": context.default_model,
88+
"provider": context.default_provider,
89+
}
90+
context.text = json.dumps(payload)
91+
print(f"Request: query length={len(long_query)}, model={context.default_model}")
92+
ask_question_authorized(context, endpoint)
5393

5494

5595
@step("I store conversation details")
@@ -72,7 +112,6 @@ def ask_question_in_same_conversation(context: Context, endpoint: str) -> None:
72112
headers = context.auth_headers if hasattr(context, "auth_headers") else {}
73113
data["conversation_id"] = context.response_data["conversation_id"]
74114

75-
print(f"Request data: {data}")
76115
context.response = requests.post(
77116
url, json=data, headers=headers, timeout=DEFAULT_LLM_TIMEOUT
78117
)
@@ -142,6 +181,29 @@ def check_streamed_fragments_in_response(context: Context) -> None:
142181
), f"Fragment '{expected}' not found in LLM response: '{response}'"
143182

144183

184+
@then("The streamed response contains error message {message}")
185+
def check_streamed_response_error_message(context: Context, message: str) -> None:
186+
"""Check that the streamed SSE response contains an error event with the given message.
187+
188+
Parses the response body as SSE, asserts that an event with event type 'error' is
189+
present, and that its 'response' or 'cause' field contains the given message.
190+
Use for streaming endpoints when the error is delivered in the stream (e.g. 200 + error event).
191+
"""
192+
assert context.response is not None, "Request needs to be performed first"
193+
print(context.response.text)
194+
parsed = _parse_streaming_response(context.response.text)
195+
stream_error = parsed.get("stream_error")
196+
assert (
197+
stream_error is not None
198+
), "No error event in stream. Expected an SSE event with event type 'error'."
199+
response_text = str(stream_error.get("response", ""))
200+
cause_text = str(stream_error.get("cause", ""))
201+
assert message in response_text or message in cause_text, (
202+
f"Expected error message '{message}' not found in stream error event: "
203+
f"response={response_text!r}, cause={cause_text!r}"
204+
)
205+
206+
145207
@then("The streamed response is equal to the full response")
146208
def compare_streamed_responses(context: Context) -> None:
147209
"""Check that streamed response is equal to complete response.
@@ -171,6 +233,9 @@ def _parse_streaming_response(response_text: str) -> dict:
171233
full_response_split = []
172234
finished = False
173235
first_token = True
236+
stream_error = (
237+
None # {"status_code": int, "response": str, "cause": str} if event "error"
238+
)
174239

175240
for line in lines:
176241
if line.startswith("data: "):
@@ -190,6 +255,8 @@ def _parse_streaming_response(response_text: str) -> dict:
190255
full_response = data["data"]["token"]
191256
elif event == "end":
192257
finished = True
258+
elif event == "error":
259+
stream_error = data.get("data") or {}
193260
except json.JSONDecodeError:
194261
continue # Skip malformed lines
195262

@@ -198,4 +265,5 @@ def _parse_streaming_response(response_text: str) -> dict:
198265
"response": "".join(full_response_split),
199266
"response_complete": full_response,
200267
"finished": finished,
268+
"stream_error": stream_error,
201269
}

tests/e2e/features/streaming_query.feature

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,18 @@ Feature: streaming_query endpoint API tests
178178
}
179179
}
180180
"""
181+
182+
Scenario: Check if streaming_query with shields returns 413 when question is too long for model context
183+
Given The system is in default state
184+
And I set the Authorization header to Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva
185+
When I use "streaming_query" to ask question with too-long query and authorization header
186+
Then The status code of the response is 413
187+
And The body of the response contains Prompt is too long
188+
189+
@disable-shields
190+
Scenario: Check if streaming_query without shields returns 200 and error in stream when question is too long for model context
191+
Given The system is in default state
192+
And I set the Authorization header to Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva
193+
When I use "streaming_query" to ask question with too-long query and authorization header
194+
Then The status code of the response is 200
195+
And The streamed response contains error message Prompt is too long
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""E2E helpers to unregister and re-register Llama Stack shields via the client API.
2+
3+
Used by the @disable-shields tag: before the scenario we call client.shields.delete()
4+
to unregister the shield; after the scenario we call client.shields.register()
5+
to restore it. Only applies in server mode (Llama Stack as a separate service).
6+
Requires E2E_LLAMA_STACK_URL or E2E_LLAMA_HOSTNAME/E2E_LLAMA_PORT.
7+
"""
8+
9+
import asyncio
10+
import os
11+
from typing import Optional
12+
13+
from llama_stack_client import (
14+
APIConnectionError,
15+
AsyncLlamaStackClient,
16+
APIStatusError,
17+
)
18+
19+
20+
def _get_llama_stack_client() -> AsyncLlamaStackClient:
21+
"""Build an AsyncLlamaStackClient from env (for e2e test use)."""
22+
base_url = os.getenv("E2E_LLAMA_STACK_URL")
23+
if not base_url:
24+
host = os.getenv("E2E_LLAMA_HOSTNAME", "localhost")
25+
port = os.getenv("E2E_LLAMA_PORT", "8321")
26+
base_url = f"http://{host}:{port}"
27+
api_key = os.getenv("E2E_LLAMA_STACK_API_KEY", "xyzzy")
28+
timeout = int(os.getenv("E2E_LLAMA_STACK_TIMEOUT", "60"))
29+
return AsyncLlamaStackClient(base_url=base_url, api_key=api_key, timeout=timeout)
30+
31+
32+
async def _unregister_shield_async(identifier: str) -> Optional[tuple[str, str]]:
33+
"""Unregister a shield by identifier; return (provider_id, provider_shield_id) for restore."""
34+
client = _get_llama_stack_client()
35+
try:
36+
shields = await client.shields.list()
37+
provider_id = None
38+
provider_shield_id = None
39+
found = False
40+
for shield in shields:
41+
if getattr(shield, "identifier", None) == identifier:
42+
provider_id = getattr(shield, "provider_id", None)
43+
provider_shield_id = getattr(
44+
shield, "provider_resource_id", None
45+
) or getattr(shield, "provider_shield_id", None)
46+
found = True
47+
break
48+
if not found:
49+
# Shield not registered; nothing to delete, scenario can proceed
50+
return None
51+
try:
52+
await client.shields.delete(identifier)
53+
except APIConnectionError:
54+
raise
55+
except APIStatusError as e:
56+
# 400 "not found": shield already absent, scenario can proceed
57+
if e.status_code == 400 and "not found" in str(e).lower():
58+
return None
59+
raise
60+
if provider_id is not None and provider_shield_id is not None:
61+
return (provider_id, provider_shield_id)
62+
return None
63+
finally:
64+
await client.close()
65+
66+
67+
async def _register_shield_async(
68+
shield_id: str,
69+
provider_id: str,
70+
provider_shield_id: str,
71+
) -> None:
72+
"""Register a shield (restore after unregister)."""
73+
client = _get_llama_stack_client()
74+
try:
75+
await client.shields.register(
76+
shield_id=shield_id,
77+
provider_id=provider_id,
78+
provider_shield_id=provider_shield_id,
79+
)
80+
finally:
81+
await client.close()
82+
83+
84+
def unregister_shield(
85+
identifier: str = "llama-guard",
86+
) -> Optional[tuple[str, str]]:
87+
"""Unregister the shield via client.shields.delete(); return (provider_id, provider_shield_id)."""
88+
return asyncio.run(_unregister_shield_async(identifier))
89+
90+
91+
def register_shield(
92+
shield_id: str = "llama-guard",
93+
provider_id: Optional[str] = None,
94+
provider_shield_id: Optional[str] = None,
95+
) -> None:
96+
"""Re-register the shield via client.shields.register()."""
97+
if not provider_id:
98+
provider_id = os.getenv("E2E_LLAMA_GUARD_PROVIDER_ID", "llama-guard")
99+
if not provider_shield_id:
100+
provider_shield_id = os.getenv(
101+
"E2E_LLAMA_GUARD_PROVIDER_SHIELD_ID",
102+
"openai/gpt-4o-mini",
103+
)
104+
asyncio.run(_register_shield_async(shield_id, provider_id, provider_shield_id))

0 commit comments

Comments
 (0)