Skip to content

Commit e2c3886

Browse files
jrobertboosasimurka
authored andcommitted
fix
1 parent e0a53c3 commit e2c3886

3 files changed

Lines changed: 66 additions & 56 deletions

File tree

src/utils/agents/streaming.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ async def agent_response_generator(
243243
logger.debug("Starting agent streaming response processing")
244244
async with agent.run_stream_events(prompt) as stream:
245245
async for event in stream:
246+
print(f"event: {event.event_kind}")
246247
if payload := dispatch_stream_event(event, dispatch_state):
247248
yield serialize_event(payload, media_type)
248249

@@ -347,12 +348,13 @@ def _(
347348
)
348349
else:
349350
final_text = state.run_result.response.text or "".join(state.text_parts)
350-
state.chunk_id += 1
351351

352-
return TurnCompleteStreamPayload.create(
352+
payload = TurnCompleteStreamPayload.create(
353353
chunk_id=state.chunk_id,
354354
token=final_text,
355355
)
356+
state.chunk_id += 1
357+
return payload
356358

357359

358360
@dispatch_stream_event.register

tests/e2e/features/steps/llm_query_response.py

Lines changed: 60 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -93,23 +93,9 @@ def wait_for_complete_response(context: Context) -> None:
9393
"""Wait for the response to be complete."""
9494
context.response_data = _parse_streaming_response(context.response.text)
9595
context.response.raise_for_status()
96-
assert context.response_data["finished"] is True, f"Response is not finished: {context.response_data}"
97-
98-
99-
@step('I use "{endpoint}" to ask question')
100-
def ask_question(context: Context, endpoint: str) -> None:
101-
"""Call the service REST API endpoint with question."""
102-
base = f"http://{context.hostname}:{context.port}"
103-
path = f"{context.api_prefix}/{endpoint}".replace("//", "/")
104-
url = base + path
105-
106-
# Replace {MODEL} and {PROVIDER} placeholders with actual values
107-
json_str = replace_placeholders(context, context.text or "{}")
108-
109-
data = json.loads(json_str)
110-
context.response = request_with_transient_retry(
111-
method="POST", url=url, json=data, timeout=DEFAULT_LLM_TIMEOUT
112-
)
96+
assert (
97+
context.response_data["finished"] is True
98+
), f"Response is not finished: {context.response_data}"
11399

114100

115101
def _read_streamed_response(response: requests.Response) -> str:
@@ -124,41 +110,72 @@ def _read_streamed_response(response: requests.Response) -> str:
124110
return "".join(chunks)
125111

126112

127-
@step('I use "{endpoint}" to ask question with authorization header')
128-
def ask_question_authorized(context: Context, endpoint: str) -> None:
129-
"""Call the service REST API endpoint with question."""
113+
def _uses_sse(endpoint: str, data: dict[str, Any]) -> bool:
114+
"""Return whether the endpoint delivers an SSE stream for the given payload."""
115+
return endpoint == "streaming_query" or (
116+
endpoint == "responses" and bool(data.get("stream"))
117+
)
118+
119+
120+
def _post_question(
121+
context: Context,
122+
endpoint: str,
123+
headers: dict[str, str] | None = None,
124+
extra_data: dict[str, Any] | None = None,
125+
) -> requests.Response:
126+
"""POST a question to the service REST API endpoint.
127+
128+
Parameters:
129+
context: Behave context with hostname, port, and request body text.
130+
endpoint: API endpoint name (e.g. ``query``, ``streaming_query``).
131+
headers: Optional HTTP headers (e.g. authorization).
132+
extra_data: Optional fields merged into the JSON request body.
133+
134+
Returns:
135+
The HTTP response, with streamed bodies fully consumed when applicable.
136+
"""
130137
base = f"http://{context.hostname}:{context.port}"
131138
path = f"{context.api_prefix}/{endpoint}".replace("//", "/")
132139
url = base + path
133140

134-
# Replace {MODEL} and {PROVIDER} placeholders with actual values
135141
json_str = replace_placeholders(context, context.text or "{}")
136-
137142
data = json.loads(json_str)
138-
use_sse = endpoint == "streaming_query" or (
139-
endpoint == "responses" and bool(data.get("stream"))
140-
)
141-
if use_sse:
143+
if extra_data:
144+
data.update(extra_data)
145+
146+
if _uses_sse(endpoint, data):
142147
resp = request_with_transient_retry(
143148
method="POST",
144149
url=url,
145150
json=data,
146-
headers=context.auth_headers,
151+
headers=headers,
147152
timeout=DEFAULT_LLM_TIMEOUT,
148153
stream=True,
149154
)
150155
# Consume stream so server close after error event does not raise
151156
body = _read_streamed_response(resp)
152157
resp._content = body.encode(resp.encoding or "utf-8")
153-
context.response = resp
154-
else:
155-
context.response = request_with_transient_retry(
156-
method="POST",
157-
url=url,
158-
json=data,
159-
headers=context.auth_headers,
160-
timeout=DEFAULT_LLM_TIMEOUT,
161-
)
158+
return resp
159+
160+
return request_with_transient_retry(
161+
method="POST",
162+
url=url,
163+
json=data,
164+
headers=headers,
165+
timeout=DEFAULT_LLM_TIMEOUT,
166+
)
167+
168+
169+
@step('I use "{endpoint}" to ask question')
170+
def ask_question(context: Context, endpoint: str) -> None:
171+
"""Call the service REST API endpoint with question."""
172+
context.response = _post_question(context, endpoint)
173+
174+
175+
@step('I use "{endpoint}" to ask question with authorization header')
176+
def ask_question_authorized(context: Context, endpoint: str) -> None:
177+
"""Call the service REST API endpoint with question."""
178+
context.response = _post_question(context, endpoint, headers=context.auth_headers)
162179

163180

164181
# Query length chosen to exceed typical model context windows (e.g. 128k tokens)
@@ -188,19 +205,12 @@ def store_conversation_details(context: Context) -> None:
188205
@step('I use "{endpoint}" to ask question with same conversation_id')
189206
def ask_question_in_same_conversation(context: Context, endpoint: str) -> None:
190207
"""Call the service REST API endpoint with question, but use the existing conversation id."""
191-
base = f"http://{context.hostname}:{context.port}"
192-
path = f"{context.api_prefix}/{endpoint}".replace("//", "/")
193-
url = base + path
194-
195-
# Replace {MODEL} and {PROVIDER} placeholders with actual values
196-
json_str = replace_placeholders(context, context.text or "{}")
197-
198-
data = json.loads(json_str)
199-
headers = context.auth_headers if hasattr(context, "auth_headers") else {}
200-
data["conversation_id"] = context.response_data["conversation_id"]
201-
202-
context.response = request_with_transient_retry(
203-
method="POST", url=url, json=data, headers=headers, timeout=DEFAULT_LLM_TIMEOUT
208+
headers = context.auth_headers if hasattr(context, "auth_headers") else None
209+
context.response = _post_question(
210+
context,
211+
endpoint,
212+
headers=headers,
213+
extra_data={"conversation_id": context.response_data["conversation_id"]},
204214
)
205215

206216

@@ -366,12 +376,12 @@ def _parse_streaming_response(response_text: str) -> dict:
366376
full_response = ""
367377
full_response_split = []
368378
finished = False
369-
first_token = True
370379
stream_error = (
371380
None # {"status_code": int, "response": str, "cause": str} if event "error"
372381
)
373382

374383
for line in lines:
384+
print(f"line: {line}")
375385
if line.startswith("data: "):
376386
try:
377387
data = json.loads(line[6:]) # Remove 'data: ' prefix
@@ -380,10 +390,6 @@ def _parse_streaming_response(response_text: str) -> dict:
380390
if event == "start":
381391
conversation_id = data["data"]["conversation_id"]
382392
elif event == "token":
383-
# Skip the first token (shield status message)
384-
if first_token:
385-
first_token = False
386-
continue
387393
full_response_split.append(data["data"]["token"])
388394
elif event == "turn_complete":
389395
full_response = data["data"]["token"]

tests/e2e/test_list.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1+
features/inline_rag.feature
12
features/streaming_query.feature
3+
features/mcp.feature

0 commit comments

Comments
 (0)