|
4 | 4 |
|
5 | 5 | import asyncio |
6 | 6 | import datetime |
7 | | -import json |
8 | 7 | from collections.abc import AsyncIterator |
9 | 8 | from typing import Annotated, Any, Optional, cast |
10 | 9 |
|
|
64 | 63 | from models.api.requests import QueryRequest |
65 | 64 | from models.api.responses.constants import UNAUTHORIZED_OPENAPI_EXAMPLES_WITH_MCP_OAUTH |
66 | 65 | from models.api.responses.error import ( |
67 | | - AbstractErrorResponse, |
68 | 66 | ForbiddenResponse, |
69 | 67 | InternalServerErrorResponse, |
70 | 68 | NotFoundResponse, |
|
78 | 76 | from models.common.responses.contexts import ResponseGeneratorContext |
79 | 77 | from models.common.responses.responses_api_params import ResponsesApiParams |
80 | 78 | from models.common.responses.types import ResponseInput |
81 | | -from models.common.turn_summary import ReferencedDocument, TurnSummary |
| 79 | +from models.common.turn_summary import TurnSummary |
82 | 80 | from models.config import Action |
83 | 81 | from utils.conversation_compaction import ( |
84 | 82 | CompactionResult, |
|
125 | 123 | validate_shield_ids_override, |
126 | 124 | ) |
127 | 125 | from utils.stream_interrupts import get_stream_interrupt_registry |
| 126 | +from utils.streaming_sse import ( |
| 127 | + http_exception_stream_event, |
| 128 | + shield_violation_generator, |
| 129 | + stream_compaction_event, |
| 130 | + stream_end_event, |
| 131 | + stream_event, |
| 132 | + stream_http_error_event, |
| 133 | + stream_interrupted_event, |
| 134 | + stream_start_event, |
| 135 | +) |
128 | 136 | from utils.suid import get_suid, normalize_conversation_id |
129 | | -from utils.token_counter import TokenCounter |
130 | 137 | from utils.vector_search import build_rag_context |
131 | 138 |
|
132 | 139 | logger = get_logger(__name__) |
@@ -620,21 +627,6 @@ async def _on_interrupt() -> None: |
620 | 627 | return guard |
621 | 628 |
|
622 | 629 |
|
623 | | -def _http_exception_stream_event(exc: HTTPException) -> str: |
624 | | - """Render a FastAPI HTTPException as an SSE error event. |
625 | | -
|
626 | | - Used by the compaction-aware streaming path, where the response is created |
627 | | - inside the stream and so create-time errors must be surfaced as SSE events |
628 | | - rather than as an HTTP status response. |
629 | | - """ |
630 | | - detail = ( |
631 | | - exc.detail if isinstance(exc.detail, dict) else {"response": str(exc.detail)} |
632 | | - ) |
633 | | - return format_stream_data( |
634 | | - {"event": "error", "data": {"status_code": exc.status_code, **detail}} |
635 | | - ) |
636 | | - |
637 | | - |
638 | 630 | async def generate_response_with_compaction( |
639 | 631 | context: ResponseGeneratorContext, |
640 | 632 | responses_params: ResponsesApiParams, |
@@ -689,7 +681,7 @@ async def generate_response_with_compaction( |
689 | 681 | endpoint_path=endpoint_path, |
690 | 682 | ) |
691 | 683 | except HTTPException as e: |
692 | | - yield _http_exception_stream_event(e) |
| 684 | + yield http_exception_stream_event(e) |
693 | 685 | return |
694 | 686 | except RuntimeError as e: # library mode wraps 413 into runtime error |
695 | 687 | error_response = ( |
@@ -1102,234 +1094,3 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat |
1102 | 1094 | rag_id_mapping=context.rag_id_mapping, |
1103 | 1095 | ) |
1104 | 1096 | turn_summary.rag_chunks = context.inline_rag_context.rag_chunks + tool_rag_chunks |
1105 | | - |
1106 | | - |
1107 | | -def stream_http_error_event( |
1108 | | - error: AbstractErrorResponse, media_type: Optional[str] = MEDIA_TYPE_JSON |
1109 | | -) -> str: |
1110 | | - """ |
1111 | | - Create an SSE-formatted error response for generic LLM or API errors. |
1112 | | -
|
1113 | | - Args: |
1114 | | - error: An AbstractErrorResponse instance representing the error. |
1115 | | - media_type: The media type for the response format. Defaults to MEDIA_TYPE_JSON if None. |
1116 | | -
|
1117 | | - Returns: |
1118 | | - str: A Server-Sent Events (SSE) formatted error message containing |
1119 | | - the serialized error details. |
1120 | | - """ |
1121 | | - logger.error("Error while obtaining answer for user question") |
1122 | | - media_type = media_type or MEDIA_TYPE_JSON |
1123 | | - if media_type == MEDIA_TYPE_TEXT: |
1124 | | - return f"Status: {error.status_code} - {error.detail.response} - {error.detail.cause}" |
1125 | | - |
1126 | | - return format_stream_data( |
1127 | | - { |
1128 | | - "event": "error", |
1129 | | - "data": { |
1130 | | - "status_code": error.status_code, |
1131 | | - "response": error.detail.response, |
1132 | | - "cause": error.detail.cause, |
1133 | | - }, |
1134 | | - } |
1135 | | - ) |
1136 | | - |
1137 | | - |
1138 | | -def format_stream_data(d: dict) -> str: |
1139 | | - """ |
1140 | | - Create a response generator function for Responses API streaming. |
1141 | | -
|
1142 | | - Parameters: |
1143 | | - ---------- |
1144 | | - d (dict): The data to be formatted as an SSE event. |
1145 | | -
|
1146 | | - Returns: |
1147 | | - ------- |
1148 | | - str: The formatted SSE data string. |
1149 | | - """ |
1150 | | - data = json.dumps(d) |
1151 | | - return f"data: {data}\n\n" |
1152 | | - |
1153 | | - |
1154 | | -def stream_start_event(conversation_id: str, request_id: str) -> str: |
1155 | | - """Format an SSE start event for a streaming response. |
1156 | | -
|
1157 | | - The payload contains both the conversation ID and the request ID |
1158 | | - so the client can correlate the stream with a conversation and |
1159 | | - use the request ID to issue an interrupt if needed. |
1160 | | -
|
1161 | | - Parameters: |
1162 | | - ---------- |
1163 | | - conversation_id (str): Unique identifier for the conversation. |
1164 | | - request_id (str): Unique SUID for this streaming request, |
1165 | | - returned to the client for interrupt support. |
1166 | | -
|
1167 | | - Returns: |
1168 | | - ------- |
1169 | | - str: SSE-formatted string representing the start event. |
1170 | | - """ |
1171 | | - return format_stream_data( |
1172 | | - { |
1173 | | - "event": "start", |
1174 | | - "data": { |
1175 | | - "conversation_id": conversation_id, |
1176 | | - "request_id": request_id, |
1177 | | - }, |
1178 | | - } |
1179 | | - ) |
1180 | | - |
1181 | | - |
1182 | | -def stream_compaction_event(conversation_id: str) -> str: |
1183 | | - """Format an SSE event signalling that conversation compaction has started. |
1184 | | -
|
1185 | | - Emitted before the summarization LLM call (R12) so the client can show a |
1186 | | - progress indicator while older turns are being summarized. |
1187 | | -
|
1188 | | - Parameters: |
1189 | | - ---------- |
1190 | | - conversation_id: The conversation being compacted. |
1191 | | -
|
1192 | | - Returns: |
1193 | | - ------- |
1194 | | - str: SSE-formatted string representing the compaction event. |
1195 | | - """ |
1196 | | - return format_stream_data( |
1197 | | - { |
1198 | | - "event": "compaction", |
1199 | | - "data": { |
1200 | | - "status": "started", |
1201 | | - "conversation_id": conversation_id, |
1202 | | - }, |
1203 | | - } |
1204 | | - ) |
1205 | | - |
1206 | | - |
1207 | | -def stream_interrupted_event(request_id: str) -> str: |
1208 | | - """Format an SSE event indicating the stream was interrupted. |
1209 | | -
|
1210 | | - Emitted to the client just before the generator closes so the |
1211 | | - frontend can distinguish an intentional user-initiated interruption |
1212 | | - from an unexpected connection drop. |
1213 | | -
|
1214 | | - Parameters: |
1215 | | - ---------- |
1216 | | - request_id (str): Unique identifier for the interrupted request. |
1217 | | -
|
1218 | | - Returns: |
1219 | | - ------- |
1220 | | - str: SSE-formatted string representing the interrupted event. |
1221 | | - """ |
1222 | | - return format_stream_data( |
1223 | | - { |
1224 | | - "event": "interrupted", |
1225 | | - "data": { |
1226 | | - "request_id": request_id, |
1227 | | - }, |
1228 | | - } |
1229 | | - ) |
1230 | | - |
1231 | | - |
1232 | | -def stream_end_event( |
1233 | | - token_usage: TokenCounter, |
1234 | | - available_quotas: dict[str, int], |
1235 | | - referenced_documents: list[ReferencedDocument], |
1236 | | - media_type: str = MEDIA_TYPE_JSON, |
1237 | | -) -> str: |
1238 | | - """ |
1239 | | - Yield the end of the data stream. |
1240 | | -
|
1241 | | - Format and return the end event for a streaming response, |
1242 | | - including referenced document metadata and token usage information. |
1243 | | -
|
1244 | | - Parameters: |
1245 | | - ---------- |
1246 | | - token_usage (TokenCounter): Token usage information. |
1247 | | - available_quotas (dict[str, int]): Available quotas for the user. |
1248 | | - referenced_documents (list[ReferencedDocument]): List of referenced documents. |
1249 | | - media_type (str): The media type for the response format. |
1250 | | -
|
1251 | | - Returns: |
1252 | | - ------- |
1253 | | - str: A Server-Sent Events (SSE) formatted string |
1254 | | - representing the end of the data stream. |
1255 | | - """ |
1256 | | - if media_type == MEDIA_TYPE_TEXT: |
1257 | | - ref_docs_string = "\n".join( |
1258 | | - f"{doc.doc_title}: {doc.doc_url}" |
1259 | | - for doc in referenced_documents |
1260 | | - if doc.doc_url and doc.doc_title |
1261 | | - ) |
1262 | | - return f"\n\n---\n\n{ref_docs_string}" if ref_docs_string else "" |
1263 | | - |
1264 | | - referenced_docs_dict = [doc.model_dump(mode="json") for doc in referenced_documents] |
1265 | | - |
1266 | | - return format_stream_data( |
1267 | | - { |
1268 | | - "event": "end", |
1269 | | - "data": { |
1270 | | - "referenced_documents": referenced_docs_dict, |
1271 | | - "truncated": None, |
1272 | | - "input_tokens": token_usage.input_tokens, |
1273 | | - "output_tokens": token_usage.output_tokens, |
1274 | | - }, |
1275 | | - "available_quotas": available_quotas, |
1276 | | - } |
1277 | | - ) |
1278 | | - |
1279 | | - |
1280 | | -def stream_event(data: dict, event_type: str, media_type: str) -> str: |
1281 | | - """Build an item to yield based on media type. |
1282 | | -
|
1283 | | - Args: |
1284 | | - data: Dictionary containing the event data |
1285 | | - event_type: Type of event (token, tool call, etc.) |
1286 | | - media_type: The media type for the response format |
1287 | | -
|
1288 | | - Returns: |
1289 | | - SSE-formatted string representing the event |
1290 | | - """ |
1291 | | - if media_type == MEDIA_TYPE_TEXT: |
1292 | | - if event_type == LLM_TOKEN_EVENT: |
1293 | | - return data.get("token", "") |
1294 | | - if event_type == LLM_TOOL_CALL_EVENT: |
1295 | | - return f"[Tool Call: {data.get('function_name', 'unknown')}]\n" |
1296 | | - if event_type == LLM_TOOL_RESULT_EVENT: |
1297 | | - return "[Tool Result]\n" |
1298 | | - if event_type == LLM_TURN_COMPLETE_EVENT: |
1299 | | - return "" |
1300 | | - return "" |
1301 | | - |
1302 | | - return format_stream_data( |
1303 | | - { |
1304 | | - "event": event_type, |
1305 | | - "data": data, |
1306 | | - } |
1307 | | - ) |
1308 | | - |
1309 | | - |
1310 | | -async def shield_violation_generator( |
1311 | | - violation_message: str, |
1312 | | - media_type: str = MEDIA_TYPE_TEXT, |
1313 | | -) -> AsyncIterator[str]: |
1314 | | - """ |
1315 | | - Create an SSE stream for shield violation responses. |
1316 | | -
|
1317 | | - Yields start, token, and end events immediately for shield violations. |
1318 | | - This function creates a minimal streaming response without going through |
1319 | | - the Llama Stack response format. |
1320 | | -
|
1321 | | - Args: |
1322 | | - violation_message: The violation message to display. |
1323 | | - media_type: The media type for the response format. |
1324 | | -
|
1325 | | - Yields: |
1326 | | - str: SSE-formatted strings for start, token, and end events. |
1327 | | - """ |
1328 | | - yield stream_event( |
1329 | | - { |
1330 | | - "id": 0, |
1331 | | - "token": violation_message, |
1332 | | - }, |
1333 | | - LLM_TOKEN_EVENT, |
1334 | | - media_type, |
1335 | | - ) |
0 commit comments