Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 12 additions & 251 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import asyncio
import datetime
import json
from collections.abc import AsyncIterator
from typing import Annotated, Any, Optional, cast

Expand Down Expand Up @@ -64,7 +63,6 @@
from models.api.requests import QueryRequest
from models.api.responses.constants import UNAUTHORIZED_OPENAPI_EXAMPLES_WITH_MCP_OAUTH
from models.api.responses.error import (
AbstractErrorResponse,
ForbiddenResponse,
InternalServerErrorResponse,
NotFoundResponse,
Expand All @@ -78,7 +76,7 @@
from models.common.responses.contexts import ResponseGeneratorContext
from models.common.responses.responses_api_params import ResponsesApiParams
from models.common.responses.types import ResponseInput
from models.common.turn_summary import ReferencedDocument, TurnSummary
from models.common.turn_summary import TurnSummary
from models.config import Action
from utils.conversation_compaction import (
CompactionResult,
Expand Down Expand Up @@ -125,8 +123,17 @@
validate_shield_ids_override,
)
from utils.stream_interrupts import get_stream_interrupt_registry
from utils.streaming_sse import (
http_exception_stream_event,
shield_violation_generator,
stream_compaction_event,
stream_end_event,
stream_event,
stream_http_error_event,
stream_interrupted_event,
stream_start_event,
)
from utils.suid import get_suid, normalize_conversation_id
from utils.token_counter import TokenCounter
from utils.vector_search import build_rag_context

logger = get_logger(__name__)
Expand Down Expand Up @@ -620,21 +627,6 @@ async def _on_interrupt() -> None:
return guard


def _http_exception_stream_event(exc: HTTPException) -> str:
"""Render a FastAPI HTTPException as an SSE error event.

Used by the compaction-aware streaming path, where the response is created
inside the stream and so create-time errors must be surfaced as SSE events
rather than as an HTTP status response.
"""
detail = (
exc.detail if isinstance(exc.detail, dict) else {"response": str(exc.detail)}
)
return format_stream_data(
{"event": "error", "data": {"status_code": exc.status_code, **detail}}
)


async def generate_response_with_compaction(
context: ResponseGeneratorContext,
responses_params: ResponsesApiParams,
Expand Down Expand Up @@ -689,7 +681,7 @@ async def generate_response_with_compaction(
endpoint_path=endpoint_path,
)
except HTTPException as e:
yield _http_exception_stream_event(e)
yield http_exception_stream_event(e)
return
Comment thread
asimurka marked this conversation as resolved.
except RuntimeError as e: # library mode wraps 413 into runtime error
error_response = (
Expand Down Expand Up @@ -1102,234 +1094,3 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
rag_id_mapping=context.rag_id_mapping,
)
turn_summary.rag_chunks = context.inline_rag_context.rag_chunks + tool_rag_chunks


def stream_http_error_event(
error: AbstractErrorResponse, media_type: Optional[str] = MEDIA_TYPE_JSON
) -> str:
"""
Create an SSE-formatted error response for generic LLM or API errors.

Args:
error: An AbstractErrorResponse instance representing the error.
media_type: The media type for the response format. Defaults to MEDIA_TYPE_JSON if None.

Returns:
str: A Server-Sent Events (SSE) formatted error message containing
the serialized error details.
"""
logger.error("Error while obtaining answer for user question")
media_type = media_type or MEDIA_TYPE_JSON
if media_type == MEDIA_TYPE_TEXT:
return f"Status: {error.status_code} - {error.detail.response} - {error.detail.cause}"

return format_stream_data(
{
"event": "error",
"data": {
"status_code": error.status_code,
"response": error.detail.response,
"cause": error.detail.cause,
},
}
)


def format_stream_data(d: dict) -> str:
"""
Create a response generator function for Responses API streaming.

Parameters:
----------
d (dict): The data to be formatted as an SSE event.

Returns:
-------
str: The formatted SSE data string.
"""
data = json.dumps(d)
return f"data: {data}\n\n"


def stream_start_event(conversation_id: str, request_id: str) -> str:
"""Format an SSE start event for a streaming response.

The payload contains both the conversation ID and the request ID
so the client can correlate the stream with a conversation and
use the request ID to issue an interrupt if needed.

Parameters:
----------
conversation_id (str): Unique identifier for the conversation.
request_id (str): Unique SUID for this streaming request,
returned to the client for interrupt support.

Returns:
-------
str: SSE-formatted string representing the start event.
"""
return format_stream_data(
{
"event": "start",
"data": {
"conversation_id": conversation_id,
"request_id": request_id,
},
}
)


def stream_compaction_event(conversation_id: str) -> str:
"""Format an SSE event signalling that conversation compaction has started.

Emitted before the summarization LLM call (R12) so the client can show a
progress indicator while older turns are being summarized.

Parameters:
----------
conversation_id: The conversation being compacted.

Returns:
-------
str: SSE-formatted string representing the compaction event.
"""
return format_stream_data(
{
"event": "compaction",
"data": {
"status": "started",
"conversation_id": conversation_id,
},
}
)


def stream_interrupted_event(request_id: str) -> str:
"""Format an SSE event indicating the stream was interrupted.

Emitted to the client just before the generator closes so the
frontend can distinguish an intentional user-initiated interruption
from an unexpected connection drop.

Parameters:
----------
request_id (str): Unique identifier for the interrupted request.

Returns:
-------
str: SSE-formatted string representing the interrupted event.
"""
return format_stream_data(
{
"event": "interrupted",
"data": {
"request_id": request_id,
},
}
)


def stream_end_event(
token_usage: TokenCounter,
available_quotas: dict[str, int],
referenced_documents: list[ReferencedDocument],
media_type: str = MEDIA_TYPE_JSON,
) -> str:
"""
Yield the end of the data stream.

Format and return the end event for a streaming response,
including referenced document metadata and token usage information.

Parameters:
----------
token_usage (TokenCounter): Token usage information.
available_quotas (dict[str, int]): Available quotas for the user.
referenced_documents (list[ReferencedDocument]): List of referenced documents.
media_type (str): The media type for the response format.

Returns:
-------
str: A Server-Sent Events (SSE) formatted string
representing the end of the data stream.
"""
if media_type == MEDIA_TYPE_TEXT:
ref_docs_string = "\n".join(
f"{doc.doc_title}: {doc.doc_url}"
for doc in referenced_documents
if doc.doc_url and doc.doc_title
)
return f"\n\n---\n\n{ref_docs_string}" if ref_docs_string else ""

referenced_docs_dict = [doc.model_dump(mode="json") for doc in referenced_documents]

return format_stream_data(
{
"event": "end",
"data": {
"referenced_documents": referenced_docs_dict,
"truncated": None,
"input_tokens": token_usage.input_tokens,
"output_tokens": token_usage.output_tokens,
},
"available_quotas": available_quotas,
}
)


def stream_event(data: dict, event_type: str, media_type: str) -> str:
"""Build an item to yield based on media type.

Args:
data: Dictionary containing the event data
event_type: Type of event (token, tool call, etc.)
media_type: The media type for the response format

Returns:
SSE-formatted string representing the event
"""
if media_type == MEDIA_TYPE_TEXT:
if event_type == LLM_TOKEN_EVENT:
return data.get("token", "")
if event_type == LLM_TOOL_CALL_EVENT:
return f"[Tool Call: {data.get('function_name', 'unknown')}]\n"
if event_type == LLM_TOOL_RESULT_EVENT:
return "[Tool Result]\n"
if event_type == LLM_TURN_COMPLETE_EVENT:
return ""
return ""

return format_stream_data(
{
"event": event_type,
"data": data,
}
)


async def shield_violation_generator(
violation_message: str,
media_type: str = MEDIA_TYPE_TEXT,
) -> AsyncIterator[str]:
"""
Create an SSE stream for shield violation responses.

Yields start, token, and end events immediately for shield violations.
This function creates a minimal streaming response without going through
the Llama Stack response format.

Args:
violation_message: The violation message to display.
media_type: The media type for the response format.

Yields:
str: SSE-formatted strings for start, token, and end events.
"""
yield stream_event(
{
"id": 0,
"token": violation_message,
},
LLM_TOKEN_EVENT,
media_type,
)
Loading
Loading