diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 59586c0f9..dbb345e73 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -4,7 +4,6 @@ import asyncio import datetime -import json from collections.abc import AsyncIterator from typing import Annotated, Any, Optional, cast @@ -64,7 +63,6 @@ from models.api.requests import QueryRequest from models.api.responses.constants import UNAUTHORIZED_OPENAPI_EXAMPLES_WITH_MCP_OAUTH from models.api.responses.error import ( - AbstractErrorResponse, ForbiddenResponse, InternalServerErrorResponse, NotFoundResponse, @@ -78,7 +76,7 @@ from models.common.responses.contexts import ResponseGeneratorContext from models.common.responses.responses_api_params import ResponsesApiParams from models.common.responses.types import ResponseInput -from models.common.turn_summary import ReferencedDocument, TurnSummary +from models.common.turn_summary import TurnSummary from models.config import Action from utils.conversation_compaction import ( CompactionResult, @@ -125,8 +123,17 @@ validate_shield_ids_override, ) from utils.stream_interrupts import get_stream_interrupt_registry +from utils.streaming_sse import ( + http_exception_stream_event, + shield_violation_generator, + stream_compaction_event, + stream_end_event, + stream_event, + stream_http_error_event, + stream_interrupted_event, + stream_start_event, +) from utils.suid import get_suid, normalize_conversation_id -from utils.token_counter import TokenCounter from utils.vector_search import build_rag_context logger = get_logger(__name__) @@ -620,21 +627,6 @@ async def _on_interrupt() -> None: return guard -def _http_exception_stream_event(exc: HTTPException) -> str: - """Render a FastAPI HTTPException as an SSE error event. - - Used by the compaction-aware streaming path, where the response is created - inside the stream and so create-time errors must be surfaced as SSE events - rather than as an HTTP status response. - """ - detail = ( - exc.detail if isinstance(exc.detail, dict) else {"response": str(exc.detail)} - ) - return format_stream_data( - {"event": "error", "data": {"status_code": exc.status_code, **detail}} - ) - - async def generate_response_with_compaction( context: ResponseGeneratorContext, responses_params: ResponsesApiParams, @@ -689,7 +681,7 @@ async def generate_response_with_compaction( endpoint_path=endpoint_path, ) except HTTPException as e: - yield _http_exception_stream_event(e) + yield http_exception_stream_event(e) return except RuntimeError as e: # library mode wraps 413 into runtime error error_response = ( @@ -1102,234 +1094,3 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat rag_id_mapping=context.rag_id_mapping, ) turn_summary.rag_chunks = context.inline_rag_context.rag_chunks + tool_rag_chunks - - -def stream_http_error_event( - error: AbstractErrorResponse, media_type: Optional[str] = MEDIA_TYPE_JSON -) -> str: - """ - Create an SSE-formatted error response for generic LLM or API errors. - - Args: - error: An AbstractErrorResponse instance representing the error. - media_type: The media type for the response format. Defaults to MEDIA_TYPE_JSON if None. - - Returns: - str: A Server-Sent Events (SSE) formatted error message containing - the serialized error details. - """ - logger.error("Error while obtaining answer for user question") - media_type = media_type or MEDIA_TYPE_JSON - if media_type == MEDIA_TYPE_TEXT: - return f"Status: {error.status_code} - {error.detail.response} - {error.detail.cause}" - - return format_stream_data( - { - "event": "error", - "data": { - "status_code": error.status_code, - "response": error.detail.response, - "cause": error.detail.cause, - }, - } - ) - - -def format_stream_data(d: dict) -> str: - """ - Create a response generator function for Responses API streaming. - - Parameters: - ---------- - d (dict): The data to be formatted as an SSE event. - - Returns: - ------- - str: The formatted SSE data string. - """ - data = json.dumps(d) - return f"data: {data}\n\n" - - -def stream_start_event(conversation_id: str, request_id: str) -> str: - """Format an SSE start event for a streaming response. - - The payload contains both the conversation ID and the request ID - so the client can correlate the stream with a conversation and - use the request ID to issue an interrupt if needed. - - Parameters: - ---------- - conversation_id (str): Unique identifier for the conversation. - request_id (str): Unique SUID for this streaming request, - returned to the client for interrupt support. - - Returns: - ------- - str: SSE-formatted string representing the start event. - """ - return format_stream_data( - { - "event": "start", - "data": { - "conversation_id": conversation_id, - "request_id": request_id, - }, - } - ) - - -def stream_compaction_event(conversation_id: str) -> str: - """Format an SSE event signalling that conversation compaction has started. - - Emitted before the summarization LLM call (R12) so the client can show a - progress indicator while older turns are being summarized. - - Parameters: - ---------- - conversation_id: The conversation being compacted. - - Returns: - ------- - str: SSE-formatted string representing the compaction event. - """ - return format_stream_data( - { - "event": "compaction", - "data": { - "status": "started", - "conversation_id": conversation_id, - }, - } - ) - - -def stream_interrupted_event(request_id: str) -> str: - """Format an SSE event indicating the stream was interrupted. - - Emitted to the client just before the generator closes so the - frontend can distinguish an intentional user-initiated interruption - from an unexpected connection drop. - - Parameters: - ---------- - request_id (str): Unique identifier for the interrupted request. - - Returns: - ------- - str: SSE-formatted string representing the interrupted event. - """ - return format_stream_data( - { - "event": "interrupted", - "data": { - "request_id": request_id, - }, - } - ) - - -def stream_end_event( - token_usage: TokenCounter, - available_quotas: dict[str, int], - referenced_documents: list[ReferencedDocument], - media_type: str = MEDIA_TYPE_JSON, -) -> str: - """ - Yield the end of the data stream. - - Format and return the end event for a streaming response, - including referenced document metadata and token usage information. - - Parameters: - ---------- - token_usage (TokenCounter): Token usage information. - available_quotas (dict[str, int]): Available quotas for the user. - referenced_documents (list[ReferencedDocument]): List of referenced documents. - media_type (str): The media type for the response format. - - Returns: - ------- - str: A Server-Sent Events (SSE) formatted string - representing the end of the data stream. - """ - if media_type == MEDIA_TYPE_TEXT: - ref_docs_string = "\n".join( - f"{doc.doc_title}: {doc.doc_url}" - for doc in referenced_documents - if doc.doc_url and doc.doc_title - ) - return f"\n\n---\n\n{ref_docs_string}" if ref_docs_string else "" - - referenced_docs_dict = [doc.model_dump(mode="json") for doc in referenced_documents] - - return format_stream_data( - { - "event": "end", - "data": { - "referenced_documents": referenced_docs_dict, - "truncated": None, - "input_tokens": token_usage.input_tokens, - "output_tokens": token_usage.output_tokens, - }, - "available_quotas": available_quotas, - } - ) - - -def stream_event(data: dict, event_type: str, media_type: str) -> str: - """Build an item to yield based on media type. - - Args: - data: Dictionary containing the event data - event_type: Type of event (token, tool call, etc.) - media_type: The media type for the response format - - Returns: - SSE-formatted string representing the event - """ - if media_type == MEDIA_TYPE_TEXT: - if event_type == LLM_TOKEN_EVENT: - return data.get("token", "") - if event_type == LLM_TOOL_CALL_EVENT: - return f"[Tool Call: {data.get('function_name', 'unknown')}]\n" - if event_type == LLM_TOOL_RESULT_EVENT: - return "[Tool Result]\n" - if event_type == LLM_TURN_COMPLETE_EVENT: - return "" - return "" - - return format_stream_data( - { - "event": event_type, - "data": data, - } - ) - - -async def shield_violation_generator( - violation_message: str, - media_type: str = MEDIA_TYPE_TEXT, -) -> AsyncIterator[str]: - """ - Create an SSE stream for shield violation responses. - - Yields start, token, and end events immediately for shield violations. - This function creates a minimal streaming response without going through - the Llama Stack response format. - - Args: - violation_message: The violation message to display. - media_type: The media type for the response format. - - Yields: - str: SSE-formatted strings for start, token, and end events. - """ - yield stream_event( - { - "id": 0, - "token": violation_message, - }, - LLM_TOKEN_EVENT, - media_type, - ) diff --git a/src/utils/streaming_sse.py b/src/utils/streaming_sse.py new file mode 100644 index 000000000..27eca7792 --- /dev/null +++ b/src/utils/streaming_sse.py @@ -0,0 +1,259 @@ +"""SSE formatting helpers for streaming query responses.""" + +from __future__ import annotations + +import json +from collections.abc import AsyncIterator +from typing import Optional + +from fastapi import HTTPException + +from constants import ( + LLM_TOKEN_EVENT, + LLM_TOOL_CALL_EVENT, + LLM_TOOL_RESULT_EVENT, + LLM_TURN_COMPLETE_EVENT, + MEDIA_TYPE_JSON, + MEDIA_TYPE_TEXT, +) +from log import get_logger +from models.api.responses.error import AbstractErrorResponse +from models.common.turn_summary import ReferencedDocument +from utils.token_counter import TokenCounter + +logger = get_logger(__name__) + + +def stream_http_error_event( + error: AbstractErrorResponse, media_type: Optional[str] = MEDIA_TYPE_JSON +) -> str: + """Create an SSE-formatted error response for generic LLM or API errors. + + Args: + error: An AbstractErrorResponse instance representing the error. + media_type: The media type for the response format. Defaults to MEDIA_TYPE_JSON. + + Returns: + A Server-Sent Events (SSE) formatted error message containing + the serialized error details. + """ + logger.error("Error while obtaining answer for user question") + media_type = media_type or MEDIA_TYPE_JSON + if media_type == MEDIA_TYPE_TEXT: + return f"Status: {error.status_code} - {error.detail.response} - {error.detail.cause}" + + return format_stream_data( + { + "event": "error", + "data": { + "status_code": error.status_code, + "response": error.detail.response, + "cause": error.detail.cause, + }, + } + ) + + +def format_stream_data(d: dict) -> str: + """Format a dictionary as an SSE data event string. + + Args: + d: The data to be formatted as an SSE event. + + Returns: + The formatted SSE data string. + """ + data = json.dumps(d) + return f"data: {data}\n\n" + + +def stream_start_event(conversation_id: str, request_id: str) -> str: + """Format an SSE start event for a streaming response. + + The payload contains both the conversation ID and the request ID + so the client can correlate the stream with a conversation and + use the request ID to issue an interrupt if needed. + + Args: + conversation_id: Unique identifier for the conversation. + request_id: Unique SUID for this streaming request, + returned to the client for interrupt support. + + Returns: + SSE-formatted string representing the start event. + """ + return format_stream_data( + { + "event": "start", + "data": { + "conversation_id": conversation_id, + "request_id": request_id, + }, + } + ) + + +def stream_compaction_event(conversation_id: str) -> str: + """Format an SSE event signalling that conversation compaction has started. + + Emitted before the summarization LLM call (R12) so the client can show a + progress indicator while older turns are being summarized. + + Args: + conversation_id: The conversation being compacted. + + Returns: + SSE-formatted string representing the compaction event. + """ + return format_stream_data( + { + "event": "compaction", + "data": { + "status": "started", + "conversation_id": conversation_id, + }, + } + ) + + +def stream_interrupted_event(request_id: str) -> str: + """Format an SSE event indicating the stream was interrupted. + + Emitted to the client just before the generator closes so the + frontend can distinguish an intentional user-initiated interruption + from an unexpected connection drop. + + Args: + request_id: Unique identifier for the interrupted request. + + Returns: + SSE-formatted string representing the interrupted event. + """ + return format_stream_data( + { + "event": "interrupted", + "data": { + "request_id": request_id, + }, + } + ) + + +def stream_end_event( + token_usage: TokenCounter, + available_quotas: dict[str, int], + referenced_documents: list[ReferencedDocument], + media_type: str = MEDIA_TYPE_JSON, +) -> str: + """Format the end event for a streaming response. + + Includes referenced document metadata and token usage information. + + Args: + token_usage: Token usage information. + available_quotas: Available quotas for the user. + referenced_documents: List of referenced documents. + media_type: The media type for the response format. + + Returns: + A Server-Sent Events (SSE) formatted string representing the end event. + """ + if media_type == MEDIA_TYPE_TEXT: + ref_docs_string = "\n".join( + f"{doc.doc_title}: {doc.doc_url}" + for doc in referenced_documents + if doc.doc_url and doc.doc_title + ) + return f"\n\n---\n\n{ref_docs_string}" if ref_docs_string else "" + + referenced_docs_dict = [doc.model_dump(mode="json") for doc in referenced_documents] + + return format_stream_data( + { + "event": "end", + "data": { + "referenced_documents": referenced_docs_dict, + "truncated": None, + "input_tokens": token_usage.input_tokens, + "output_tokens": token_usage.output_tokens, + }, + "available_quotas": available_quotas, + } + ) + + +def stream_event(data: dict, event_type: str, media_type: str) -> str: + """Build an SSE event string based on media type. + + Args: + data: Dictionary containing the event data. + event_type: Type of event (token, tool call, etc.). + media_type: The media type for the response format. + + Returns: + SSE-formatted string representing the event. + """ + if media_type == MEDIA_TYPE_TEXT: + if event_type == LLM_TOKEN_EVENT: + return data.get("token", "") + if event_type == LLM_TOOL_CALL_EVENT: + return f"[Tool Call: {data.get('function_name', 'unknown')}]\n" + if event_type == LLM_TOOL_RESULT_EVENT: + return "[Tool Result]\n" + if event_type == LLM_TURN_COMPLETE_EVENT: + return "" + return "" + + return format_stream_data( + { + "event": event_type, + "data": data, + } + ) + + +def http_exception_stream_event(exc: HTTPException) -> str: + """Render a FastAPI HTTPException as an SSE error event. + + Used by the compaction-aware streaming path, where the response is created + inside the stream and so create-time errors must be surfaced as SSE events + rather than as an HTTP status response. + + Args: + exc: HTTP exception raised during in-stream response creation. + + Returns: + SSE-formatted error event string. + """ + detail = ( + exc.detail if isinstance(exc.detail, dict) else {"response": str(exc.detail)} + ) + return format_stream_data( + {"event": "error", "data": {"status_code": exc.status_code, **detail}} + ) + + +async def shield_violation_generator( + violation_message: str, + media_type: str = MEDIA_TYPE_TEXT, +) -> AsyncIterator[str]: + """Create an SSE token stream for shield violation responses. + + Yields a single token event for shield violations. Callers should wrap + this generator to emit start/end events and persist the blocked turn. + + Args: + violation_message: The violation message to display. + media_type: The media type for the response format. + + Yields: + SSE-formatted token event string. + """ + yield stream_event( + { + "id": 0, + "token": violation_message, + }, + LLM_TOKEN_EVENT, + media_type, + ) diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index b147822b9..68e34de41 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -3,7 +3,6 @@ # pylint: disable=too-many-lines,too-many-function-args import asyncio -import json from collections.abc import AsyncIterator from typing import Any @@ -42,7 +41,6 @@ OpenAIResponseOutputMessageMCPCall as MCPCall, ) from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient -from pydantic import AnyUrl from pytest_mock import MockerFixture from app.endpoints.streaming_query import ( @@ -50,18 +48,10 @@ generate_response, response_generator, retrieve_response_generator, - shield_violation_generator, - stream_end_event, - stream_event, - stream_http_error_event, - stream_start_event, streaming_query_endpoint_handler, ) from configuration import AppConfig from constants import ( - LLM_TOKEN_EVENT, - LLM_TOOL_CALL_EVENT, - LLM_TOOL_RESULT_EVENT, MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT, ) @@ -120,165 +110,7 @@ def setup_configuration_fixture() -> AppConfig: return cfg -# ============================================================================ -# OLS Compatibility Tests -# ============================================================================ - - -class TestOLSStreamEventFormatting: - """Test the stream_event function for both media types (OLS compatibility).""" - - def test_stream_event_json_token(self) -> None: - """Test token event formatting for JSON media type.""" - data = {"id": 0, "token": "Hello"} - result = stream_event(data, LLM_TOKEN_EVENT, MEDIA_TYPE_JSON) - - expected = 'data: {"event": "token", "data": {"id": 0, "token": "Hello"}}\n\n' - assert result == expected - - def test_stream_event_text_token(self) -> None: - """Test token event formatting for text media type.""" - data = {"id": 0, "token": "Hello"} - result = stream_event(data, LLM_TOKEN_EVENT, MEDIA_TYPE_TEXT) - - assert result == "Hello" - - def test_stream_event_json_tool_call(self) -> None: - """Test tool call event formatting for JSON media type.""" - data = { - "id": 0, - "token": {"tool_name": "search", "arguments": {"query": "test"}}, - } - result = stream_event(data, LLM_TOOL_CALL_EVENT, MEDIA_TYPE_JSON) - - expected = ( - 'data: {"event": "tool_call", "data": {"id": 0, "token": ' - '{"tool_name": "search", "arguments": {"query": "test"}}}}\n\n' - ) - assert result == expected - - def test_stream_event_text_tool_call(self) -> None: - """Test tool call event formatting for text media type.""" - data = { - "id": 0, - "function_name": "search", - "arguments": {"query": "test"}, - } - result = stream_event(data, LLM_TOOL_CALL_EVENT, MEDIA_TYPE_TEXT) - - expected = "[Tool Call: search]\n" - assert result == expected - - def test_stream_event_json_tool_result(self) -> None: - """Test tool result event formatting for JSON media type.""" - data = { - "id": 0, - "token": {"tool_name": "search", "response": "Found results"}, - } - result = stream_event(data, LLM_TOOL_RESULT_EVENT, MEDIA_TYPE_JSON) - - expected = ( - 'data: {"event": "tool_result", "data": {"id": 0, "token": ' - '{"tool_name": "search", "response": "Found results"}}}\n\n' - ) - assert result == expected - - def test_stream_event_text_tool_result(self) -> None: - """Test tool result event formatting for text media type.""" - data = { - "id": 0, - "tool_name": "search", - "response": "Found results", - } - result = stream_event(data, LLM_TOOL_RESULT_EVENT, MEDIA_TYPE_TEXT) - - expected = "[Tool Result]\n" - assert result == expected - - def test_stream_event_unknown_type(self) -> None: - """Test handling of unknown event types.""" - data = {"id": 0, "token": "test"} - result = stream_event(data, "unknown_event", MEDIA_TYPE_TEXT) - - assert result == "" - - -class TestOLSStreamEndEvent: - """Test the stream_end_event function for both media types (OLS compatibility).""" - - def test_stream_end_event_json(self) -> None: - """Test end event formatting for JSON media type.""" - token_usage = TokenCounter(input_tokens=100, output_tokens=50) - available_quotas: dict[str, int] = {} - referenced_documents = [ - ReferencedDocument( - doc_url=AnyUrl("https://example.com/doc1"), doc_title="Test Doc 1" - ), - ReferencedDocument( - doc_url=AnyUrl("https://example.com/doc2"), doc_title="Test Doc 2" - ), - ] - result = stream_end_event( - token_usage, - available_quotas, - referenced_documents, - MEDIA_TYPE_JSON, - ) - - data_part = result.replace("data: ", "").strip() - parsed = json.loads(data_part) - - assert parsed["event"] == "end" - assert "referenced_documents" in parsed["data"] - assert len(parsed["data"]["referenced_documents"]) == 2 - assert parsed["data"]["referenced_documents"][0]["doc_title"] == "Test Doc 1" - assert ( - parsed["data"]["referenced_documents"][0]["doc_url"] - == "https://example.com/doc1" - ) - assert "available_quotas" in parsed - - def test_stream_end_event_text(self) -> None: - """Test end event formatting for text media type.""" - token_usage = TokenCounter(input_tokens=100, output_tokens=50) - available_quotas: dict[str, int] = {} - referenced_documents = [ - ReferencedDocument( - doc_url=AnyUrl("https://example.com/doc1"), doc_title="Test Doc 1" - ), - ReferencedDocument( - doc_url=AnyUrl("https://example.com/doc2"), doc_title="Test Doc 2" - ), - ] - result = stream_end_event( - token_usage, - available_quotas, - referenced_documents, - MEDIA_TYPE_TEXT, - ) - - expected = ( - "\n\n---\n\nTest Doc 1: https://example.com/doc1\n" - "Test Doc 2: https://example.com/doc2" - ) - assert result == expected - - def test_stream_end_event_text_no_docs(self) -> None: - """Test end event formatting for text media type with no documents.""" - token_usage = TokenCounter(input_tokens=100, output_tokens=50) - available_quotas: dict[str, int] = {} - referenced_documents: list[ReferencedDocument] = [] - result = stream_end_event( - token_usage, - available_quotas, - referenced_documents, - MEDIA_TYPE_TEXT, - ) - - assert result == "" - - -class TestOLSCompatibilityIntegration: +class TestOLSCompatibilityIntegration: # pylint: disable=too-few-public-methods """Integration tests for OLS compatibility.""" def test_media_type_validation(self) -> None: @@ -298,31 +130,6 @@ def test_media_type_validation(self) -> None: query="test", media_type="invalid/type" ) # pyright: ignore[reportCallIssue] - def test_ols_end_event_structure(self) -> None: - """Test that end event follows OLS structure.""" - token_usage = TokenCounter(input_tokens=100, output_tokens=50) - available_quotas: dict[str, int] = {} - referenced_documents = [ - ReferencedDocument( - doc_url=AnyUrl("https://example.com/doc"), doc_title="Test Doc" - ), - ] - end_event = stream_end_event( - token_usage, - available_quotas, - referenced_documents, - MEDIA_TYPE_JSON, - ) - data_part = end_event.replace("data: ", "").strip() - parsed = json.loads(data_part) - - assert parsed["event"] == "end" - assert "referenced_documents" in parsed["data"] - assert "truncated" in parsed["data"] - assert "input_tokens" in parsed["data"] - assert "output_tokens" in parsed["data"] - assert "available_quotas" in parsed - # ============================================================================ # Endpoint Handler Tests @@ -1516,7 +1323,6 @@ async def mock_generator() -> AsyncIterator[str]: assert len(result) > 0 assert any("error" in item for item in result) - @pytest.mark.asyncio async def test_generate_response_cancelled_persists_interrupted_turn( self, mocker: MockerFixture, @@ -2579,80 +2385,6 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: assert mock_turn_summary.referenced_documents[1].doc_title == "Tool Doc" -class TestStreamHttpErrorEvent: - """Tests for stream_http_error_event function.""" - - def test_stream_http_error_event_json(self, mocker: MockerFixture) -> None: - """Test HTTP error event formatting for JSON media type.""" - error = InternalServerErrorResponse.query_failed("Test error") - mocker.patch("app.endpoints.streaming_query.logger") - - result = stream_http_error_event(error, MEDIA_TYPE_JSON) - - assert "error" in result - assert "Test error" in result - - def test_stream_http_error_event_text(self, mocker: MockerFixture) -> None: - """Test HTTP error event formatting for text media type.""" - error = InternalServerErrorResponse.query_failed("Test error") - mocker.patch("app.endpoints.streaming_query.logger") - - result = stream_http_error_event(error, MEDIA_TYPE_TEXT) - - assert "Status:" in result - assert "500" in result - assert "Test error" in result - - def test_stream_http_error_event_default(self, mocker: MockerFixture) -> None: - """Test HTTP error event formatting with default media type.""" - error = InternalServerErrorResponse.query_failed("Test error") - mocker.patch("app.endpoints.streaming_query.logger") - - result = stream_http_error_event(error) - - assert "error" in result - assert "500" in result or "status_code" in result - - -class TestStreamStartEvent: # pylint: disable=too-few-public-methods - """Tests for stream_start_event function.""" - - def test_stream_start_event(self) -> None: - """Test start event formatting.""" - result = stream_start_event("conv_123", "123e4567-e89b-12d3-a456-426614174000") - - assert "start" in result - assert "conv_123" in result - assert "123e4567-e89b-12d3-a456-426614174000" in result - - -class TestShieldViolationGenerator: - """Tests for shield_violation_generator function.""" - - @pytest.mark.asyncio - async def test_shield_violation_generator_json(self) -> None: - """Test shield violation generator for JSON media type.""" - result = [] - async for item in shield_violation_generator( - "Violation message", MEDIA_TYPE_JSON - ): - result.append(item) - - assert len(result) > 0 - assert any("Violation message" in item for item in result) - - @pytest.mark.asyncio - async def test_shield_violation_generator_text(self) -> None: - """Test shield violation generator for text media type.""" - result = [] - async for item in shield_violation_generator( - "Violation message", MEDIA_TYPE_TEXT - ): - result.append(item) - - assert len(result) > 0 - - class TestResponseGeneratorMCPCalls: """Tests for MCP call specific event handling in response_generator.""" diff --git a/tests/unit/utils/test_streaming_sse.py b/tests/unit/utils/test_streaming_sse.py new file mode 100644 index 000000000..60b4080c1 --- /dev/null +++ b/tests/unit/utils/test_streaming_sse.py @@ -0,0 +1,277 @@ +"""Unit tests for utils/streaming_sse.py.""" + +import json + +import pytest +from pydantic import AnyUrl +from pytest_mock import MockerFixture + +from constants import ( + LLM_TOKEN_EVENT, + LLM_TOOL_CALL_EVENT, + LLM_TOOL_RESULT_EVENT, + MEDIA_TYPE_JSON, + MEDIA_TYPE_TEXT, +) +from models.api.responses.error import InternalServerErrorResponse +from models.common.turn_summary import ReferencedDocument +from utils.streaming_sse import ( + shield_violation_generator, + stream_end_event, + stream_event, + stream_http_error_event, + stream_start_event, +) +from utils.token_counter import TokenCounter + + +class TestOLSStreamEventFormatting: + """Test the stream_event function for both media types (OLS compatibility).""" + + def test_stream_event_json_token(self) -> None: + """Test token event formatting for JSON media type.""" + data = {"id": 0, "token": "Hello"} + result = stream_event(data, LLM_TOKEN_EVENT, MEDIA_TYPE_JSON) + + expected = 'data: {"event": "token", "data": {"id": 0, "token": "Hello"}}\n\n' + assert result == expected + + def test_stream_event_text_token(self) -> None: + """Test token event formatting for text media type.""" + data = {"id": 0, "token": "Hello"} + result = stream_event(data, LLM_TOKEN_EVENT, MEDIA_TYPE_TEXT) + + assert result == "Hello" + + def test_stream_event_json_tool_call(self) -> None: + """Test tool call event formatting for JSON media type.""" + data = { + "id": 0, + "token": {"tool_name": "search", "arguments": {"query": "test"}}, + } + result = stream_event(data, LLM_TOOL_CALL_EVENT, MEDIA_TYPE_JSON) + + expected = ( + 'data: {"event": "tool_call", "data": {"id": 0, "token": ' + '{"tool_name": "search", "arguments": {"query": "test"}}}}\n\n' + ) + assert result == expected + + def test_stream_event_text_tool_call(self) -> None: + """Test tool call event formatting for text media type.""" + data = { + "id": 0, + "function_name": "search", + "arguments": {"query": "test"}, + } + result = stream_event(data, LLM_TOOL_CALL_EVENT, MEDIA_TYPE_TEXT) + + expected = "[Tool Call: search]\n" + assert result == expected + + def test_stream_event_json_tool_result(self) -> None: + """Test tool result event formatting for JSON media type.""" + data = { + "id": 0, + "token": {"tool_name": "search", "response": "Found results"}, + } + result = stream_event(data, LLM_TOOL_RESULT_EVENT, MEDIA_TYPE_JSON) + + expected = ( + 'data: {"event": "tool_result", "data": {"id": 0, "token": ' + '{"tool_name": "search", "response": "Found results"}}}\n\n' + ) + assert result == expected + + def test_stream_event_text_tool_result(self) -> None: + """Test tool result event formatting for text media type.""" + data = { + "id": 0, + "tool_name": "search", + "response": "Found results", + } + result = stream_event(data, LLM_TOOL_RESULT_EVENT, MEDIA_TYPE_TEXT) + + expected = "[Tool Result]\n" + assert result == expected + + def test_stream_event_unknown_type(self) -> None: + """Test handling of unknown event types.""" + data = {"id": 0, "token": "test"} + result = stream_event(data, "unknown_event", MEDIA_TYPE_TEXT) + + assert result == "" + + +class TestOLSStreamEndEvent: + """Test the stream_end_event function for both media types (OLS compatibility).""" + + def test_stream_end_event_json(self) -> None: + """Test end event formatting for JSON media type.""" + token_usage = TokenCounter(input_tokens=100, output_tokens=50) + available_quotas: dict[str, int] = {} + referenced_documents = [ + ReferencedDocument( + doc_url=AnyUrl("https://example.com/doc1"), doc_title="Test Doc 1" + ), + ReferencedDocument( + doc_url=AnyUrl("https://example.com/doc2"), doc_title="Test Doc 2" + ), + ] + result = stream_end_event( + token_usage, + available_quotas, + referenced_documents, + MEDIA_TYPE_JSON, + ) + + data_part = result.replace("data: ", "").strip() + parsed = json.loads(data_part) + + assert parsed["event"] == "end" + assert "referenced_documents" in parsed["data"] + assert len(parsed["data"]["referenced_documents"]) == 2 + assert parsed["data"]["referenced_documents"][0]["doc_title"] == "Test Doc 1" + assert ( + parsed["data"]["referenced_documents"][0]["doc_url"] + == "https://example.com/doc1" + ) + assert "available_quotas" in parsed + + def test_stream_end_event_text(self) -> None: + """Test end event formatting for text media type.""" + token_usage = TokenCounter(input_tokens=100, output_tokens=50) + available_quotas: dict[str, int] = {} + referenced_documents = [ + ReferencedDocument( + doc_url=AnyUrl("https://example.com/doc1"), doc_title="Test Doc 1" + ), + ReferencedDocument( + doc_url=AnyUrl("https://example.com/doc2"), doc_title="Test Doc 2" + ), + ] + result = stream_end_event( + token_usage, + available_quotas, + referenced_documents, + MEDIA_TYPE_TEXT, + ) + + expected = ( + "\n\n---\n\nTest Doc 1: https://example.com/doc1\n" + "Test Doc 2: https://example.com/doc2" + ) + assert result == expected + + def test_stream_end_event_text_no_docs(self) -> None: + """Test end event formatting for text media type with no documents.""" + token_usage = TokenCounter(input_tokens=100, output_tokens=50) + available_quotas: dict[str, int] = {} + referenced_documents: list[ReferencedDocument] = [] + result = stream_end_event( + token_usage, + available_quotas, + referenced_documents, + MEDIA_TYPE_TEXT, + ) + + assert result == "" + + def test_ols_end_event_structure(self) -> None: + """Test that end event follows OLS structure.""" + token_usage = TokenCounter(input_tokens=100, output_tokens=50) + available_quotas: dict[str, int] = {} + referenced_documents = [ + ReferencedDocument( + doc_url=AnyUrl("https://example.com/doc"), doc_title="Test Doc" + ), + ] + end_event = stream_end_event( + token_usage, + available_quotas, + referenced_documents, + MEDIA_TYPE_JSON, + ) + data_part = end_event.replace("data: ", "").strip() + parsed = json.loads(data_part) + + assert parsed["event"] == "end" + assert "referenced_documents" in parsed["data"] + assert "truncated" in parsed["data"] + assert "input_tokens" in parsed["data"] + assert "output_tokens" in parsed["data"] + assert "available_quotas" in parsed + + +class TestStreamHttpErrorEvent: + """Tests for stream_http_error_event function.""" + + def test_stream_http_error_event_json(self, mocker: MockerFixture) -> None: + """Test HTTP error event formatting for JSON media type.""" + error = InternalServerErrorResponse.query_failed("Test error") + mocker.patch("utils.streaming_sse.logger") + + result = stream_http_error_event(error, MEDIA_TYPE_JSON) + + assert "error" in result + assert "Test error" in result + + def test_stream_http_error_event_text(self, mocker: MockerFixture) -> None: + """Test HTTP error event formatting for text media type.""" + error = InternalServerErrorResponse.query_failed("Test error") + mocker.patch("utils.streaming_sse.logger") + + result = stream_http_error_event(error, MEDIA_TYPE_TEXT) + + assert "Status:" in result + assert "500" in result + assert "Test error" in result + + def test_stream_http_error_event_default(self, mocker: MockerFixture) -> None: + """Test HTTP error event formatting with default media type.""" + error = InternalServerErrorResponse.query_failed("Test error") + mocker.patch("utils.streaming_sse.logger") + + result = stream_http_error_event(error) + + assert "error" in result + assert "500" in result or "status_code" in result + + +class TestStreamStartEvent: # pylint: disable=too-few-public-methods + """Tests for stream_start_event function.""" + + def test_stream_start_event(self) -> None: + """Test start event formatting.""" + result = stream_start_event("conv_123", "123e4567-e89b-12d3-a456-426614174000") + + assert "start" in result + assert "conv_123" in result + assert "123e4567-e89b-12d3-a456-426614174000" in result + + +class TestShieldViolationGenerator: + """Tests for shield_violation_generator function.""" + + @pytest.mark.asyncio + async def test_shield_violation_generator_json(self) -> None: + """Test shield violation generator for JSON media type.""" + result = [] + async for item in shield_violation_generator( + "Violation message", MEDIA_TYPE_JSON + ): + result.append(item) + + assert len(result) > 0 + assert any("Violation message" in item for item in result) + + @pytest.mark.asyncio + async def test_shield_violation_generator_text(self) -> None: + """Test shield violation generator for text media type.""" + result = [] + async for item in shield_violation_generator( + "Violation message", MEDIA_TYPE_TEXT + ): + result.append(item) + + assert len(result) > 0