Skip to content

Commit ca1aba8

Browse files
committed
Wired agent.run_stream_events in streaming_query
1 parent 736f162 commit ca1aba8

7 files changed

Lines changed: 786 additions & 284 deletions

File tree

src/app/endpoints/streaming_query.py

Lines changed: 11 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import asyncio
66
import datetime
7-
import json
87
from collections.abc import AsyncIterator
98
from typing import Annotated, Any, Optional, cast
109

@@ -78,6 +77,10 @@
7877
from models.common.responses.responses_api_params import ResponsesApiParams
7978
from models.common.turn_summary import ReferencedDocument, TurnSummary
8079
from models.config import Action
80+
from utils.agents.streaming import (
81+
generate_agent_response,
82+
retrieve_agent_response_generator,
83+
)
8184
from utils.conversations import append_turn_items_to_conversation
8285
from utils.endpoints import (
8386
check_configuration_loaded,
@@ -115,6 +118,11 @@
115118
validate_shield_ids_override,
116119
)
117120
from utils.stream_interrupts import get_stream_interrupt_registry
121+
from utils.streaming_sse import (
122+
format_stream_data,
123+
shield_violation_generator,
124+
stream_event,
125+
)
118126
from utils.suid import get_suid, normalize_conversation_id
119127
from utils.token_counter import TokenCounter
120128
from utils.vector_search import build_rag_context
@@ -287,7 +295,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
287295
)
288296
recording.record_llm_call(provider_id, model_id, endpoint_path)
289297

290-
generator, turn_summary = await retrieve_response_generator(
298+
generator, turn_summary = await retrieve_agent_response_generator(
291299
responses_params=responses_params,
292300
context=context,
293301
endpoint_path=endpoint_path,
@@ -306,7 +314,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
306314
)
307315

308316
return StreamingResponse(
309-
generate_response(
317+
generate_agent_response(
310318
generator=generator,
311319
context=context,
312320
responses_params=responses_params,
@@ -923,22 +931,6 @@ def stream_http_error_event(
923931
)
924932

925933

926-
def format_stream_data(d: dict) -> str:
927-
"""
928-
Create a response generator function for Responses API streaming.
929-
930-
Parameters:
931-
----------
932-
d (dict): The data to be formatted as an SSE event.
933-
934-
Returns:
935-
-------
936-
str: The formatted SSE data string.
937-
"""
938-
data = json.dumps(d)
939-
return f"data: {data}\n\n"
940-
941-
942934
def stream_start_event(conversation_id: str, request_id: str) -> str:
943935
"""Format an SSE start event for a streaming response.
944936
@@ -1038,61 +1030,3 @@ def stream_end_event(
10381030
"available_quotas": available_quotas,
10391031
}
10401032
)
1041-
1042-
1043-
def stream_event(data: dict, event_type: str, media_type: str) -> str:
1044-
"""Build an item to yield based on media type.
1045-
1046-
Args:
1047-
data: Dictionary containing the event data
1048-
event_type: Type of event (token, tool call, etc.)
1049-
media_type: The media type for the response format
1050-
1051-
Returns:
1052-
SSE-formatted string representing the event
1053-
"""
1054-
if media_type == MEDIA_TYPE_TEXT:
1055-
if event_type == LLM_TOKEN_EVENT:
1056-
return data.get("token", "")
1057-
if event_type == LLM_TOOL_CALL_EVENT:
1058-
return f"[Tool Call: {data.get('function_name', 'unknown')}]\n"
1059-
if event_type == LLM_TOOL_RESULT_EVENT:
1060-
return "[Tool Result]\n"
1061-
if event_type == LLM_TURN_COMPLETE_EVENT:
1062-
return ""
1063-
return ""
1064-
1065-
return format_stream_data(
1066-
{
1067-
"event": event_type,
1068-
"data": data,
1069-
}
1070-
)
1071-
1072-
1073-
async def shield_violation_generator(
1074-
violation_message: str,
1075-
media_type: str = MEDIA_TYPE_TEXT,
1076-
) -> AsyncIterator[str]:
1077-
"""
1078-
Create an SSE stream for shield violation responses.
1079-
1080-
Yields start, token, and end events immediately for shield violations.
1081-
This function creates a minimal streaming response without going through
1082-
the Llama Stack response format.
1083-
1084-
Args:
1085-
violation_message: The violation message to display.
1086-
media_type: The media type for the response format.
1087-
1088-
Yields:
1089-
str: SSE-formatted strings for start, token, and end events.
1090-
"""
1091-
yield stream_event(
1092-
{
1093-
"id": 0,
1094-
"token": violation_message,
1095-
},
1096-
LLM_TOKEN_EVENT,
1097-
media_type,
1098-
)

0 commit comments

Comments
 (0)