From 70185c5a8c1ba3e7805751c266628141150e142f Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 31 Oct 2025 11:37:06 +0100 Subject: [PATCH 01/16] refactor: Update bidirectional event types --- .../bidirectional_streaming/__init__.py | 44 +- .../bidirectional_streaming/agent/agent.py | 13 +- .../models/bidirectional_model.py | 23 +- .../models/gemini_live.py | 149 +++-- .../models/novasonic.py | 54 +- .../bidirectional_streaming/models/openai.py | 57 +- .../bidirectional_streaming/types/__init__.py | 49 +- .../types/bidirectional_streaming.py | 564 +++++++++++++----- .../models/test_gemini_live.py | 66 +- .../models/test_novasonic.py | 54 +- .../models/test_openai_realtime.py | 30 +- 11 files changed, 706 insertions(+), 397 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index d855ba0388..0413593145 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -14,36 +14,46 @@ # Event types - For type hints and event handling from .types.bidirectional_streaming import ( AudioInputEvent, - AudioOutputEvent, - BidirectionalStreamEvent, + AudioStreamEvent, + ErrorEvent, ImageInputEvent, - InterruptionDetectedEvent, + InputEvent, + InterruptionEvent, + ModalityUsage, + MultimodalUsage, + OutputEvent, + SessionEndEvent, + SessionStartEvent, TextInputEvent, - TextOutputEvent, - UsageMetricsEvent, - VoiceActivityEvent, + TranscriptStreamEvent, + TurnCompleteEvent, + TurnStartEvent, ) __all__ = [ # Main interface "BidirectionalAgent", - # Model providers "GeminiLiveBidirectionalModel", "NovaSonicBidirectionalModel", "OpenAIRealtimeBidirectionalModel", - - # Event types + # Input Event types + "TextInputEvent", "AudioInputEvent", - "AudioOutputEvent", "ImageInputEvent", - "TextInputEvent", - "TextOutputEvent", - "InterruptionDetectedEvent", - "BidirectionalStreamEvent", - "VoiceActivityEvent", - "UsageMetricsEvent", - + "InputEvent", + # Output Event types + "SessionStartEvent", + "TurnStartEvent", + "AudioStreamEvent", + "TranscriptStreamEvent", + "InterruptionEvent", + "TurnCompleteEvent", + "MultimodalUsage", + "ModalityUsage", + "SessionEndEvent", + "ErrorEvent", + "OutputEvent", # Model interface "BidirectionalModel", ] diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index c9d7292b88..d74860222d 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -31,7 +31,7 @@ from ....types.traces import AttributeValue from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection from ..models.bidirectional_model import BidirectionalModel -from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent, ImageInputEvent +from ..types.bidirectional_streaming import AudioInputEvent, ImageInputEvent, OutputEvent logger = logging.getLogger(__name__) @@ -395,19 +395,24 @@ async def send(self, input_data: str | AudioInputEvent | ImageInputEvent) -> Non "(dict with imageData, mimeType, encoding)" ) - async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: + async def receive(self) -> AsyncIterable[dict[str, Any]]: """Receive events from the model including audio, text, and tool calls. Yields model output events processed by background tasks including audio output, text responses, tool calls, and session updates. Yields: - BidirectionalStreamEvent: Events from the model session. + dict: Event dictionaries from the model session. Each event is a TypedEvent + converted to a dictionary for consistency with the standard Agent API. """ while self._session and self._session.active: try: event = await asyncio.wait_for(self._output_queue.get(), timeout=0.1) - yield event + # Convert TypedEvent to dict for consistency with Agent.stream_async + if hasattr(event, 'as_dict'): + yield event.as_dict() + else: + yield event except asyncio.TimeoutError: continue diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 5b7091dcdc..28a6f77cef 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -16,12 +16,13 @@ import logging from typing import AsyncIterable, Union +from ....types._events import ToolResultEvent from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec from ..types.bidirectional_streaming import ( AudioInputEvent, - BidirectionalStreamEvent, ImageInputEvent, + OutputEvent, TextInputEvent, ) @@ -69,7 +70,7 @@ async def close(self) -> None: raise NotImplementedError @abc.abstractmethod - async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: + async def receive(self) -> AsyncIterable[OutputEvent]: """Receive streaming events from the model. Continuously yields events from the model as they arrive over the connection. @@ -79,13 +80,16 @@ async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: The stream continues until the connection is closed or an error occurs. Yields: - BidirectionalStreamEvent: Standardized event dictionaries containing - audio output, text responses, tool calls, or control signals. + OutputEvent: Standardized event objects containing audio output, + transcripts, tool calls, or control signals. """ raise NotImplementedError @abc.abstractmethod - async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: + async def send( + self, + content: Union[TextInputEvent, AudioInputEvent, ImageInputEvent, ToolResultEvent], + ) -> None: """Send content to the model over the active connection. Transmits user input or tool results to the model during an active streaming @@ -95,13 +99,14 @@ async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputE Args: content: The content to send. Must be one of: - TextInputEvent: Text message from the user - - ImageInputEvent: Image data for visual understanding - AudioInputEvent: Audio data for speech input - - ToolResult: Result from a tool execution + - ImageInputEvent: Image data for visual understanding + - ToolResultEvent: Result from a tool execution Example: await model.send(TextInputEvent(text="Hello", role="user")) - await model.send(AudioInputEvent(audioData=bytes, format="pcm", ...)) - await model.send(ToolResult(toolUseId="123", status="success", ...)) + await model.send(AudioInputEvent(audio=bytes, format="pcm", sample_rate=16000, channels=1)) + await model.send(ImageInputEvent(image=bytes, mime_type="image/jpeg", encoding="raw")) + await model.send(ToolResultEvent(tool_result)) """ raise NotImplementedError diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 639328c64b..fe495f426e 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -23,16 +23,19 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse +from ....types._events import ToolResultEvent from ..types.bidirectional_streaming import ( AudioInputEvent, - AudioOutputEvent, - BidirectionalConnectionEndEvent, - BidirectionalConnectionStartEvent, + AudioStreamEvent, + ErrorEvent, ImageInputEvent, - InterruptionDetectedEvent, + InterruptionEvent, + SessionEndEvent, + SessionStartEvent, TextInputEvent, - TextOutputEvent, - TranscriptEvent, + TranscriptStreamEvent, + TurnCompleteEvent, + TurnStartEvent, ) from .bidirectional_model import BidirectionalModel @@ -158,12 +161,12 @@ async def _send_message_history(self, messages: Messages) -> None: async def receive(self) -> AsyncIterable[Dict[str, Any]]: """Receive Gemini Live API events and convert to provider-agnostic format.""" - # Emit connection start event - connection_start: BidirectionalConnectionStartEvent = { - "connectionId": self.session_id, - "metadata": {"provider": "gemini_live", "model_id": self.model_id} - } - yield {"BidirectionalConnectionStart": connection_start} + # Emit session start event + yield SessionStartEvent( + session_id=self.session_id, + model=self.model_id, + capabilities=["audio", "tools", "images"] + ) try: # Wrap in while loop to restart after turn_complete (SDK limitation workaround) @@ -189,30 +192,23 @@ async def receive(self) -> AsyncIterable[Dict[str, Any]]: except Exception as e: logger.error("Fatal error in receive loop: %s", e) + yield ErrorEvent(error=e) finally: - # Emit connection end event when exiting - connection_end: BidirectionalConnectionEndEvent = { - "connectionId": self.session_id, - "reason": "connection_complete", - "metadata": {"provider": "gemini_live"} - } - yield {"BidirectionalConnectionEnd": connection_end} + # Emit session end event when exiting + yield SessionEndEvent(reason="complete") def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dict[str, Any]]: """Convert Gemini Live API events to provider-agnostic format. - Handles different types of text output: - - inputTranscription: User's speech transcribed to text (emitted as transcript event) - - outputTranscription: Model's audio transcribed to text (emitted as transcript event) - - modelTurn text: Actual text response from the model (emitted as textOutput) + Handles different types of content: + - inputTranscription: User's speech transcribed to text + - outputTranscription: Model's audio transcribed to text + - modelTurn text: Text response from the model """ try: # Handle interruption first (from server_content) if message.server_content and message.server_content.interrupted: - interruption: InterruptionDetectedEvent = { - "reason": "user_input" - } - return {"interruptionDetected": interruption} + return InterruptionEvent(reason="user_speech", turn_id=None) # Handle input transcription (user's speech) - emit as transcript event if message.server_content and message.server_content.input_transcription: @@ -221,12 +217,11 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic if hasattr(input_transcript, 'text') and input_transcript.text: transcription_text = input_transcript.text logger.debug(f"Input transcription detected: {transcription_text}") - transcript: TranscriptEvent = { - "text": transcription_text, - "role": "user", - "type": "input" - } - return {"transcript": transcript} + return TranscriptStreamEvent( + text=transcription_text, + source="user", + is_final=True + ) # Handle output transcription (model's audio) - emit as transcript event if message.server_content and message.server_content.output_transcription: @@ -235,32 +230,29 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic if hasattr(output_transcript, 'text') and output_transcript.text: transcription_text = output_transcript.text logger.debug(f"Output transcription detected: {transcription_text}") - transcript: TranscriptEvent = { - "text": transcription_text, - "role": "assistant", - "type": "output" - } - return {"transcript": transcript} + return TranscriptStreamEvent( + text=transcription_text, + source="assistant", + is_final=True + ) - # Handle actual text output from model (not transcription) - # The SDK's message.text property accesses modelTurn.parts[].text + # Handle text output from model if message.text: - text_output: TextOutputEvent = { - "text": message.text, - "role": "assistant" - } - return {"textOutput": text_output} + logger.debug(f"Text output as transcript: {message.text}") + return TranscriptStreamEvent( + text=message.text, + source="assistant", + is_final=True + ) # Handle audio output using SDK's built-in data property if message.data: - audio_output: AudioOutputEvent = { - "audioData": message.data, - "format": "pcm", - "sampleRate": GEMINI_OUTPUT_SAMPLE_RATE, - "channels": GEMINI_CHANNELS, - "encoding": "raw" - } - return {"audioOutput": audio_output} + return AudioStreamEvent( + audio=message.data, + format="pcm", + sample_rate=GEMINI_OUTPUT_SAMPLE_RATE, + channels=GEMINI_CHANNELS + ) # Handle tool calls if message.tool_call and message.tool_call.function_calls: @@ -281,34 +273,33 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic logger.error("Message attributes: %s", [attr for attr in dir(message) if not attr.startswith('_')]) return None - async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: + async def send( + self, + content: Union[TextInputEvent, AudioInputEvent, ImageInputEvent, ToolResultEvent], + ) -> None: """Unified send method for all content types. Sends the given inputs to Google Live API Dispatches to appropriate internal handler based on content type. Args: - content: Typed event (TextInputEvent, ImageInputEvent, AudioInputEvent, or ToolResult). + content: Typed event (TextInputEvent, AudioInputEvent, ImageInputEvent, or ToolResultEvent). """ if not self._active: return try: - if isinstance(content, dict): - # Dispatch based on content structure - if "text" in content and "role" in content: - # TextInputEvent - await self._send_text_content(content["text"]) - elif "audioData" in content: - # AudioInputEvent - await self._send_audio_content(content) - elif "imageData" in content or "image_url" in content: - # ImageInputEvent - await self._send_image_content(content) - elif "toolUseId" in content and "status" in content: - # ToolResult - await self._send_tool_result(content) - else: - logger.warning(f"Unknown content type with keys: {content.keys()}") + if isinstance(content, TextInputEvent): + await self._send_text_content(content.text) + elif isinstance(content, AudioInputEvent): + await self._send_audio_content(content) + elif isinstance(content, ImageInputEvent): + await self._send_image_content(content) + elif isinstance(content, ToolResultEvent): + tool_result = content.get("tool_result") + if tool_result: + await self._send_tool_result(tool_result) + else: + logger.warning(f"Unknown content type: {type(content)}") except Exception as e: logger.error(f"Error sending content: {e}") @@ -321,7 +312,7 @@ async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: try: # Create audio blob for the SDK audio_blob = genai_types.Blob( - data=audio_input["audioData"], + data=audio_input.audio, mime_type=f"audio/pcm;rate={GEMINI_INPUT_SAMPLE_RATE}" ) @@ -339,19 +330,19 @@ async def _send_image_content(self, image_input: ImageInputEvent) -> None: """ try: # Prepare the message based on encoding - if image_input.get("encoding") == "base64": + if image_input.encoding == "base64": # Data is already base64 encoded - if isinstance(image_input["imageData"], bytes): - data_str = image_input["imageData"].decode() + if isinstance(image_input.image, bytes): + data_str = image_input.image.decode() else: - data_str = image_input["imageData"] + data_str = image_input.image else: # Raw bytes - need to base64 encode - data_str = base64.b64encode(image_input["imageData"]).decode() + data_str = base64.b64encode(image_input.image).decode() # Create the message in the format expected by Gemini Live msg = { - "mime_type": image_input["mimeType"], + "mime_type": image_input.mime_type, "data": data_str } diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index b9c5060ba7..f66a713772 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -32,16 +32,20 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse +from ....types._events import ToolResultEvent from ..types.bidirectional_streaming import ( AudioInputEvent, - AudioOutputEvent, - BidirectionalConnectionEndEvent, - BidirectionalConnectionStartEvent, + AudioStreamEvent, + ErrorEvent, ImageInputEvent, - InterruptionDetectedEvent, + InterruptionEvent, + MultimodalUsage, + SessionEndEvent, + SessionStartEvent, TextInputEvent, - TextOutputEvent, - UsageMetricsEvent, + TranscriptStreamEvent, + TurnCompleteEvent, + TurnStartEvent, ) from .bidirectional_model import BidirectionalModel @@ -302,34 +306,34 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: } yield {"BidirectionalConnectionEnd": connection_end} - async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: + async def send( + self, + content: Union[TextInputEvent, AudioInputEvent, ImageInputEvent, ToolResultEvent], + ) -> None: """Unified send method for all content types. Sends the given content to Nova Sonic. Dispatches to appropriate internal handler based on content type. Args: - content: Typed event (TextInputEvent, ImageInputEvent, AudioInputEvent, or ToolResult). + content: Typed event (TextInputEvent, AudioInputEvent, ImageInputEvent, or ToolResultEvent). """ if not self._active: return try: - if isinstance(content, dict): - # Dispatch based on content structure - if "text" in content and "role" in content: - # TextInputEvent - await self._send_text_content(content["text"]) - elif "audioData" in content: - # AudioInputEvent - await self._send_audio_content(content) - elif "imageData" in content or "image_url" in content: - # ImageInputEvent - not supported by Nova Sonic - logger.warning("Image input not supported by Nova Sonic") - elif "toolUseId" in content and "status" in content: - # ToolResult - await self._send_tool_result(content) - else: - logger.warning(f"Unknown content type with keys: {content.keys()}") + if isinstance(content, TextInputEvent): + await self._send_text_content(content.text) + elif isinstance(content, AudioInputEvent): + await self._send_audio_content(content) + elif isinstance(content, ImageInputEvent): + # ImageInputEvent - not supported by Nova Sonic + logger.warning("Image input not supported by Nova Sonic") + elif isinstance(content, ToolResultEvent): + tool_result = content.get("tool_result") + if tool_result: + await self._send_tool_result(tool_result) + else: + logger.warning(f"Unknown content type: {type(content)}") except Exception as e: logger.error(f"Error sending content: {e}") @@ -370,7 +374,7 @@ async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: self.silence_task.cancel() # Convert audio to Nova Sonic base64 format - nova_audio_data = base64.b64encode(audio_input["audioData"]).decode("utf-8") + nova_audio_data = base64.b64encode(audio_input.audio).decode("utf-8") # Send audio input event audio_event = json.dumps( diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index a542ec894e..ae7de4c835 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -18,16 +18,21 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse +from ....types._events import ToolResultEvent from ..types.bidirectional_streaming import ( AudioInputEvent, - AudioOutputEvent, - BidirectionalConnectionEndEvent, - BidirectionalConnectionStartEvent, - BidirectionalStreamEvent, + AudioStreamEvent, + ErrorEvent, ImageInputEvent, + InterruptionEvent, + MultimodalUsage, + OutputEvent, + SessionEndEvent, + SessionStartEvent, TextInputEvent, - TextOutputEvent, - VoiceActivityEvent, + TranscriptStreamEvent, + TurnCompleteEvent, + TurnStartEvent, ) from .bidirectional_model import BidirectionalModel @@ -266,7 +271,7 @@ async def _process_responses(self) -> None: self._active = False logger.debug("OpenAI Realtime response processor stopped") - async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: + async def receive(self) -> AsyncIterable[OutputEvent]: """Receive OpenAI events and convert to Strands format.""" connection_start: BidirectionalConnectionStartEvent = { "connectionId": self.session_id, @@ -431,40 +436,40 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] logger.debug("Unhandled OpenAI event type: %s", event_type) return None - async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: + async def send( + self, + content: Union[TextInputEvent, AudioInputEvent, ImageInputEvent, ToolResultEvent], + ) -> None: """Unified send method for all content types. Sends the given content to OpenAI. Dispatches to appropriate internal handler based on content type. Args: - content: Typed event (TextInputEvent, ImageInputEvent, AudioInputEvent, or ToolResult). + content: Typed event (TextInputEvent, AudioInputEvent, ImageInputEvent, or ToolResultEvent). """ if not self._require_active(): return try: - if isinstance(content, dict): - # Dispatch based on content structure - if "text" in content and "role" in content: - # TextInputEvent - await self._send_text_content(content["text"]) - elif "audioData" in content: - # AudioInputEvent - await self._send_audio_content(content) - elif "imageData" in content or "image_url" in content: - # ImageInputEvent - not supported by OpenAI Realtime yet - logger.warning("Image input not supported by OpenAI Realtime API") - elif "toolUseId" in content and "status" in content: - # ToolResult - await self._send_tool_result(content) - else: - logger.warning(f"Unknown content type with keys: {content.keys()}") + if isinstance(content, TextInputEvent): + await self._send_text_content(content.text) + elif isinstance(content, AudioInputEvent): + await self._send_audio_content(content) + elif isinstance(content, ImageInputEvent): + # ImageInputEvent - not supported by OpenAI Realtime yet + logger.warning("Image input not supported by OpenAI Realtime API") + elif isinstance(content, ToolResultEvent): + tool_result = content.get("tool_result") + if tool_result: + await self._send_tool_result(tool_result) + else: + logger.warning(f"Unknown content type: {type(content)}") except Exception as e: logger.error(f"Error sending content: {e}") async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: """Internal: Send audio content to OpenAI for processing.""" - audio_base64 = base64.b64encode(audio_input["audioData"]).decode("utf-8") + audio_base64 = base64.b64encode(audio_input.audio).decode("utf-8") await self._send_event({"type": "input_audio_buffer.append", "audio": audio_base64}) async def _send_text_content(self, text: str) -> None: diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py index d040ee436f..52034db1b3 100644 --- a/src/strands/experimental/bidirectional_streaming/types/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -2,38 +2,51 @@ from .bidirectional_streaming import ( DEFAULT_CHANNELS, + DEFAULT_FORMAT, DEFAULT_SAMPLE_RATE, SUPPORTED_AUDIO_FORMATS, SUPPORTED_CHANNELS, SUPPORTED_SAMPLE_RATES, AudioInputEvent, - AudioOutputEvent, - BidirectionalConnectionEndEvent, - BidirectionalConnectionStartEvent, - BidirectionalStreamEvent, + AudioStreamEvent, + ErrorEvent, ImageInputEvent, - InterruptionDetectedEvent, - TextOutputEvent, - TranscriptEvent, - UsageMetricsEvent, - VoiceActivityEvent, + InputEvent, + InterruptionEvent, + ModalityUsage, + MultimodalUsage, + OutputEvent, + SessionEndEvent, + SessionStartEvent, + TextInputEvent, + TranscriptStreamEvent, + TurnCompleteEvent, + TurnStartEvent, ) __all__ = [ + # Input Events + "TextInputEvent", "AudioInputEvent", - "AudioOutputEvent", - "BidirectionalConnectionEndEvent", - "BidirectionalConnectionStartEvent", - "BidirectionalStreamEvent", "ImageInputEvent", - "InterruptionDetectedEvent", - "TextOutputEvent", - "TranscriptEvent", - "UsageMetricsEvent", - "VoiceActivityEvent", + "InputEvent", + # Output Events + "SessionStartEvent", + "TurnStartEvent", + "AudioStreamEvent", + "TranscriptStreamEvent", + "InterruptionEvent", + "TurnCompleteEvent", + "MultimodalUsage", + "ModalityUsage", + "SessionEndEvent", + "ErrorEvent", + "OutputEvent", + # Constants "SUPPORTED_AUDIO_FORMATS", "SUPPORTED_SAMPLE_RATES", "SUPPORTED_CHANNELS", "DEFAULT_SAMPLE_RATE", "DEFAULT_CHANNELS", + "DEFAULT_FORMAT", ] diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 145710c3cb..e7af3ad433 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -6,9 +6,9 @@ Key features: - Audio input/output events with standardized formats - Interruption detection and handling -- connection lifecycle management +- Session lifecycle management - Provider-agnostic event types -- Backwards compatibility with existing StreamEvent types +- Type-safe discriminated unions with TypedEvent Audio format normalization: - Supports PCM, WAV, Opus, and MP3 formats @@ -17,12 +17,9 @@ - Abstracts provider-specific encodings """ -from typing import Any, Dict, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Union, cast -from typing_extensions import TypedDict - -from ....types.content import Role -from ....types.streaming import StreamEvent +from ....types._events import TypedEvent # Audio format constants SUPPORTED_AUDIO_FORMATS = ["pcm", "wav", "opus", "mp3"] @@ -30,221 +27,470 @@ SUPPORTED_CHANNELS = [1, 2] # 1=mono, 2=stereo DEFAULT_SAMPLE_RATE = 16000 DEFAULT_CHANNELS = 1 +DEFAULT_FORMAT = "pcm" -class AudioOutputEvent(TypedDict): - """Audio output event from the model. +# ============================================================================ +# Input Events (sent via session.send()) +# ============================================================================ - Provides standardized audio output format across different providers using - raw bytes instead of provider-specific encodings. - Attributes: - audioData: Raw audio bytes (not base64 or hex encoded). - format: Audio format from SUPPORTED_AUDIO_FORMATS. - sampleRate: Sample rate from SUPPORTED_SAMPLE_RATES. - channels: Channel count from SUPPORTED_CHANNELS. - encoding: Original provider encoding for debugging purposes. +class TextInputEvent(TypedEvent): + """Text input event for sending text to the model. + + Used for sending text content through the send() method. + + Parameters: + text: The text content to send to the model. + role: The role of the message sender (typically "user"). """ - audioData: bytes - format: Literal["pcm", "wav", "opus", "mp3"] - sampleRate: Literal[16000, 24000, 48000] - channels: Literal[1, 2] - encoding: Optional[str] + def __init__(self, text: str, role: str): + super().__init__( + { + "type": "bidirectional_text_input", + "text": text, + "role": role, + } + ) + + @property + def text(self) -> str: + return cast(str, self.get("text")) + @property + def role(self) -> str: + return cast(str, self.get("role")) -class AudioInputEvent(TypedDict): + +class AudioInputEvent(TypedEvent): """Audio input event for sending audio to the model. Used for sending audio data through the send() method. - Attributes: - audioData: Raw audio bytes to send to model. + Parameters: + audio: Raw audio bytes to send to model (not base64 encoded). format: Audio format from SUPPORTED_AUDIO_FORMATS. - sampleRate: Sample rate from SUPPORTED_SAMPLE_RATES. + sample_rate: Sample rate from SUPPORTED_SAMPLE_RATES. channels: Channel count from SUPPORTED_CHANNELS. """ - audioData: bytes - format: Literal["pcm", "wav", "opus", "mp3"] - sampleRate: Literal[16000, 24000, 48000] - channels: Literal[1, 2] - - -class ImageInputEvent(TypedDict): + def __init__( + self, + audio: bytes, + format: Literal["pcm", "wav", "opus", "mp3"], + sample_rate: Literal[16000, 24000, 48000], + channels: Literal[1, 2], + ): + super().__init__( + { + "type": "bidirectional_audio_input", + "audio": audio, + "format": format, + "sample_rate": sample_rate, + "channels": channels, + } + ) + + @property + def audio(self) -> bytes: + return cast(bytes, self.get("audio")) + + @property + def format(self) -> str: + return cast(str, self.get("format")) + + @property + def sample_rate(self) -> int: + return cast(int, self.get("sample_rate")) + + @property + def channels(self) -> int: + return cast(int, self.get("channels")) + + +class ImageInputEvent(TypedEvent): """Image input event for sending images/video frames to the model. - + Used for sending image data through the send() method. Supports both raw image bytes and base64-encoded data. - - Attributes: - imageData: Image bytes (raw or base64-encoded string). - mimeType: MIME type (e.g., "image/jpeg", "image/png"). - encoding: How the imageData is encoded. + + Parameters: + image: Image bytes (raw or base64-encoded string). + mime_type: MIME type (e.g., "image/jpeg", "image/png"). + encoding: How the image data is encoded. """ - - imageData: bytes | str - mimeType: str - encoding: Literal["base64", "raw"] + def __init__( + self, + image: Union[bytes, str], + mime_type: str, + encoding: Literal["base64", "raw"], + ): + super().__init__( + { + "type": "bidirectional_image_input", + "image": image, + "mime_type": mime_type, + "encoding": encoding, + } + ) + + @property + def image(self) -> Union[bytes, str]: + return cast(Union[bytes, str], self.get("image")) + + @property + def mime_type(self) -> str: + return cast(str, self.get("mime_type")) + + @property + def encoding(self) -> str: + return cast(str, self.get("encoding")) + + +# ============================================================================ +# Output Events (received via session.receive_events()) +# ============================================================================ + + +class SessionStartEvent(TypedEvent): + """Session established and ready for interaction. + + Parameters: + session_id: Unique identifier for this session. + model: Model identifier (e.g., "gpt-realtime", "gemini-2.0-flash-live"). + capabilities: List of supported features (e.g., ["audio", "tools", "images"]). + """ -class TextInputEvent(TypedDict): - """Text input event for sending text to the model. + def __init__(self, session_id: str, model: str, capabilities: List[str]): + super().__init__( + { + "type": "bidirectional_session_start", + "session_id": session_id, + "model": model, + "capabilities": capabilities, + } + ) - Used for sending text content through the send() method. + @property + def session_id(self) -> str: + return cast(str, self.get("session_id")) - Attributes: - text: The text content to send to the model. - role: The role of the message sender (typically "user"). - """ + @property + def model(self) -> str: + return cast(str, self.get("model")) - text: str - role: Role + @property + def capabilities(self) -> List[str]: + return cast(List[str], self.get("capabilities")) -class TextOutputEvent(TypedDict): - """Text output event from the model during bidirectional streaming. +class TurnStartEvent(TypedEvent): + """Model starts generating a response. - Attributes: - text: The text content from the model. - role: The role of the message sender. + Parameters: + turn_id: Unique identifier for this turn (used in turn.complete). """ - text: str - role: Role + def __init__(self, turn_id: str): + super().__init__({"type": "bidirectional_turn_start", "turn_id": turn_id}) + @property + def turn_id(self) -> str: + return cast(str, self.get("turn_id")) -class TranscriptEvent(TypedDict): - """Transcript event for audio transcriptions. - - Used for both input transcriptions (user speech) and output transcriptions - (model audio). These are informational and separate from actual text responses. - - Attributes: - text: The transcribed text. - role: The role of the speaker ("user" or "assistant"). - type: Type of transcription ("input" or "output"). - """ - - text: str - role: Role - type: Literal["input", "output"] +class AudioStreamEvent(TypedEvent): + """Streaming audio output from the model. -class InterruptionDetectedEvent(TypedDict): - """Interruption detection event. + Parameters: + audio: Raw audio data as bytes (not base64 encoded). + format: Audio encoding format. + sample_rate: Number of audio samples per second in Hz. + channels: Number of audio channels (1=mono, 2=stereo). + """ - Signals when user interruption is detected during model generation. + def __init__( + self, + audio: bytes, + format: Literal["pcm", "wav", "opus", "mp3"], + sample_rate: Literal[16000, 24000, 48000], + channels: Literal[1, 2], + ): + super().__init__( + { + "type": "bidirectional_audio_stream", + "audio": audio, + "format": format, + "sample_rate": sample_rate, + "channels": channels, + } + ) + + @property + def audio(self) -> bytes: + return cast(bytes, self.get("audio")) + + @property + def format(self) -> str: + return cast(str, self.get("format")) + + @property + def sample_rate(self) -> int: + return cast(int, self.get("sample_rate")) + + @property + def channels(self) -> int: + return cast(int, self.get("channels")) + + +class TranscriptStreamEvent(TypedEvent): + """Audio transcription of speech (user or assistant). + + Parameters: + text: Transcribed text from audio. + source: Who is speaking ("user" or "assistant"). + is_final: Whether this is the final/complete transcript. + """ - Attributes: - reason: Interruption reason from predefined set. + def __init__( + self, text: str, source: Literal["user", "assistant"], is_final: bool + ): + super().__init__( + { + "type": "bidirectional_transcript_stream", + "text": text, + "source": source, + "is_final": is_final, + } + ) + + @property + def text(self) -> str: + return cast(str, self.get("text")) + + @property + def source(self) -> str: + return cast(str, self.get("source")) + + @property + def is_final(self) -> bool: + return cast(bool, self.get("is_final")) + + +class InterruptionEvent(TypedEvent): + """Model generation was interrupted. + + Parameters: + reason: Why the interruption occurred. + turn_id: ID of the turn that was interrupted (may be None). """ - reason: Literal["user_input", "vad_detected", "manual"] + def __init__( + self, reason: Literal["user_speech", "error"], turn_id: Optional[str] = None + ): + super().__init__( + { + "type": "bidirectional_interruption", + "reason": reason, + "turn_id": turn_id, + } + ) + @property + def reason(self) -> str: + return cast(str, self.get("reason")) -class BidirectionalConnectionStartEvent(TypedDict, total=False): - """connection start event for bidirectional streaming. + @property + def turn_id(self) -> Optional[str]: + return cast(Optional[str], self.get("turn_id")) - Attributes: - connectionId: Unique connection identifier. - metadata: Provider-specific connection metadata. - """ - connectionId: Optional[str] - metadata: Optional[Dict[str, Any]] +class TurnCompleteEvent(TypedEvent): + """Model finished generating response. + Parameters: + turn_id: ID of the turn that completed (matches turn.start). + stop_reason: Why the turn ended. + """ -class BidirectionalConnectionEndEvent(TypedDict): - """connection end event for bidirectional streaming. + def __init__( + self, + turn_id: str, + stop_reason: Literal["complete", "interrupted", "tool_use", "error"], + ): + super().__init__( + { + "type": "bidirectional_turn_complete", + "turn_id": turn_id, + "stop_reason": stop_reason, + } + ) - Attributes: - reason: Reason for connection end from predefined set. - connectionId: Unique connection identifier. - metadata: Provider-specific connection metadata. - """ + @property + def turn_id(self) -> str: + return cast(str, self.get("turn_id")) - reason: Literal["user_request", "timeout", "error", "connection_complete"] - connectionId: Optional[str] - metadata: Optional[Dict[str, Any]] + @property + def stop_reason(self) -> str: + return cast(str, self.get("stop_reason")) -class UsageMetricsEvent(TypedDict): - """Token usage and performance tracking. - Provides standardized usage metrics across providers for cost monitoring - and performance optimization. +class ModalityUsage(dict): + """Token usage for a specific modality. Attributes: - totalTokens: Total tokens used in the interaction. - inputTokens: Tokens used for input processing. - outputTokens: Tokens used for output generation. - audioTokens: Tokens used specifically for audio processing. + modality: Type of content. + input_tokens: Tokens used for this modality's input. + output_tokens: Tokens used for this modality's output. """ - totalTokens: Optional[int] - inputTokens: Optional[int] - outputTokens: Optional[int] - audioTokens: Optional[int] + modality: Literal["text", "audio", "image", "cached"] + input_tokens: int + output_tokens: int -class VoiceActivityEvent(TypedDict): - """Voice activity detection event for speech monitoring. +class MultimodalUsage(TypedEvent): + """Token usage event with modality breakdown for multimodal streaming. - Provides standardized voice activity detection events across providers - to enable speech-aware applications and better conversation flow. + Combines TypedEvent behavior with Usage fields for a unified event type. - Attributes: - activityType: Type of voice activity detected. + Parameters: + input_tokens: Total tokens used for all input modalities. + output_tokens: Total tokens used for all output modalities. + total_tokens: Sum of input and output tokens. + modality_details: Optional list of token usage per modality. + cache_read_input_tokens: Optional tokens read from cache. + cache_write_input_tokens: Optional tokens written to cache. """ - activityType: Literal["speech_started", "speech_stopped", "timeout"] - - -class UsageMetricsEvent(TypedDict): - """Token usage and performance tracking. - - Provides standardized usage metrics across providers for cost monitoring - and performance optimization. - - Attributes: - totalTokens: Total tokens used in the interaction. - inputTokens: Tokens used for input processing. - outputTokens: Tokens used for output generation. - audioTokens: Tokens used specifically for audio processing. + def __init__( + self, + input_tokens: int, + output_tokens: int, + total_tokens: int, + modality_details: Optional[List[ModalityUsage]] = None, + cache_read_input_tokens: Optional[int] = None, + cache_write_input_tokens: Optional[int] = None, + ): + data: Dict[str, Any] = { + "type": "multimodal_usage", + "inputTokens": input_tokens, + "outputTokens": output_tokens, + "totalTokens": total_tokens, + } + if modality_details is not None: + data["modality_details"] = modality_details + if cache_read_input_tokens is not None: + data["cacheReadInputTokens"] = cache_read_input_tokens + if cache_write_input_tokens is not None: + data["cacheWriteInputTokens"] = cache_write_input_tokens + super().__init__(data) + + @property + def input_tokens(self) -> int: + return cast(int, self.get("inputTokens")) + + @property + def output_tokens(self) -> int: + return cast(int, self.get("outputTokens")) + + @property + def total_tokens(self) -> int: + return cast(int, self.get("totalTokens")) + + @property + def modality_details(self) -> List[ModalityUsage]: + return cast(List[ModalityUsage], self.get("modality_details", [])) + + @property + def cache_read_input_tokens(self) -> Optional[int]: + return cast(Optional[int], self.get("cacheReadInputTokens")) + + @property + def cache_write_input_tokens(self) -> Optional[int]: + return cast(Optional[int], self.get("cacheWriteInputTokens")) + + +class SessionEndEvent(TypedEvent): + """Session terminated. + + Parameters: + reason: Why the session ended. """ - totalTokens: Optional[int] - inputTokens: Optional[int] - outputTokens: Optional[int] - audioTokens: Optional[int] + def __init__( + self, reason: Literal["client_disconnect", "timeout", "error", "complete"] + ): + super().__init__({"type": "bidirectional_session_end", "reason": reason}) + @property + def reason(self) -> str: + return cast(str, self.get("reason")) -class BidirectionalStreamEvent(StreamEvent, total=False): - """Bidirectional stream event extending existing StreamEvent. - Extends the existing StreamEvent type with bidirectional-specific events - while maintaining full backward compatibility with existing Strands streaming. +class ErrorEvent(TypedEvent): + """Error occurred during the session. - Attributes: - audioOutput: Audio output from the model. - audioInput: Audio input sent to the model. - imageInput: Image input sent to the model. - textOutput: Text output from the model. - transcript: Audio transcription (input or output). - interruptionDetected: User interruption detection. - BidirectionalConnectionStart: connection start event. - BidirectionalConnectionEnd: connection end event. - voiceActivity: Voice activity detection events. - usageMetrics: Token usage and performance metrics. + Similar to strands.types._events.ForceStopEvent, this event wraps exceptions + that occur during bidirectional streaming sessions. + + Parameters: + error: The exception that occurred. + code: Optional error code for programmatic handling (defaults to exception class name). + details: Optional additional error information. """ - audioOutput: Optional[AudioOutputEvent] - audioInput: Optional[AudioInputEvent] - imageInput: Optional[ImageInputEvent] - textOutput: Optional[TextOutputEvent] - transcript: Optional[TranscriptEvent] - interruptionDetected: Optional[InterruptionDetectedEvent] - BidirectionalConnectionStart: Optional[BidirectionalConnectionStartEvent] - BidirectionalConnectionEnd: Optional[BidirectionalConnectionEndEvent] - voiceActivity: Optional[VoiceActivityEvent] - usageMetrics: Optional[UsageMetricsEvent] + def __init__( + self, + error: Exception, + code: Optional[str] = None, + details: Optional[Dict[str, Any]] = None, + ): + super().__init__( + { + "bidirectional_error": True, + "error": error, + "error_message": str(error), + "error_code": code or type(error).__name__, + "error_details": details, + } + ) + + @property + def error(self) -> Exception: + return cast(Exception, self.get("error")) + + @property + def code(self) -> str: + return cast(str, self.get("error_code")) + + @property + def message(self) -> str: + return cast(str, self.get("error_message")) + + @property + def details(self) -> Optional[Dict[str, Any]]: + return cast(Optional[Dict[str, Any]], self.get("error_details")) + + +# ============================================================================ +# Type Unions +# ============================================================================ + +# Note: ToolResultEvent and ToolUseStreamEvent are reused from strands.types._events + +InputEvent = Union[TextInputEvent, AudioInputEvent, ImageInputEvent] + +OutputEvent = Union[ + SessionStartEvent, + TurnStartEvent, + AudioStreamEvent, + TranscriptStreamEvent, + InterruptionEvent, + TurnCompleteEvent, + MultimodalUsage, + SessionEndEvent, + ErrorEvent, +] diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index b894509c91..f13e2cf04f 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -19,6 +19,7 @@ ImageInputEvent, TextInputEvent, ) +from strands.types._events import ToolResultEvent from strands.types.tools import ToolResult @@ -188,7 +189,7 @@ async def test_send_all_content_types(mock_genai_client, model): await model.connect() # Test text input - text_input: TextInputEvent = {"text": "Hello", "role": "user"} + text_input = TextInputEvent(text="Hello", role="user") await model.send(text_input) mock_live_session.send_client_content.assert_called_once() call_args = mock_live_session.send_client_content.call_args @@ -197,31 +198,32 @@ async def test_send_all_content_types(mock_genai_client, model): assert content.parts[0].text == "Hello" # Test audio input - audio_input: AudioInputEvent = { - "audioData": b"audio_bytes", - "format": "pcm", - "sampleRate": 16000, - "channels": 1, - } + audio_input = AudioInputEvent( + audio=b"audio_bytes", + format="pcm", + sample_rate=16000, + channels=1, + ) await model.send(audio_input) mock_live_session.send_realtime_input.assert_called_once() # Test image input - image_input: ImageInputEvent = { - "imageData": b"image_bytes", - "mimeType": "image/jpeg", - "encoding": "raw", - } + image_input = ImageInputEvent( + image=b"image_bytes", + mime_type="image/jpeg", + encoding="raw", + ) await model.send(image_input) mock_live_session.send.assert_called_once() # Test tool result + from strands.types._events import ToolResultEvent tool_result: ToolResult = { "toolUseId": "tool-123", "status": "success", "content": [{"text": "Result: 42"}], } - await model.send(tool_result) + await model.send(ToolResultEvent(tool_result)) mock_live_session.send_tool_response.assert_called_once() await model.close() @@ -233,7 +235,7 @@ async def test_send_edge_cases(mock_genai_client, model): _, mock_live_session, _ = mock_genai_client # Test send when inactive - text_input: TextInputEvent = {"text": "Hello", "role": "user"} + text_input = TextInputEvent(text="Hello", role="user") await model.send(text_input) mock_live_session.send_client_content.assert_not_called() @@ -251,6 +253,11 @@ async def test_send_edge_cases(mock_genai_client, model): @pytest.mark.asyncio async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): """Test that receive() emits connection start and end events.""" + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( + SessionStartEvent, + SessionEndEvent, + ) + _, mock_live_session, _ = mock_genai_client mock_live_session.receive.return_value = agenerator([]) @@ -266,18 +273,24 @@ async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): # Verify connection start and end assert len(events) >= 2 - assert "BidirectionalConnectionStart" in events[0] - assert events[0]["BidirectionalConnectionStart"]["connectionId"] == model.session_id - assert "BidirectionalConnectionEnd" in events[-1] + assert isinstance(events[0], SessionStartEvent) + assert events[0].session_id == model.session_id + assert isinstance(events[-1], SessionEndEvent) @pytest.mark.asyncio async def test_event_conversion(mock_genai_client, model): """Test conversion of all Gemini Live event types to standard format.""" + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( + TranscriptStreamEvent, + AudioStreamEvent, + InterruptionEvent, + ) + _, _, _ = mock_genai_client await model.connect() - # Test text output + # Test text output (now converted to transcript) mock_text = unittest.mock.Mock() mock_text.text = "Hello from Gemini" mock_text.data = None @@ -285,9 +298,10 @@ async def test_event_conversion(mock_genai_client, model): mock_text.server_content = None text_event = model._convert_gemini_live_event(mock_text) - assert "textOutput" in text_event - assert text_event["textOutput"]["text"] == "Hello from Gemini" - assert text_event["textOutput"]["role"] == "assistant" + assert isinstance(text_event, TranscriptStreamEvent) + assert text_event.text == "Hello from Gemini" + assert text_event.source == "assistant" + assert text_event.is_final is True # Test audio output mock_audio = unittest.mock.Mock() @@ -297,9 +311,9 @@ async def test_event_conversion(mock_genai_client, model): mock_audio.server_content = None audio_event = model._convert_gemini_live_event(mock_audio) - assert "audioOutput" in audio_event - assert audio_event["audioOutput"]["audioData"] == b"audio_data" - assert audio_event["audioOutput"]["format"] == "pcm" + assert isinstance(audio_event, AudioStreamEvent) + assert audio_event.audio == b"audio_data" + assert audio_event.format == "pcm" # Test tool call mock_func_call = unittest.mock.Mock() @@ -334,8 +348,8 @@ async def test_event_conversion(mock_genai_client, model): mock_interrupt.server_content = mock_server_content interrupt_event = model._convert_gemini_live_event(mock_interrupt) - assert "interruptionDetected" in interrupt_event - assert interrupt_event["interruptionDetected"]["reason"] == "user_input" + assert isinstance(interrupt_event, InterruptionEvent) + assert interrupt_event.reason == "user_speech" await model.close() diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 10066a6938..bc2b0961c9 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -131,36 +131,42 @@ async def test_connection_edge_cases(nova_model, mock_client, mock_stream, model @pytest.mark.asyncio async def test_send_all_content_types(nova_model, mock_client, mock_stream): """Test sending all content types through unified send() method.""" + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( + TextInputEvent, + AudioInputEvent, + ) + from strands.types._events import ToolResultEvent + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model._client = mock_client await nova_model.connect() # Test text content - text_event = {"text": "Hello, Nova!", "role": "user"} + text_event = TextInputEvent(text="Hello, Nova!", role="user") await nova_model.send(text_event) # Should send contentStart, textInput, and contentEnd assert mock_stream.input_stream.send.call_count >= 3 # Test audio content - audio_event = { - "audioData": b"audio data", - "format": "pcm", - "sampleRate": 16000, - "channels": 1 - } + audio_event = AudioInputEvent( + audio=b"audio data", + format="pcm", + sample_rate=16000, + channels=1 + ) await nova_model.send(audio_event) # Should start audio connection and send audio assert nova_model.audio_connection_active assert mock_stream.input_stream.send.called # Test tool result - tool_result = { + tool_result: ToolResult = { "toolUseId": "tool-123", "status": "success", "content": [{"text": "Weather is sunny"}] } - await nova_model.send(tool_result) + await nova_model.send(ToolResultEvent(tool_result)) # Should send contentStart, toolResult, and contentEnd assert mock_stream.input_stream.send.called @@ -170,19 +176,25 @@ async def test_send_all_content_types(nova_model, mock_client, mock_stream): @pytest.mark.asyncio async def test_send_edge_cases(nova_model, mock_client, mock_stream, caplog): """Test send() edge cases and error handling.""" + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( + TextInputEvent, + ImageInputEvent, + ) + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model._client = mock_client # Test send when inactive - text_event = {"text": "Hello", "role": "user"} + text_event = TextInputEvent(text="Hello", role="user") await nova_model.send(text_event) # Should not raise # Test image content (not supported) await nova_model.connect() - image_event = { - "imageData": b"image data", - "mimeType": "image/jpeg" - } + image_event = ImageInputEvent( + image=b"image data", + mime_type="image/jpeg", + encoding="raw" + ) await nova_model.send(image_event) # Should log warning about unsupported image input assert any("not supported" in record.message.lower() for record in caplog.records) @@ -319,6 +331,8 @@ async def test_audio_connection_lifecycle(nova_model, mock_client, mock_stream): @pytest.mark.asyncio async def test_silence_detection(nova_model, mock_client, mock_stream): """Test that silence detection automatically ends audio input.""" + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import AudioInputEvent + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model._client = mock_client nova_model.silence_threshold = 0.1 # Short threshold for testing @@ -326,12 +340,12 @@ async def test_silence_detection(nova_model, mock_client, mock_stream): await nova_model.connect() # Send audio to start connection - audio_event = { - "audioData": b"audio data", - "format": "pcm", - "sampleRate": 16000, - "channels": 1 - } + audio_event = AudioInputEvent( + audio=b"audio data", + format="pcm", + sample_rate=16000, + channels=1 + ) await nova_model.send(audio_event) assert nova_model.audio_connection_active diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 1209150ba9..7495a4489b 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -222,11 +222,13 @@ async def async_connect(*args, **kwargs): @pytest.mark.asyncio async def test_send_all_content_types(mock_websockets_connect, model): """Test sending all content types through unified send() method.""" + from strands.types._events import ToolResultEvent + _, mock_ws = mock_websockets_connect await model.connect() # Test text input - text_input: TextInputEvent = {"text": "Hello", "role": "user"} + text_input = TextInputEvent(text="Hello", role="user") await model.send(text_input) calls = mock_ws.send.call_args_list messages = [json.loads(call[0][0]) for call in calls] @@ -236,12 +238,12 @@ async def test_send_all_content_types(mock_websockets_connect, model): assert len(response_create) > 0 # Test audio input - audio_input: AudioInputEvent = { - "audioData": b"audio_bytes", - "format": "pcm", - "sampleRate": 24000, - "channels": 1, - } + audio_input = AudioInputEvent( + audio=b"audio_bytes", + format="pcm", + sample_rate=24000, + channels=1, + ) await model.send(audio_input) calls = mock_ws.send.call_args_list messages = [json.loads(call[0][0]) for call in calls] @@ -257,7 +259,7 @@ async def test_send_all_content_types(mock_websockets_connect, model): "status": "success", "content": [{"text": "Result: 42"}], } - await model.send(tool_result) + await model.send(ToolResultEvent(tool_result)) calls = mock_ws.send.call_args_list messages = [json.loads(call[0][0]) for call in calls] item_create = [m for m in messages if m.get("type") == "conversation.item.create"] @@ -275,17 +277,17 @@ async def test_send_edge_cases(mock_websockets_connect, model): _, mock_ws = mock_websockets_connect # Test send when inactive - text_input: TextInputEvent = {"text": "Hello", "role": "user"} + text_input = TextInputEvent(text="Hello", role="user") await model.send(text_input) mock_ws.send.assert_not_called() # Test image input (not supported) await model.connect() - image_input: ImageInputEvent = { - "imageData": b"image_bytes", - "mimeType": "image/jpeg", - "encoding": "raw", - } + image_input = ImageInputEvent( + image=b"image_bytes", + mime_type="image/jpeg", + encoding="raw", + ) with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.logger") as mock_logger: await model.send(image_input) mock_logger.warning.assert_called_with("Image input not supported by OpenAI Realtime API") From b9aa1f898a6d9d4c00b19fcaecc28d7b8f2bd131 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 31 Oct 2025 17:33:56 +0100 Subject: [PATCH 02/16] fix: get tests working --- .../bidirectional_streaming/agent/agent.py | 58 ++++++--- .../event_loop/bidirectional_event_loop.py | 50 ++++---- .../models/gemini_live.py | 24 ++-- .../models/novasonic.py | 110 ++++++++---------- .../bidirectional_streaming/models/openai.py | 95 +++++++-------- .../tests/test_bidi_novasonic.py | 57 +++++---- .../tests/test_bidi_openai.py | 68 +++++++---- .../tests/test_gemini_live.py | 110 ++++++++++-------- .../types/bidirectional_streaming.py | 48 ++++---- .../models/test_gemini_live.py | 19 +-- .../models/test_novasonic.py | 94 ++++++++------- .../models/test_openai_realtime.py | 90 ++++++++------ 12 files changed, 454 insertions(+), 369 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index d74860222d..ab08978fbd 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -360,39 +360,67 @@ async def start(self) -> None: logger.debug("Conversation start - initializing session") self._session = await start_bidirectional_connection(self) - async def send(self, input_data: str | AudioInputEvent | ImageInputEvent) -> None: - """Send input to the model (text, audio, or image). + async def send(self, input_data: str | AudioInputEvent | ImageInputEvent | dict) -> None: + """Send input to the model (text, audio, image, or event dict). Unified method for sending text, audio, and image input to the model during - an active conversation session. + an active conversation session. Accepts TypedEvent instances or plain dicts + (e.g., from WebSocket clients) which are automatically reconstructed. Args: - input_data: String for text, AudioInputEvent for audio, or ImageInputEvent for images. + input_data: Can be: + - str: Text message from user + - AudioInputEvent: Audio data with format/sample rate + - ImageInputEvent: Image data with MIME type + - dict: Event dictionary (will be reconstructed to TypedEvent) Raises: ValueError: If no active session or invalid input type. + + Example: + await agent.send("Hello") + await agent.send(AudioInputEvent(audio="base64...", format="pcm", ...)) + await agent.send({"type": "bidirectional_text_input", "text": "Hello", "role": "user"}) """ self._validate_active_session() + # Handle string input if isinstance(input_data, str): # Add user text message to history self.messages.append({"role": "user", "content": input_data}) - logger.debug("Text sent: %d characters", len(input_data)) - # Create TextInputEvent for send() - text_event = {"text": input_data, "role": "user"} + from ..types.bidirectional_streaming import TextInputEvent + text_event = TextInputEvent(text=input_data, role="user") await self._session.model.send(text_event) - elif isinstance(input_data, dict) and "audioData" in input_data: - # Handle audio input - already in AudioInputEvent format - await self._session.model.send(input_data) - elif isinstance(input_data, dict) and "imageData" in input_data: - # Handle image input - already in ImageInputEvent format + return + + # Handle dict - reconstruct TypedEvent for WebSocket integration + if isinstance(input_data, dict) and "type" in input_data: + from ..types.bidirectional_streaming import TextInputEvent + event_type = input_data["type"] + if event_type == "bidirectional_text_input": + input_data = TextInputEvent(text=input_data["text"], role=input_data["role"]) + elif event_type == "bidirectional_audio_input": + input_data = AudioInputEvent( + audio=input_data["audio"], + format=input_data["format"], + sample_rate=input_data["sample_rate"], + channels=input_data["channels"] + ) + elif event_type == "bidirectional_image_input": + input_data = ImageInputEvent( + image=input_data["image"], + mime_type=input_data["mime_type"] + ) + else: + raise ValueError(f"Unknown event type: {event_type}") + + # Handle TypedEvent instances + if isinstance(input_data, (AudioInputEvent, ImageInputEvent, TextInputEvent)): await self._session.model.send(input_data) else: raise ValueError( - "Input must be either a string (text), AudioInputEvent " - "(dict with audioData, format, sampleRate, channels), or ImageInputEvent " - "(dict with imageData, mimeType, encoding)" + f"Input must be a string, TypedEvent, or event dict, got: {type(input_data)}" ) async def receive(self) -> AsyncIterable[dict[str, Any]]: diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index d1d6e90b32..27732294a5 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -223,7 +223,9 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: try: while True: event = session.agent._output_queue.get_nowait() - if event.get("audioOutput"): + # Check for audio events + event_type = event.get("type", "") + if event_type == "bidirectional_audio_stream": audio_cleared += 1 else: # Keep non-audio events @@ -267,8 +269,11 @@ async def _process_model_events(session: BidirectionalConnection) -> None: strands_event = provider_event - # Handle interruption detection (provider converts raw patterns to interruptionDetected) - if strands_event.get("interruptionDetected"): + # Get event type + event_type = strands_event.get("type", "") + + # Handle interruption detection + if event_type == "bidirectional_interruption": logger.debug("Interruption forwarded") await _handle_interruption(session) # Forward interruption event to agent for application-level handling @@ -276,26 +281,23 @@ async def _process_model_events(session: BidirectionalConnection) -> None: continue # Queue tool requests for concurrent execution - if strands_event.get("toolUse"): - tool_name = strands_event["toolUse"].get("name") - logger.debug("Tool usage detected: %s", tool_name) - await session.tool_queue.put(strands_event["toolUse"]) + if event_type == "tool_use": + tool_use = strands_event.get("tool_use") + if tool_use: + tool_name = tool_use.get("name") + logger.debug("Tool usage detected: %s", tool_name) + await session.tool_queue.put(tool_use) continue - # Send output events to Agent for receive() method - if strands_event.get("audioOutput") or strands_event.get("textOutput"): - await session.agent._output_queue.put(strands_event) + # Send all output events to Agent for receive() method + await session.agent._output_queue.put(strands_event) - # Update Agent conversation history using existing patterns - if strands_event.get("messageStop"): - logger.debug("Message added to history") - session.agent.messages.append(strands_event["messageStop"]["message"]) - - # Handle user audio transcripts - add to message history - if strands_event.get("textOutput") and strands_event["textOutput"].get("role") == "user": - user_transcript = strands_event["textOutput"]["text"] - if user_transcript.strip(): # Only add non-empty transcripts - user_message = {"role": "user", "content": user_transcript} + # Update Agent conversation history for user transcripts + if event_type == "bidirectional_transcript_stream": + source = strands_event.get("source") + text = strands_event.get("text", "") + if source == "user" and text.strip(): + user_message = {"role": "user", "content": text} session.agent.messages.append(user_message) logger.debug("User transcript added to history") @@ -434,8 +436,8 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: tool_result = tool_event.tool_result tool_use_id = tool_result.get("toolUseId") - # Send result through send() method - await session.model.send(tool_result) + # Send ToolResultEvent through send() method + await session.model.send(tool_event) logger.debug("Tool result sent: %s", tool_use_id) # Handle streaming events if needed later @@ -464,14 +466,14 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: except Exception as e: logger.error("Tool execution error: %s - %s", tool_name, str(e)) - # Send error result + # Send error result wrapped in ToolResultEvent error_result: ToolResult = { "toolUseId": tool_id, "status": "error", "content": [{"text": f"Error: {str(e)}"}] } try: - await session.model.send(error_result) + await session.model.send(ToolResultEvent(error_result)) logger.debug("Error result sent: %s", tool_id) except Exception: logger.error("Failed to send error result: %s", tool_id) diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index f7bfebac8e..02044125f7 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -248,8 +248,10 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic # Handle audio output using SDK's built-in data property if message.data: + # Convert bytes to base64 string for JSON serializability + audio_b64 = base64.b64encode(message.data).decode('utf-8') return AudioStreamEvent( - audio=message.data, + audio=audio_b64, format="pcm", sample_rate=GEMINI_OUTPUT_SAMPLE_RATE, channels=GEMINI_CHANNELS @@ -311,9 +313,12 @@ async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: This automatically triggers VAD and can interrupt ongoing responses. """ try: + # Decode base64 audio to bytes for SDK + audio_bytes = base64.b64decode(audio_input.audio) + # Create audio blob for the SDK audio_blob = genai_types.Blob( - data=audio_input.audio, + data=audio_bytes, mime_type=f"audio/pcm;rate={GEMINI_INPUT_SAMPLE_RATE}" ) @@ -330,21 +335,10 @@ async def _send_image_content(self, image_input: ImageInputEvent) -> None: Images are sent as base64-encoded data with MIME type. """ try: - # Prepare the message based on encoding - if image_input.encoding == "base64": - # Data is already base64 encoded - if isinstance(image_input.image, bytes): - data_str = image_input.image.decode() - else: - data_str = image_input.image - else: - # Raw bytes - need to base64 encode - data_str = base64.b64encode(image_input.image).decode() - - # Create the message in the format expected by Gemini Live + # Image is already base64 encoded in the event msg = { "mime_type": image_input.mime_type, - "data": data_str + "data": image_input.image } # Send using the same method as the GitHub example diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 63afc33787..fa7465454d 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -40,6 +40,7 @@ ImageInputEvent, InterruptionEvent, MultimodalUsage, + OutputEvent, SessionEndEvent, SessionStartEvent, TextInputEvent, @@ -271,12 +272,12 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: logger.debug("Nova events - starting event stream") - # Emit connection start event to Strands event system - connection_start: BidirectionalConnectionStartEvent = { - "connectionId": self.prompt_name, - "metadata": {"provider": "nova_sonic", "model_id": self.model_id}, - } - yield {"BidirectionalConnectionStart": connection_start} + # Emit session start event + yield SessionStartEvent( + session_id=self.prompt_name, + model=self.model_id, + capabilities=["audio", "tools"] + ) try: while self._active: @@ -296,14 +297,10 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: except Exception as e: logger.error("Error receiving Nova Sonic event: %s", e) logger.error(traceback.format_exc()) + yield ErrorEvent(error=e) finally: - # Emit connection end event when exiting - connection_end: BidirectionalConnectionEndEvent = { - "connectionId": self.prompt_name, - "reason": "connection_complete", - "metadata": {"provider": "nova_sonic"}, - } - yield {"BidirectionalConnectionEnd": connection_end} + # Emit session end event + yield SessionEndEvent(reason="complete") async def send( self, @@ -372,9 +369,7 @@ async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: if self.silence_task and not self.silence_task.done(): self.silence_task.cancel() - # Convert audio to Nova Sonic base64 format - nova_audio_data = base64.b64encode(audio_input.audio).decode("utf-8") - + # Audio is already base64 encoded in the event # Send audio input event audio_event = json.dumps( { @@ -382,7 +377,7 @@ async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: "audioInput": { "promptName": self.prompt_name, "contentName": self.audio_content_name, - "content": nova_audio_data, + "content": audio_input.audio, } } } @@ -513,82 +508,79 @@ async def close(self) -> None: finally: logger.debug("Nova connection closed") - def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | None: - """Convert Nova Sonic events to provider-agnostic format.""" + def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: + """Convert Nova Sonic events to TypedEvent format.""" # Handle audio output if "audioOutput" in nova_event: + # Audio is already base64 string from Nova Sonic audio_content = nova_event["audioOutput"]["content"] - audio_bytes = base64.b64decode(audio_content) - - audio_output: AudioOutputEvent = { - "audioData": audio_bytes, - "format": "pcm", - "sampleRate": 24000, - "channels": 1, - "encoding": "base64", - } - - return {"audioOutput": audio_output} + return AudioStreamEvent( + audio=audio_content, + format="pcm", + sample_rate=24000, + channels=1 + ) - # Handle text output + # Handle text output (transcripts) elif "textOutput" in nova_event: text_content = nova_event["textOutput"]["content"] # Use stored role from contentStart event, fallback to event role role = getattr(self, "_current_role", nova_event["textOutput"].get("role", "assistant")) - # Check for Nova Sonic interruption pattern (matches working sample) + # Check for Nova Sonic interruption pattern if '{ "interrupted" : true }' in text_content: logger.debug("Nova interruption detected in text") - interruption: InterruptionDetectedEvent = {"reason": "user_input"} - return {"interruptionDetected": interruption} - - # Show transcription for user speech - ALWAYS show these regardless of DEBUG flag - if role == "USER": - print(f"User: {text_content}") - elif role == "ASSISTANT": - print(f"Assistant: {text_content}") + return InterruptionEvent(reason="user_speech", turn_id=None) - text_output: TextOutputEvent = {"text": text_content, "role": role.lower()} - - return {"textOutput": text_output} + return TranscriptStreamEvent( + text=text_content, + source="user" if role == "USER" else "assistant", + is_final=True + ) # Handle tool use elif "toolUse" in nova_event: tool_use = nova_event["toolUse"] - tool_use_event: ToolUse = { "toolUseId": tool_use["toolUseId"], "name": tool_use["toolName"], "input": json.loads(tool_use["content"]), } - - return {"toolUse": tool_use_event} + # Return dict with tool_use for event loop processing + return {"type": "tool_use", "tool_use": tool_use_event} # Handle interruption elif nova_event.get("stopReason") == "INTERRUPTED": logger.debug("Nova interruption stop reason") + return InterruptionEvent(reason="user_speech", turn_id=None) - interruption: InterruptionDetectedEvent = {"reason": "user_input"} - - return {"interruptionDetected": interruption} - - # Handle usage events - convert to standardized format + # Handle usage events - convert to multimodal usage format elif "usageEvent" in nova_event: usage_data = nova_event["usageEvent"] - usage_metrics: UsageMetricsEvent = { - "totalTokens": usage_data.get("totalTokens", 0), - "inputTokens": usage_data.get("totalInputTokens", 0), - "outputTokens": usage_data.get("totalOutputTokens", 0), - "audioTokens": usage_data.get("details", {}).get("total", {}).get("output", {}).get("speechTokens", 0) - } - return {"usageMetrics": usage_metrics} + total_input = usage_data.get("totalInputTokens", 0) + total_output = usage_data.get("totalOutputTokens", 0) + + return MultimodalUsage( + input_tokens=total_input, + output_tokens=total_output, + total_tokens=usage_data.get("totalTokens", total_input + total_output) + ) # Handle content start events (track role) elif "contentStart" in nova_event: role = nova_event["contentStart"].get("role", "unknown") # Store role for subsequent text output events self._current_role = role - return None + # Emit turn start event + return TurnStartEvent(turn_id=str(uuid.uuid4())) + + # Handle content stop events + elif "contentStop" in nova_event: + stop_reason = nova_event["contentStop"].get("stopReason", "complete") + return TurnCompleteEvent( + turn_id=str(uuid.uuid4()), + stop_reason="interrupted" if stop_reason == "INTERRUPTED" else "complete" + ) # Handle other events else: diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index be954deadc..15d1bbf86c 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -173,15 +173,21 @@ def _require_active(self) -> bool: """Check if session is active.""" return self._active - def _create_text_event(self, text: str, role: str) -> dict[str, any]: - """Create standardized text output event.""" - text_output: TextOutputEvent = {"text": text, "role": role} - return {"textOutput": text_output} + def _create_text_event(self, text: str, role: str) -> TranscriptStreamEvent: + """Create standardized transcript event.""" + return TranscriptStreamEvent( + text=text, + source="user" if role == "user" else "assistant", + is_final=True + ) - def _create_voice_activity_event(self, activity_type: str) -> dict[str, any]: - """Create standardized voice activity event.""" - voice_activity: VoiceActivityEvent = {"activityType": activity_type} - return {"voiceActivity": voice_activity} + def _create_voice_activity_event(self, activity_type: str) -> InterruptionEvent | None: + """Create standardized interruption event for voice activity.""" + # Only speech_started triggers interruption + if activity_type == "speech_started": + return InterruptionEvent(reason="user_speech", turn_id=None) + # Other voice activity events are logged but don't create events + return None def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict: """Build session configuration for OpenAI Realtime API.""" @@ -273,12 +279,13 @@ async def _process_responses(self) -> None: logger.debug("OpenAI Realtime response processor stopped") async def receive(self) -> AsyncIterable[OutputEvent]: - """Receive OpenAI events and convert to Strands format.""" - connection_start: BidirectionalConnectionStartEvent = { - "connectionId": self.session_id, - "metadata": {"provider": "openai_realtime", "model": self.model}, - } - yield {"BidirectionalConnectionStart": connection_start} + """Receive OpenAI events and convert to Strands TypedEvent format.""" + # Emit session start event + yield SessionStartEvent( + session_id=self.session_id, + model=self.model, + capabilities=["audio", "tools"] + ) try: while self._active: @@ -292,29 +299,24 @@ async def receive(self) -> AsyncIterable[OutputEvent]: except Exception as e: logger.error("Error receiving OpenAI Realtime event: %s", e) + yield ErrorEvent(error=e) finally: - connection_end: BidirectionalConnectionEndEvent = { - "connectionId": self.session_id, - "reason": "connection_complete", - "metadata": {"provider": "openai_realtime"}, - } - yield {"BidirectionalConnectionEnd": connection_end} + # Emit session end event + yield SessionEndEvent(reason="complete") - def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] | None: - """Convert OpenAI events to Strands format.""" + def _convert_openai_event(self, openai_event: dict[str, any]) -> OutputEvent | None: + """Convert OpenAI events to Strands TypedEvent format.""" event_type = openai_event.get("type") # Audio output if event_type == "response.output_audio.delta": - audio_data = base64.b64decode(openai_event["delta"]) - audio_output: AudioOutputEvent = { - "audioData": audio_data, - "format": "pcm", - "sampleRate": 24000, - "channels": 1, - "encoding": None, - } - return {"audioOutput": audio_output} + # Audio is already base64 string from OpenAI + return AudioStreamEvent( + audio=openai_event["delta"], + format="pcm", + sample_rate=24000, + channels=1 + ) # Assistant text output events - combine multiple similar events elif event_type in ["response.output_text.delta", "response.output_audio_transcript.delta"]: @@ -359,7 +361,8 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] "input": json.loads(function_call["arguments"]) if function_call["arguments"] else {}, } del self._function_call_buffer[call_id] - return {"toolUse": tool_use} + # Return dict with tool_use for event loop processing + return {"type": "tool_use", "tool_use": tool_use} except (json.JSONDecodeError, KeyError) as e: logger.warning("Error parsing function arguments for %s: %s", call_id, e) del self._function_call_buffer[call_id] @@ -385,23 +388,14 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] elif event_type == "conversation.item.done": logger.debug("OpenAI conversation item done: %s", openai_event.get("item", {}).get("id")) - + # This event signals turn completion - emit TurnCompleteEvent item = openai_event.get("item", {}) if item.get("type") == "message" and item.get("role") == "assistant": - content_parts = item.get("content", []) - if content_parts: - message_content = [] - for content_part in content_parts: - if content_part.get("type") == "output_text": - message_content.append({"type": "text", "text": content_part.get("text", "")}) - elif content_part.get("type") == "output_audio": - transcript = content_part.get("transcript", "") - if transcript: - message_content.append({"type": "text", "text": transcript}) - - if message_content: - message = {"role": "assistant", "content": message_content} - return {"messageStop": {"message": message}} + item_id = item.get("id", "unknown") + return TurnCompleteEvent( + turn_id=item_id, + stop_reason="complete" + ) return None # Response output events - combine similar events @@ -452,6 +446,7 @@ async def send( return try: + # Note: TypedEvent inherits from dict, so isinstance checks for TypedEvent must come first if isinstance(content, TextInputEvent): await self._send_text_content(content.text) elif isinstance(content, AudioInputEvent): @@ -464,14 +459,14 @@ async def send( if tool_result: await self._send_tool_result(tool_result) else: - logger.warning(f"Unknown content type: {type(content)}") + logger.warning(f"Unknown content type: {type(content).__name__}") except Exception as e: logger.error(f"Error sending content: {e}") async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: """Internal: Send audio content to OpenAI for processing.""" - audio_base64 = base64.b64encode(audio_input.audio).decode("utf-8") - await self._send_event({"type": "input_audio_buffer.append", "audio": audio_base64}) + # Audio is already base64 encoded in the event + await self._send_event({"type": "input_audio_buffer.append", "audio": audio_input.audio}) async def _send_text_content(self, text: str) -> None: """Internal: Send text content to OpenAI for processing.""" diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py index b0a41f20d4..b538fc0238 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py @@ -5,6 +5,7 @@ """ import asyncio +import base64 import sys from pathlib import Path @@ -129,33 +130,36 @@ async def receive(agent, context): """Receive and process events from agent.""" try: async for event in agent.receive(): - # Handle audio output - if "audioOutput" in event: + # Get event type + event_type = event.get("type", "unknown") + + # Handle audio stream events (bidirectional_audio_stream) + if event_type == "bidirectional_audio_stream": if not context.get("interrupted", False): - context["audio_out"].put_nowait(event["audioOutput"]["audioData"]) + # Decode base64 audio string to bytes for playback + audio_b64 = event["audio"] + audio_data = base64.b64decode(audio_b64) + context["audio_out"].put_nowait(audio_data) - # Handle interruption events - elif "interruptionDetected" in event: + # Handle interruption events (bidirectional_interruption) + elif event_type == "bidirectional_interruption": context["interrupted"] = True - elif "interrupted" in event: - context["interrupted"] = True - - # Handle text output with interruption detection - elif "textOutput" in event: - text_content = event["textOutput"].get("content", "") - role = event["textOutput"].get("role", "unknown") - - # Check for text-based interruption patterns - if '{ "interrupted" : true }' in text_content: - context["interrupted"] = True - elif "interrupted" in text_content.lower(): - context["interrupted"] = True - # Log text output - if role.upper() == "USER": + # Handle transcript events (bidirectional_transcript_stream) + elif event_type == "bidirectional_transcript_stream": + text_content = event.get("text", "") + source = event.get("source", "unknown") + + # Log transcript output + if source == "user": print(f"User: {text_content}") - elif role.upper() == "ASSISTANT": + elif source == "assistant": print(f"Assistant: {text_content}") + + # Handle turn complete events (bidirectional_turn_complete) + elif event_type == "bidirectional_turn_complete": + # Reset interrupted state since the turn is complete + context["interrupted"] = False except asyncio.CancelledError: pass @@ -167,7 +171,16 @@ async def send(agent, context): while time.time() - context["start_time"] < context["duration"]: try: audio_bytes = context["audio_in"].get_nowait() - audio_event = {"audioData": audio_bytes, "format": "pcm", "sampleRate": 16000, "channels": 1} + # Create audio event using TypedEvent + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import AudioInputEvent + + audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') + audio_event = AudioInputEvent( + audio=audio_b64, + format="pcm", + sample_rate=16000, + channels=1 + ) await agent.send(audio_event) except asyncio.QueueEmpty: await asyncio.sleep(0.01) # Restored to working timing diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py index 90e82c2bc9..d270637be5 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py @@ -2,6 +2,7 @@ """Test OpenAI Realtime API speech-to-speech interaction.""" import asyncio +import base64 import os import sys import time @@ -118,35 +119,48 @@ async def receive(agent, context): if not context["active"]: break - # Handle audio output - if "audioOutput" in event: - audio_data = event["audioOutput"]["audioData"] + # Get event type + event_type = event.get("type", "unknown") + + # Handle audio stream events (bidirectional_audio_stream) + if event_type == "bidirectional_audio_stream": + # Decode base64 audio string to bytes for playback + audio_b64 = event["audio"] + audio_data = base64.b64decode(audio_b64) if not context.get("interrupted", False): await context["audio_out"].put(audio_data) - # Handle text output (transcripts) - elif "textOutput" in event: - text_output = event["textOutput"] - role = text_output.get("role", "assistant") - text = text_output.get("text", "").strip() + # Handle transcript events (bidirectional_transcript_stream) + elif event_type == "bidirectional_transcript_stream": + source = event.get("source", "assistant") + text = event.get("text", "").strip() if text: - if role == "user": - print(f"User: {text}") - elif role == "assistant": - print(f"Assistant: {text}") + if source == "user": + print(f"🎤 User: {text}") + elif source == "assistant": + print(f"🔊 Assistant: {text}") - # Handle interruption detection - elif "interruptionDetected" in event: + # Handle interruption events (bidirectional_interruption) + elif event_type == "bidirectional_interruption": context["interrupted"] = True + print("⚠️ Interruption detected") + + # Handle session start events (bidirectional_session_start) + elif event_type == "bidirectional_session_start": + print(f"✓ Session started: {event.get('model', 'unknown')}") - # Handle connection events - elif "BidirectionalConnectionStart" in event: - pass # Silent connection start - elif "BidirectionalConnectionEnd" in event: + # Handle session end events (bidirectional_session_end) + elif event_type == "bidirectional_session_end": + print(f"✓ Session ended: {event.get('reason', 'unknown')}") context["active"] = False break + + # Handle turn complete events (bidirectional_turn_complete) + elif event_type == "bidirectional_turn_complete": + # Reset interrupted state since the turn is complete + context["interrupted"] = False except asyncio.CancelledError: pass @@ -163,13 +177,17 @@ async def send(agent, context): try: audio_bytes = await asyncio.wait_for(context["audio_in"].get(), timeout=0.1) - # Create audio event in expected format - audio_event = { - "audioData": audio_bytes, - "format": "pcm", - "sampleRate": 24000, - "channels": 1 - } + # Create audio event using TypedEvent + # Encode audio bytes to base64 string for JSON serializability + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import AudioInputEvent + + audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') + audio_event = AudioInputEvent( + audio=audio_b64, + format="pcm", + sample_rate=24000, + channels=1 + ) await agent.send(audio_event) diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py index 23e97bd5d2..0bd283eb99 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py @@ -145,56 +145,58 @@ async def receive(agent, context): """Receive and process events from agent.""" try: async for event in agent.receive(): - # Debug: Log all event types - event_types = [k for k in event.keys() if not k.startswith('_')] - if event_types: - logger.debug(f"Received event types: {event_types}") + # Debug: Log event type and keys + event_type = event.get("type", "unknown") + event_keys = list(event.keys()) + logger.debug(f"Received event type: {event_type}, keys: {event_keys}") - # Handle audio output - if "audioOutput" in event: + # Handle audio stream events (bidirectional_audio_stream) + if event_type == "bidirectional_audio_stream": if not context.get("interrupted", False): - context["audio_out"].put_nowait(event["audioOutput"]["audioData"]) - - # Handle interruption events - elif "interruptionDetected" in event: - context["interrupted"] = True - elif "interrupted" in event: + # Decode base64 audio string to bytes for playback + audio_b64 = event["audio"] + audio_data = base64.b64decode(audio_b64) + context["audio_out"].put_nowait(audio_data) + logger.info(f"🔊 Audio queued for playback: {len(audio_data)} bytes") + + # Handle interruption events (bidirectional_interruption) + elif event_type == "bidirectional_interruption": context["interrupted"] = True + logger.info("Interruption detected") - # Handle text output - elif "textOutput" in event: - text_content = event["textOutput"].get("text", "") - role = event["textOutput"].get("role", "unknown") - - # Check for text-based interruption patterns - if '{ "interrupted" : true }' in text_content: - context["interrupted"] = True - elif "interrupted" in text_content.lower(): - context["interrupted"] = True - - # Log text output - if role.upper() == "USER": - print(f"User: {text_content}") - elif role.upper() == "ASSISTANT": - print(f"Assistant: {text_content}") - - # Handle transcript events (audio transcriptions) - elif "transcript" in event: - transcript_text = event["transcript"].get("text", "") - transcript_role = event["transcript"].get("role", "unknown") - transcript_type = event["transcript"].get("type", "unknown") + # Handle transcript events (bidirectional_transcript_stream) + elif event_type == "bidirectional_transcript_stream": + transcript_text = event.get("text", "") + transcript_source = event.get("source", "unknown") + is_final = event.get("is_final", False) - # Print transcripts with special formatting to distinguish from text output - if transcript_role.upper() == "USER": - print(f"🎤 User (transcript): {transcript_text}") - elif transcript_role.upper() == "ASSISTANT": - print(f"🔊 Assistant (transcript): {transcript_text}") + # Print transcripts with special formatting + if transcript_source == "user": + print(f"🎤 User: {transcript_text}") + elif transcript_source == "assistant": + print(f"🔊 Assistant: {transcript_text}") - # Handle turn complete events - elif "turnComplete" in event: - logger.debug("Turn complete event received - model ready for next input") + # Handle turn complete events (bidirectional_turn_complete) + elif event_type == "bidirectional_turn_complete": + logger.debug("Turn complete - model ready for next input") # Reset interrupted state since the turn is complete context["interrupted"] = False + + # Handle session start events (bidirectional_session_start) + elif event_type == "bidirectional_session_start": + logger.info(f"Session started: {event.get('model', 'unknown')}") + + # Handle session end events (bidirectional_session_end) + elif event_type == "bidirectional_session_end": + logger.info(f"Session ended: {event.get('reason', 'unknown')}") + + # Handle error events (bidirectional_error) + elif event_type == "bidirectional_error": + logger.error(f"Error: {event.get('error_message', 'unknown')}") + + # Handle turn start events (bidirectional_turn_start) + elif event_type == "bidirectional_turn_start": + logger.debug(f"Turn started: {event.get('turn_id', 'unknown')}") except asyncio.CancelledError: pass @@ -246,11 +248,12 @@ async def get_frames(context): # Send frame to agent as image input try: - image_event = { - "imageData": frame["data"], - "mimeType": frame["mime_type"], - "encoding": "base64" - } + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ImageInputEvent + + image_event = ImageInputEvent( + image=frame["data"], # Already base64 encoded + mime_type=frame["mime_type"] + ) await context["agent"].send(image_event) print("📸 Frame sent to model") except Exception as e: @@ -272,7 +275,16 @@ async def send(agent, context): while time.time() - context["start_time"] < context["duration"]: try: audio_bytes = context["audio_in"].get_nowait() - audio_event = {"audioData": audio_bytes, "format": "pcm", "sampleRate": 16000, "channels": 1} + # Create audio event using TypedEvent + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import AudioInputEvent + + audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') + audio_event = AudioInputEvent( + audio=audio_b64, + format="pcm", + sample_rate=16000, + channels=1 + ) await agent.send(audio_event) except asyncio.QueueEmpty: await asyncio.sleep(0.01) @@ -304,7 +316,7 @@ async def main(duration=180): model = GeminiLiveModel( model_id="gemini-2.5-flash-native-audio-preview-09-2025", api_key=api_key, - params={ + live_config={ "response_modalities": ["AUDIO"], "output_audio_transcription": {}, # Enable output transcription "input_audio_transcription": {} # Enable input transcription diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index e7af3ad433..160b15a27a 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -9,12 +9,14 @@ - Session lifecycle management - Provider-agnostic event types - Type-safe discriminated unions with TypedEvent +- JSON-serializable events (audio/images stored as base64 strings) Audio format normalization: - Supports PCM, WAV, Opus, and MP3 formats - Standardizes sample rates (16kHz, 24kHz, 48kHz) - Normalizes channel configurations (mono/stereo) - Abstracts provider-specific encodings +- Audio data stored as base64-encoded strings for JSON compatibility """ from typing import Any, Dict, List, Literal, Optional, Union, cast @@ -69,7 +71,7 @@ class AudioInputEvent(TypedEvent): Used for sending audio data through the send() method. Parameters: - audio: Raw audio bytes to send to model (not base64 encoded). + audio: Base64-encoded audio string to send to model. format: Audio format from SUPPORTED_AUDIO_FORMATS. sample_rate: Sample rate from SUPPORTED_SAMPLE_RATES. channels: Channel count from SUPPORTED_CHANNELS. @@ -77,7 +79,7 @@ class AudioInputEvent(TypedEvent): def __init__( self, - audio: bytes, + audio: str, format: Literal["pcm", "wav", "opus", "mp3"], sample_rate: Literal[16000, 24000, 48000], channels: Literal[1, 2], @@ -93,8 +95,8 @@ def __init__( ) @property - def audio(self) -> bytes: - return cast(bytes, self.get("audio")) + def audio(self) -> str: + return cast(str, self.get("audio")) @property def format(self) -> str: @@ -112,42 +114,34 @@ def channels(self) -> int: class ImageInputEvent(TypedEvent): """Image input event for sending images/video frames to the model. - Used for sending image data through the send() method. Supports both - raw image bytes and base64-encoded data. + Used for sending image data through the send() method. Parameters: - image: Image bytes (raw or base64-encoded string). + image: Base64-encoded image string. mime_type: MIME type (e.g., "image/jpeg", "image/png"). - encoding: How the image data is encoded. """ def __init__( self, - image: Union[bytes, str], + image: str, mime_type: str, - encoding: Literal["base64", "raw"], ): super().__init__( { "type": "bidirectional_image_input", "image": image, "mime_type": mime_type, - "encoding": encoding, } ) @property - def image(self) -> Union[bytes, str]: - return cast(Union[bytes, str], self.get("image")) + def image(self) -> str: + return cast(str, self.get("image")) @property def mime_type(self) -> str: return cast(str, self.get("mime_type")) - @property - def encoding(self) -> str: - return cast(str, self.get("encoding")) - # ============================================================================ # Output Events (received via session.receive_events()) @@ -205,7 +199,7 @@ class AudioStreamEvent(TypedEvent): """Streaming audio output from the model. Parameters: - audio: Raw audio data as bytes (not base64 encoded). + audio: Base64-encoded audio string. format: Audio encoding format. sample_rate: Number of audio samples per second in Hz. channels: Number of audio channels (1=mono, 2=stereo). @@ -213,7 +207,7 @@ class AudioStreamEvent(TypedEvent): def __init__( self, - audio: bytes, + audio: str, format: Literal["pcm", "wav", "opus", "mp3"], sample_rate: Literal[16000, 24000, 48000], channels: Literal[1, 2], @@ -229,8 +223,8 @@ def __init__( ) @property - def audio(self) -> bytes: - return cast(bytes, self.get("audio")) + def audio(self) -> str: + return cast(str, self.get("audio")) @property def format(self) -> str: @@ -436,8 +430,11 @@ class ErrorEvent(TypedEvent): Similar to strands.types._events.ForceStopEvent, this event wraps exceptions that occur during bidirectional streaming sessions. + Note: The Exception object is not stored in the event data to maintain JSON + serializability. Only the error message, code, and details are stored. + Parameters: - error: The exception that occurred. + error: The exception that occurred (used to extract message and type). code: Optional error code for programmatic handling (defaults to exception class name). details: Optional additional error information. """ @@ -450,18 +447,13 @@ def __init__( ): super().__init__( { - "bidirectional_error": True, - "error": error, + "type": "bidirectional_error", "error_message": str(error), "error_code": code or type(error).__name__, "error_details": details, } ) - @property - def error(self) -> Exception: - return cast(Exception, self.get("error")) - @property def code(self) -> str: return cast(str, self.get("error_code")) diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index 86b75fd21d..f48f04910d 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -197,9 +197,11 @@ async def test_send_all_content_types(mock_genai_client, model): assert content.role == "user" assert content.parts[0].text == "Hello" - # Test audio input + # Test audio input (base64 encoded) + import base64 + audio_b64 = base64.b64encode(b"audio_bytes").decode('utf-8') audio_input = AudioInputEvent( - audio=b"audio_bytes", + audio=audio_b64, format="pcm", sample_rate=16000, channels=1, @@ -207,11 +209,11 @@ async def test_send_all_content_types(mock_genai_client, model): await model.send(audio_input) mock_live_session.send_realtime_input.assert_called_once() - # Test image input + # Test image input (base64 encoded, no encoding parameter) + image_b64 = base64.b64encode(b"image_bytes").decode('utf-8') image_input = ImageInputEvent( - image=b"image_bytes", + image=image_b64, mime_type="image/jpeg", - encoding="raw", ) await model.send(image_input) mock_live_session.send.assert_called_once() @@ -303,7 +305,8 @@ async def test_event_conversion(mock_genai_client, model): assert text_event.source == "assistant" assert text_event.is_final is True - # Test audio output + # Test audio output (now returns base64 encoded string) + import base64 mock_audio = unittest.mock.Mock() mock_audio.text = None mock_audio.data = b"audio_data" @@ -312,7 +315,9 @@ async def test_event_conversion(mock_genai_client, model): audio_event = model._convert_gemini_live_event(mock_audio) assert isinstance(audio_event, AudioStreamEvent) - assert audio_event.audio == b"audio_data" + # Audio is now base64 encoded + expected_b64 = base64.b64encode(b"audio_data").decode('utf-8') + assert audio_event.audio == expected_b64 assert audio_event.format == "pcm" # Test tool call diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 17f6b8e57d..83c5e2f829 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -148,9 +148,10 @@ async def test_send_all_content_types(nova_model, mock_client, mock_stream): # Should send contentStart, textInput, and contentEnd assert mock_stream.input_stream.send.call_count >= 3 - # Test audio content + # Test audio content (base64 encoded) + audio_b64 = base64.b64encode(b"audio data").decode('utf-8') audio_event = AudioInputEvent( - audio=b"audio data", + audio=audio_b64, format="pcm", sample_rate=16000, channels=1 @@ -188,12 +189,13 @@ async def test_send_edge_cases(nova_model, mock_client, mock_stream, caplog): text_event = TextInputEvent(text="Hello", role="user") await nova_model.send(text_event) # Should not raise - # Test image content (not supported) + # Test image content (not supported, base64 encoded, no encoding parameter) await nova_model.connect() + import base64 + image_b64 = base64.b64encode(b"image data").decode('utf-8') image_event = ImageInputEvent( - image=b"image data", + image=image_b64, mime_type="image/jpeg", - encoding="raw" ) await nova_model.send(image_event) # Should log warning about unsupported image input @@ -224,36 +226,41 @@ async def mock_wait_for(*args, **kwargs): async for event in nova_model.receive(): events.append(event) - # Should have connection start and end + # Should have session start and end (new TypedEvent format) assert len(events) >= 2 - assert "BidirectionalConnectionStart" in events[0] - assert events[0]["BidirectionalConnectionStart"]["connectionId"] == nova_model.prompt_name - assert "BidirectionalConnectionEnd" in events[-1] + assert events[0].get("type") == "bidirectional_session_start" + assert events[0].get("session_id") == nova_model.prompt_name + assert events[-1].get("type") == "bidirectional_session_end" @pytest.mark.asyncio async def test_event_conversion(nova_model): """Test conversion of all Nova Sonic event types to standard format.""" - # Test audio output + # Test audio output (now returns AudioStreamEvent) + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import AudioStreamEvent audio_bytes = b"test audio data" audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") nova_event = {"audioOutput": {"content": audio_base64}} result = nova_model._convert_nova_event(nova_event) assert result is not None - assert "audioOutput" in result - assert result["audioOutput"]["audioData"] == audio_bytes - assert result["audioOutput"]["format"] == "pcm" - assert result["audioOutput"]["sampleRate"] == 24000 - - # Test text output + assert isinstance(result, AudioStreamEvent) + assert result.get("type") == "bidirectional_audio_stream" + # Audio is kept as base64 string + assert result.get("audio") == audio_base64 + assert result.get("format") == "pcm" + assert result.get("sample_rate") == 24000 + + # Test text output (now returns TranscriptStreamEvent) + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import TranscriptStreamEvent nova_event = {"textOutput": {"content": "Hello, world!", "role": "ASSISTANT"}} result = nova_model._convert_nova_event(nova_event) assert result is not None - assert "textOutput" in result - assert result["textOutput"]["text"] == "Hello, world!" - assert result["textOutput"]["role"] == "assistant" + assert isinstance(result, TranscriptStreamEvent) + assert result.get("type") == "bidirectional_transcript_stream" + assert result.get("text") == "Hello, world!" + assert result.get("source") == "assistant" - # Test tool use + # Test tool use (now returns dict with tool_use) tool_input = {"location": "Seattle"} nova_event = { "toolUse": { @@ -264,19 +271,23 @@ async def test_event_conversion(nova_model): } result = nova_model._convert_nova_event(nova_event) assert result is not None - assert "toolUse" in result - assert result["toolUse"]["toolUseId"] == "tool-123" - assert result["toolUse"]["name"] == "get_weather" - assert result["toolUse"]["input"] == tool_input - - # Test interruption + assert result.get("type") == "tool_use" + tool_use = result.get("tool_use") + assert tool_use["toolUseId"] == "tool-123" + assert tool_use["name"] == "get_weather" + assert tool_use["input"] == tool_input + + # Test interruption (now returns InterruptionEvent) + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import InterruptionEvent nova_event = {"stopReason": "INTERRUPTED"} result = nova_model._convert_nova_event(nova_event) assert result is not None - assert "interruptionDetected" in result - assert result["interruptionDetected"]["reason"] == "user_input" + assert isinstance(result, InterruptionEvent) + assert result.get("type") == "bidirectional_interruption" + assert result.get("reason") == "user_speech" - # Test usage metrics + # Test usage metrics (now returns MultimodalUsage) + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import MultimodalUsage nova_event = { "usageEvent": { "totalTokens": 100, @@ -293,16 +304,19 @@ async def test_event_conversion(nova_model): } result = nova_model._convert_nova_event(nova_event) assert result is not None - assert "usageMetrics" in result - assert result["usageMetrics"]["totalTokens"] == 100 - assert result["usageMetrics"]["inputTokens"] == 40 - assert result["usageMetrics"]["outputTokens"] == 60 - assert result["usageMetrics"]["audioTokens"] == 30 - - # Test content start tracks role + assert isinstance(result, MultimodalUsage) + assert result.get("type") == "multimodal_usage" + assert result.get("totalTokens") == 100 + assert result.get("inputTokens") == 40 + assert result.get("outputTokens") == 60 + + # Test content start tracks role and emits TurnStartEvent + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import TurnStartEvent nova_event = {"contentStart": {"role": "USER"}} result = nova_model._convert_nova_event(nova_event) - assert result is None # contentStart doesn't emit an event + assert result is not None + assert isinstance(result, TurnStartEvent) + assert result.get("type") == "bidirectional_turn_start" assert nova_model._current_role == "USER" @@ -339,9 +353,11 @@ async def test_silence_detection(nova_model, mock_client, mock_stream): await nova_model.connect() - # Send audio to start connection + # Send audio to start connection (base64 encoded) + import base64 + audio_b64 = base64.b64encode(b"audio data").decode('utf-8') audio_event = AudioInputEvent( - audio=b"audio data", + audio=audio_b64, format="pcm", sample_rate=16000, channels=1 diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 7f799816a1..8640f58338 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -237,9 +237,10 @@ async def test_send_all_content_types(mock_websockets_connect, model): assert len(item_create) > 0 assert len(response_create) > 0 - # Test audio input + # Test audio input (base64 encoded) + audio_b64 = base64.b64encode(b"audio_bytes").decode('utf-8') audio_input = AudioInputEvent( - audio=b"audio_bytes", + audio=audio_b64, format="pcm", sample_rate=24000, channels=1, @@ -250,8 +251,8 @@ async def test_send_all_content_types(mock_websockets_connect, model): audio_append = [m for m in messages if m.get("type") == "input_audio_buffer.append"] assert len(audio_append) > 0 assert "audio" in audio_append[0] - decoded = base64.b64decode(audio_append[0]["audio"]) - assert decoded == b"audio_bytes" + # Audio should be passed through as base64 + assert audio_append[0]["audio"] == audio_b64 # Test tool result tool_result: ToolResult = { @@ -281,12 +282,12 @@ async def test_send_edge_cases(mock_websockets_connect, model): await model.send(text_input) mock_ws.send.assert_not_called() - # Test image input (not supported) + # Test image input (not supported, base64 encoded, no encoding parameter) await model.connect() + image_b64 = base64.b64encode(b"image_bytes").decode('utf-8') image_input = ImageInputEvent( - image=b"image_bytes", + image=image_b64, mime_type="image/jpeg", - encoding="raw", ) with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.logger") as mock_logger: await model.send(image_input) @@ -315,11 +316,12 @@ async def test_receive_lifecycle_events(mock_websockets_connect, model): receive_gen = model.receive() first_event = await anext(receive_gen) - # First event should be connection start - assert "BidirectionalConnectionStart" in first_event - assert first_event["BidirectionalConnectionStart"]["connectionId"] == model.session_id + # First event should be session start (new TypedEvent format) + assert first_event.get("type") == "bidirectional_session_start" + assert first_event.get("session_id") == model.session_id + assert first_event.get("model") == model.model - # Close to trigger connection end + # Close to trigger session end await model.close() # Collect remaining events @@ -330,8 +332,8 @@ async def test_receive_lifecycle_events(mock_websockets_connect, model): except StopAsyncIteration: pass - # Last event should be connection end - assert "BidirectionalConnectionEnd" in events[-1] + # Last event should be session end (new TypedEvent format) + assert events[-1].get("type") == "bidirectional_session_end" @pytest.mark.asyncio @@ -340,25 +342,29 @@ async def test_event_conversion(mock_websockets_connect, model): _, _ = mock_websockets_connect await model.connect() - # Test audio output + # Test audio output (now returns AudioStreamEvent) + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import AudioStreamEvent audio_event = { "type": "response.output_audio.delta", "delta": base64.b64encode(b"audio_data").decode() } converted = model._convert_openai_event(audio_event) - assert "audioOutput" in converted - assert converted["audioOutput"]["audioData"] == b"audio_data" - assert converted["audioOutput"]["format"] == "pcm" + assert isinstance(converted, AudioStreamEvent) + assert converted.get("type") == "bidirectional_audio_stream" + assert converted.get("audio") == base64.b64encode(b"audio_data").decode() + assert converted.get("format") == "pcm" - # Test text output + # Test text output (now returns TranscriptStreamEvent) + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import TranscriptStreamEvent text_event = { "type": "response.output_text.delta", "delta": "Hello from OpenAI" } converted = model._convert_openai_event(text_event) - assert "textOutput" in converted - assert converted["textOutput"]["text"] == "Hello from OpenAI" - assert converted["textOutput"]["role"] == "assistant" + assert isinstance(converted, TranscriptStreamEvent) + assert converted.get("type") == "bidirectional_transcript_stream" + assert converted.get("text") == "Hello from OpenAI" + assert converted.get("source") == "assistant" # Test function call sequence item_added = { @@ -383,18 +389,23 @@ async def test_event_conversion(mock_websockets_connect, model): "call_id": "call-123" } converted = model._convert_openai_event(args_done) - assert "toolUse" in converted - assert converted["toolUse"]["toolUseId"] == "call-123" - assert converted["toolUse"]["name"] == "calculator" - assert converted["toolUse"]["input"]["expression"] == "2+2" - - # Test voice activity + # Now returns dict with tool_use + assert isinstance(converted, dict) + assert converted.get("type") == "tool_use" + tool_use = converted.get("tool_use") + assert tool_use["toolUseId"] == "call-123" + assert tool_use["name"] == "calculator" + assert tool_use["input"]["expression"] == "2+2" + + # Test voice activity (now returns InterruptionEvent for speech_started) + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import InterruptionEvent speech_started = { "type": "input_audio_buffer.speech_started" } converted = model._convert_openai_event(speech_started) - assert "voiceActivity" in converted - assert converted["voiceActivity"]["activityType"] == "speech_started" + assert isinstance(converted, InterruptionEvent) + assert converted.get("type") == "bidirectional_interruption" + assert converted.get("reason") == "user_speech" await model.close() @@ -442,16 +453,23 @@ def test_helper_methods(model): assert model._require_active() is True model._active = False - # Test _create_text_event + # Test _create_text_event (now returns TranscriptStreamEvent) + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import TranscriptStreamEvent text_event = model._create_text_event("Hello", "user") - assert "textOutput" in text_event - assert text_event["textOutput"]["text"] == "Hello" - assert text_event["textOutput"]["role"] == "user" + assert isinstance(text_event, TranscriptStreamEvent) + assert text_event.get("type") == "bidirectional_transcript_stream" + assert text_event.get("text") == "Hello" + assert text_event.get("source") == "user" - # Test _create_voice_activity_event + # Test _create_voice_activity_event (now returns InterruptionEvent for speech_started) + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import InterruptionEvent voice_event = model._create_voice_activity_event("speech_started") - assert "voiceActivity" in voice_event - assert voice_event["voiceActivity"]["activityType"] == "speech_started" + assert isinstance(voice_event, InterruptionEvent) + assert voice_event.get("type") == "bidirectional_interruption" + assert voice_event.get("reason") == "user_speech" + + # Other voice activities return None + assert model._create_voice_activity_event("speech_stopped") is None @pytest.mark.asyncio From 0799be8f21f437ef0d3c09956eb17bf3c2f25046 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 3 Nov 2025 12:17:21 +0300 Subject: [PATCH 03/16] feat: add usage to openai --- .../bidirectional_streaming/models/openai.py | 114 +++++++++++++++--- .../models/test_openai_realtime.py | 46 ++++--- 2 files changed, 121 insertions(+), 39 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 15d1bbf86c..181c01c27b 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -291,9 +291,8 @@ async def receive(self) -> AsyncIterable[OutputEvent]: while self._active: try: openai_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) - provider_event = self._convert_openai_event(openai_event) - if provider_event: - yield provider_event + for event in self._convert_openai_event(openai_event) or []: + yield event except asyncio.TimeoutError: continue @@ -304,35 +303,41 @@ async def receive(self) -> AsyncIterable[OutputEvent]: # Emit session end event yield SessionEndEvent(reason="complete") - def _convert_openai_event(self, openai_event: dict[str, any]) -> OutputEvent | None: + def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEvent] | None: """Convert OpenAI events to Strands TypedEvent format.""" event_type = openai_event.get("type") + # Turn start - response begins + if event_type == "response.created": + response = openai_event.get("response", {}) + response_id = response.get("id", str(uuid.uuid4())) + return [TurnStartEvent(turn_id=response_id)] + # Audio output - if event_type == "response.output_audio.delta": + elif event_type == "response.output_audio.delta": # Audio is already base64 string from OpenAI - return AudioStreamEvent( + return [AudioStreamEvent( audio=openai_event["delta"], format="pcm", sample_rate=24000, channels=1 - ) + )] # Assistant text output events - combine multiple similar events elif event_type in ["response.output_text.delta", "response.output_audio_transcript.delta"]: - return self._create_text_event(openai_event["delta"], "assistant") + return [self._create_text_event(openai_event["delta"], "assistant")] # User transcription events - combine multiple similar events elif event_type in ["conversation.item.input_audio_transcription.delta", "conversation.item.input_audio_transcription.completed"]: text_key = "delta" if "delta" in event_type else "transcript" text = openai_event.get(text_key, "") - return self._create_text_event(text, "user") if text.strip() else None + return [self._create_text_event(text, "user")] if text.strip() else None elif event_type == "conversation.item.input_audio_transcription.segment": segment_data = openai_event.get("segment", {}) text = segment_data.get("text", "") - return self._create_text_event(text, "user") if text.strip() else None + return [self._create_text_event(text, "user")] if text.strip() else None elif event_type == "conversation.item.input_audio_transcription.failed": error_info = openai_event.get("error", {}) @@ -362,7 +367,7 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> OutputEvent | N } del self._function_call_buffer[call_id] # Return dict with tool_use for event loop processing - return {"type": "tool_use", "tool_use": tool_use} + return [{"type": "tool_use", "tool_use": tool_use}] except (json.JSONDecodeError, KeyError) as e: logger.warning("Error parsing function arguments for %s: %s", call_id, e) del self._function_call_buffer[call_id] @@ -377,7 +382,84 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> OutputEvent | N "input_audio_buffer.speech_stopped": "speech_stopped", "input_audio_buffer.timeout_triggered": "timeout" } - return self._create_voice_activity_event(activity_map[event_type]) + event = self._create_voice_activity_event(activity_map[event_type]) + return [event] if event else None + + # Turn complete and usage - response finished + elif event_type == "response.done": + response = openai_event.get("response", {}) + response_id = response.get("id", "unknown") + status = response.get("status", "completed") + usage = response.get("usage") + + # Map OpenAI status to our stop_reason + stop_reason_map = { + "completed": "complete", + "cancelled": "interrupted", + "failed": "error", + "incomplete": "interrupted" + } + + # Build list of events to return + events = [] + + # Always add turn complete event + events.append(TurnCompleteEvent( + turn_id=response_id, + stop_reason=stop_reason_map.get(status, "complete") + )) + + # Add usage event if available + if usage: + input_details = usage.get("input_token_details", {}) + output_details = usage.get("output_token_details", {}) + + # Build modality details + modality_details = [] + + # Text modality + text_input = input_details.get("text_tokens", 0) + text_output = output_details.get("text_tokens", 0) + if text_input > 0 or text_output > 0: + modality_details.append({ + "modality": "text", + "input_tokens": text_input, + "output_tokens": text_output + }) + + # Audio modality + audio_input = input_details.get("audio_tokens", 0) + audio_output = output_details.get("audio_tokens", 0) + if audio_input > 0 or audio_output > 0: + modality_details.append({ + "modality": "audio", + "input_tokens": audio_input, + "output_tokens": audio_output + }) + + # Image modality + image_input = input_details.get("image_tokens", 0) + if image_input > 0: + modality_details.append({ + "modality": "image", + "input_tokens": image_input, + "output_tokens": 0 + }) + + # Cached tokens + cached_tokens = input_details.get("cached_tokens", 0) + + # Add usage event + events.append(MultimodalUsage( + input_tokens=usage.get("input_tokens", 0), + output_tokens=usage.get("output_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + modality_details=modality_details if modality_details else None, + cache_read_input_tokens=cached_tokens if cached_tokens > 0 else None + )) + + # Return list of events + return events # Lifecycle events (log only) - combine multiple similar events elif event_type in ["conversation.item.retrieve", "conversation.item.added"]: @@ -388,14 +470,6 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> OutputEvent | N elif event_type == "conversation.item.done": logger.debug("OpenAI conversation item done: %s", openai_event.get("item", {}).get("id")) - # This event signals turn completion - emit TurnCompleteEvent - item = openai_event.get("item", {}) - if item.get("type") == "message" and item.get("role") == "assistant": - item_id = item.get("id", "unknown") - return TurnCompleteEvent( - turn_id=item_id, - stop_reason="complete" - ) return None # Response output events - combine similar events diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 8640f58338..60e88aa0fa 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -342,29 +342,33 @@ async def test_event_conversion(mock_websockets_connect, model): _, _ = mock_websockets_connect await model.connect() - # Test audio output (now returns AudioStreamEvent) + # Test audio output (now returns list with AudioStreamEvent) from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import AudioStreamEvent audio_event = { "type": "response.output_audio.delta", "delta": base64.b64encode(b"audio_data").decode() } converted = model._convert_openai_event(audio_event) - assert isinstance(converted, AudioStreamEvent) - assert converted.get("type") == "bidirectional_audio_stream" - assert converted.get("audio") == base64.b64encode(b"audio_data").decode() - assert converted.get("format") == "pcm" - - # Test text output (now returns TranscriptStreamEvent) + assert isinstance(converted, list) + assert len(converted) == 1 + assert isinstance(converted[0], AudioStreamEvent) + assert converted[0].get("type") == "bidirectional_audio_stream" + assert converted[0].get("audio") == base64.b64encode(b"audio_data").decode() + assert converted[0].get("format") == "pcm" + + # Test text output (now returns list with TranscriptStreamEvent) from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import TranscriptStreamEvent text_event = { "type": "response.output_text.delta", "delta": "Hello from OpenAI" } converted = model._convert_openai_event(text_event) - assert isinstance(converted, TranscriptStreamEvent) - assert converted.get("type") == "bidirectional_transcript_stream" - assert converted.get("text") == "Hello from OpenAI" - assert converted.get("source") == "assistant" + assert isinstance(converted, list) + assert len(converted) == 1 + assert isinstance(converted[0], TranscriptStreamEvent) + assert converted[0].get("type") == "bidirectional_transcript_stream" + assert converted[0].get("text") == "Hello from OpenAI" + assert converted[0].get("source") == "assistant" # Test function call sequence item_added = { @@ -389,23 +393,27 @@ async def test_event_conversion(mock_websockets_connect, model): "call_id": "call-123" } converted = model._convert_openai_event(args_done) - # Now returns dict with tool_use - assert isinstance(converted, dict) - assert converted.get("type") == "tool_use" - tool_use = converted.get("tool_use") + # Now returns list with dict containing tool_use + assert isinstance(converted, list) + assert len(converted) == 1 + assert isinstance(converted[0], dict) + assert converted[0].get("type") == "tool_use" + tool_use = converted[0].get("tool_use") assert tool_use["toolUseId"] == "call-123" assert tool_use["name"] == "calculator" assert tool_use["input"]["expression"] == "2+2" - # Test voice activity (now returns InterruptionEvent for speech_started) + # Test voice activity (now returns list with InterruptionEvent for speech_started) from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import InterruptionEvent speech_started = { "type": "input_audio_buffer.speech_started" } converted = model._convert_openai_event(speech_started) - assert isinstance(converted, InterruptionEvent) - assert converted.get("type") == "bidirectional_interruption" - assert converted.get("reason") == "user_speech" + assert isinstance(converted, list) + assert len(converted) == 1 + assert isinstance(converted[0], InterruptionEvent) + assert converted[0].get("type") == "bidirectional_interruption" + assert converted[0].get("reason") == "user_speech" await model.close() From a2f29b37fde4d566ec588bd615d8e200c96df54e Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 3 Nov 2025 12:17:58 +0300 Subject: [PATCH 04/16] feat: add usage to gemini --- .../models/gemini_live.py | 45 ++++++++++++++++++- .../models/test_gemini_live.py | 4 +- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 02044125f7..39e6deed72 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -30,6 +30,7 @@ ErrorEvent, ImageInputEvent, InterruptionEvent, + MultimodalUsage, SessionEndEvent, SessionStartEvent, TextInputEvent, @@ -205,6 +206,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic - inputTranscription: User's speech transcribed to text - outputTranscription: Model's audio transcribed to text - modelTurn text: Text response from the model + - usageMetadata: Token usage information """ try: # Handle interruption first (from server_content) @@ -267,7 +269,48 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic } return {"toolUse": tool_use_event} - # Silently ignore setup_complete, turn_complete, generation_complete, and usage_metadata messages + # Handle usage metadata + if hasattr(message, 'usage_metadata') and message.usage_metadata: + usage = message.usage_metadata + + # Build modality details from token details + modality_details = [] + + # Process prompt tokens details + if usage.prompt_tokens_details: + for detail in usage.prompt_tokens_details: + if detail.modality and detail.token_count: + modality_details.append({ + "modality": str(detail.modality).lower(), + "input_tokens": detail.token_count, + "output_tokens": 0 + }) + + # Process response tokens details + if usage.response_tokens_details: + for detail in usage.response_tokens_details: + if detail.modality and detail.token_count: + # Find or create modality entry + modality_str = str(detail.modality).lower() + existing = next((m for m in modality_details if m["modality"] == modality_str), None) + if existing: + existing["output_tokens"] = detail.token_count + else: + modality_details.append({ + "modality": modality_str, + "input_tokens": 0, + "output_tokens": detail.token_count + }) + + return MultimodalUsage( + input_tokens=usage.prompt_token_count or 0, + output_tokens=usage.response_token_count or 0, + total_tokens=usage.total_token_count or 0, + modality_details=modality_details if modality_details else None, + cache_read_input_tokens=usage.cached_content_token_count if usage.cached_content_token_count else None + ) + + # Silently ignore setup_complete and generation_complete messages return None except Exception as e: diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index f48f04910d..d3bf965f4e 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -292,7 +292,7 @@ async def test_event_conversion(mock_genai_client, model): _, _, _ = mock_genai_client await model.connect() - # Test text output (now converted to transcript) + # Test text output (converted to transcript) mock_text = unittest.mock.Mock() mock_text.text = "Hello from Gemini" mock_text.data = None @@ -305,7 +305,7 @@ async def test_event_conversion(mock_genai_client, model): assert text_event.source == "assistant" assert text_event.is_final is True - # Test audio output (now returns base64 encoded string) + # Test audio output (base64 encoded) import base64 mock_audio = unittest.mock.Mock() mock_audio.text = None From 44e5a631a015dcccbc60e38953acc121045352f8 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 4 Nov 2025 17:17:02 +0300 Subject: [PATCH 05/16] fix: update websockets in openai --- pyproject.toml | 4 +-- .../bidirectional_streaming/agent/agent.py | 33 ++++++++++++------- .../bidirectional_streaming/models/openai.py | 3 +- .../models/test_openai_realtime.py | 2 +- 4 files changed, 26 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e079ec263b..7810d09c7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,7 @@ bidirectional-streaming-nova = [ ] bidirectional-streaming-openai = [ "pyaudio>=0.2.13", - "websockets>=12.0,<14.0", + "websockets>=14.0,<16.0", ] bidirectional-streaming = [ "pyaudio>=0.2.13", @@ -71,7 +71,7 @@ bidirectional-streaming = [ "smithy-aws-core>=0.0.1", "pytz", "aws_sdk_bedrock_runtime", - "websockets>=12.0,<14.0", + "websockets>=14.0,<16.0", ] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 0b62d87fc8..bbe3f3da2c 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -31,7 +31,13 @@ from ....types.traces import AttributeValue from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection from ..models.bidirectional_model import BidirectionalModel -from ..types.bidirectional_streaming import AudioInputEvent, ImageInputEvent, OutputEvent +from ..types.bidirectional_streaming import ( + AudioInputEvent, + ImageInputEvent, + InputEvent, + OutputEvent, + TextInputEvent, +) logger = logging.getLogger(__name__) @@ -389,14 +395,18 @@ async def send(self, input_data: str | AudioInputEvent | ImageInputEvent | dict) # Add user text message to history self.messages.append({"role": "user", "content": input_data}) logger.debug("Text sent: %d characters", len(input_data)) - from ..types.bidirectional_streaming import TextInputEvent text_event = TextInputEvent(text=input_data, role="user") await self._session.model.send(text_event) return - # Handle dict - reconstruct TypedEvent for WebSocket integration + # Handle InputEvent instances (TextInputEvent, AudioInputEvent, ImageInputEvent) + # Check this before dict since TypedEvent inherits from dict + if isinstance(input_data, (TextInputEvent, AudioInputEvent, ImageInputEvent)): + await self._session.model.send(input_data) + return + + # Handle plain dict - reconstruct TypedEvent for WebSocket integration if isinstance(input_data, dict) and "type" in input_data: - from ..types.bidirectional_streaming import TextInputEvent event_type = input_data["type"] if event_type == "bidirectional_text_input": input_data = TextInputEvent(text=input_data["text"], role=input_data["role"]) @@ -414,14 +424,15 @@ async def send(self, input_data: str | AudioInputEvent | ImageInputEvent | dict) ) else: raise ValueError(f"Unknown event type: {event_type}") - - # Handle TypedEvent instances - if isinstance(input_data, (AudioInputEvent, ImageInputEvent, TextInputEvent)): + + # Send the reconstructed TypedEvent await self._session.model.send(input_data) - else: - raise ValueError( - f"Input must be a string, TypedEvent, or event dict, got: {type(input_data)}" - ) + return + + # If we get here, input type is invalid + raise ValueError( + f"Input must be a string, InputEvent (TextInputEvent/AudioInputEvent/ImageInputEvent), or event dict with 'type' field, got: {type(input_data)}" + ) async def receive(self) -> AsyncIterable[dict[str, Any]]: """Receive events from the model including audio, text, and tool calls. diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 1c4318defa..92da277291 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -13,7 +13,6 @@ from typing import AsyncIterable, Union import websockets -from websockets.client import WebSocketClientProtocol from websockets.exceptions import ConnectionClosed from ....types.content import Messages @@ -149,7 +148,7 @@ async def connect( if self.project: headers.append(("OpenAI-Project", self.project)) - self.websocket = await websockets.connect(url, extra_headers=headers) + self.websocket = await websockets.connect(url, additional_headers=headers) logger.info("WebSocket connected successfully") # Configure session diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 6909a1f5c2..60e88aa0fa 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -174,7 +174,7 @@ async def test_connection_lifecycle(mock_websockets_connect, model, system_promp model_org = OpenAIRealtimeModel(api_key="test-key", organization="org-123") await model_org.connect() call_kwargs = mock_connect.call_args.kwargs - headers = call_kwargs.get("extra_headers", []) + headers = call_kwargs.get("additional_headers", []) org_header = [h for h in headers if h[0] == "OpenAI-Organization"] assert len(org_header) == 1 assert org_header[0][1] == "org-123" From 631b620ea85c0dcfcecd1f40b54cf9f50883a099 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 4 Nov 2025 17:17:12 +0300 Subject: [PATCH 06/16] fix: fix integ test --- .../test_bidirectional_agent.py | 41 +++++++++++++++---- .../utils/audio_generator.py | 10 ++++- .../utils/test_context.py | 24 ++++++++--- 3 files changed, 61 insertions(+), 14 deletions(-) diff --git a/tests_integ/bidirectional_streaming/test_bidirectional_agent.py b/tests_integ/bidirectional_streaming/test_bidirectional_agent.py index 9ae64514d6..80b32b1782 100644 --- a/tests_integ/bidirectional_streaming/test_bidirectional_agent.py +++ b/tests_integ/bidirectional_streaming/test_bidirectional_agent.py @@ -11,29 +11,56 @@ import os import pytest -from strands_tools import calculator +from strands import tool from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel -from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeBidirectionalModel -from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveBidirectionalModel +from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicModel +from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeModel +from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveModel from .utils.test_context import BidirectionalTestContext logger = logging.getLogger(__name__) +# Simple calculator tool for testing +@tool +def calculator(operation: str, x: float, y: float) -> float: + """Perform basic arithmetic operations. + + Args: + operation: The operation to perform (add, subtract, multiply, divide) + x: First number + y: Second number + + Returns: + Result of the operation + """ + if operation == "add": + return x + y + elif operation == "subtract": + return x - y + elif operation == "multiply": + return x * y + elif operation == "divide": + if y == 0: + raise ValueError("Cannot divide by zero") + return x / y + else: + raise ValueError(f"Unknown operation: {operation}") + + # Provider configurations PROVIDER_CONFIGS = { "nova_sonic": { - "model_class": NovaSonicBidirectionalModel, + "model_class": NovaSonicModel, "model_kwargs": {"region": "us-east-1"}, "silence_duration": 2.5, # Nova Sonic needs 2+ seconds of silence "env_vars": ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"], "skip_reason": "AWS credentials not available", }, "openai": { - "model_class": OpenAIRealtimeBidirectionalModel, + "model_class": OpenAIRealtimeModel, "model_kwargs": { "model": "gpt-4o-realtime-preview-2024-12-17", "session": { @@ -60,7 +87,7 @@ # The model responds with audio but the test infrastructure expects text/transcripts # TODO: Fix Gemini Live event emission to yield both transcript and audio events # "gemini_live": { - # "model_class": GeminiLiveBidirectionalModel, + # "model_class": GeminiLiveModel, # "model_kwargs": { # "model_id": "gemini-2.5-flash-native-audio-preview-09-2025", # "params": { diff --git a/tests_integ/bidirectional_streaming/utils/audio_generator.py b/tests_integ/bidirectional_streaming/utils/audio_generator.py index 605a2aaa90..c3ad3f965c 100644 --- a/tests_integ/bidirectional_streaming/utils/audio_generator.py +++ b/tests_integ/bidirectional_streaming/utils/audio_generator.py @@ -120,10 +120,16 @@ def create_audio_input_event( Returns: AudioInputEvent dict ready for agent.send(). """ + import base64 + + # Convert bytes to base64 string for JSON compatibility + audio_b64 = base64.b64encode(audio_data).decode('utf-8') + return { - "audioData": audio_data, + "type": "bidirectional_audio_input", + "audio": audio_b64, "format": format, - "sampleRate": sample_rate, + "sample_rate": sample_rate, "channels": channels, } diff --git a/tests_integ/bidirectional_streaming/utils/test_context.py b/tests_integ/bidirectional_streaming/utils/test_context.py index 4c91e2fc1b..687aef1b5a 100644 --- a/tests_integ/bidirectional_streaming/utils/test_context.py +++ b/tests_integ/bidirectional_streaming/utils/test_context.py @@ -227,19 +227,24 @@ def get_events(self, event_type: str | None = None) -> list[dict]: def get_text_outputs(self) -> list[str]: """Extract text outputs from collected events. - Handles both textOutput events (Nova Sonic, OpenAI) and transcript events (Gemini Live). + Handles both new TypedEvent format and legacy event formats. Returns: List of text content strings. """ texts = [] for event in self.get_events(): # Drain queue first - # Handle textOutput events (Nova Sonic, OpenAI) - if "textOutput" in event: + # Handle new TypedEvent format (bidirectional_transcript_stream) + if event.get("type") == "bidirectional_transcript_stream": + text = event.get("text", "") + if text: + texts.append(text) + # Handle legacy textOutput events (Nova Sonic, OpenAI) + elif "textOutput" in event: text = event["textOutput"].get("text", "") if text: texts.append(text) - # Handle transcript events (Gemini Live) + # Handle legacy transcript events (Gemini Live) elif "transcript" in event: text = event["transcript"].get("text", "") if text: @@ -252,11 +257,20 @@ def get_audio_outputs(self) -> list[bytes]: Returns: List of audio data bytes. """ + import base64 + # Drain queue first to get latest events events = self.get_events() audio_data = [] for event in events: - if "audioOutput" in event: + # Handle new TypedEvent format (bidirectional_audio_stream) + if event.get("type") == "bidirectional_audio_stream": + audio_b64 = event.get("audio") + if audio_b64: + # Decode base64 to bytes + audio_data.append(base64.b64decode(audio_b64)) + # Handle legacy audioOutput events + elif "audioOutput" in event: data = event["audioOutput"].get("audioData") if data: audio_data.append(data) From 6e33d9a4bb3c28834816e6e34f23489410becd43 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 4 Nov 2025 17:55:28 +0300 Subject: [PATCH 07/16] feat(types): add tool use types --- .../bidirectional_streaming/__init__.py | 12 ++++++++++++ .../event_loop/bidirectional_event_loop.py | 18 +++++++++++++----- .../models/gemini_live.py | 8 ++++++-- .../models/novasonic.py | 9 ++++++--- .../bidirectional_streaming/models/openai.py | 9 ++++++--- 5 files changed, 43 insertions(+), 13 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index f75834a76c..678dfc0d4d 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -30,6 +30,13 @@ TurnStartEvent, ) +# Re-export standard agent events for tool handling +from ...types._events import ( + ToolResultEvent, + ToolStreamEvent, + ToolUseStreamEvent, +) + __all__ = [ # Main interface "BidirectionalAgent", @@ -58,6 +65,11 @@ "ErrorEvent", "OutputEvent", + # Tool Event types (reused from standard agent) + "ToolUseStreamEvent", + "ToolResultEvent", + "ToolStreamEvent", + # Model interface "BidirectionalModel", ] diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index b3b1ee8ca9..e618245e14 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -281,12 +281,15 @@ async def _process_model_events(session: BidirectionalConnection) -> None: continue # Queue tool requests for concurrent execution - if event_type == "tool_use": - tool_use = strands_event.get("tool_use") + # Check for ToolUseStreamEvent (standard agent event) + if "current_tool_use" in strands_event: + tool_use = strands_event.get("current_tool_use") if tool_use: tool_name = tool_use.get("name") logger.debug("Tool usage detected: %s", tool_name) await session.tool_queue.put(tool_use) + # Forward ToolUseStreamEvent to output queue for client visibility + await session.agent._output_queue.put(strands_event) continue # Send all output events to Agent for receive() method @@ -436,14 +439,19 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: tool_result = tool_event.tool_result tool_use_id = tool_result.get("toolUseId") - # Send ToolResultEvent through send() method + # Send ToolResultEvent through send() method to model await session.model.send(tool_event) - logger.debug("Tool result sent: %s", tool_use_id) + logger.debug("Tool result sent to model: %s", tool_use_id) + + # Also forward ToolResultEvent to output queue for client visibility + await session.agent._output_queue.put(tool_event.as_dict()) + logger.debug("Tool result sent to client: %s", tool_use_id) # Handle streaming events if needed later elif isinstance(tool_event, ToolStreamEvent): logger.debug("Tool stream event: %s", tool_event) - pass + # Forward tool stream events to output queue + await session.agent._output_queue.put(tool_event.as_dict()) # Add tool result message to conversation history if tool_results: diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index ac546e010e..1475edaac1 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -23,7 +23,7 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse -from ....types._events import ToolResultEvent +from ....types._events import ToolResultEvent, ToolUseStreamEvent from ..types.bidirectional_streaming import ( AudioInputEvent, AudioStreamEvent, @@ -267,7 +267,11 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic "name": func_call.name, "input": func_call.args or {} } - return {"toolUse": tool_use_event} + # Return ToolUseStreamEvent for consistency with standard agent + return ToolUseStreamEvent( + delta={"toolUse": tool_use_event}, + current_tool_use=tool_use_event + ) # Handle usage metadata if hasattr(message, 'usage_metadata') and message.usage_metadata: diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index d054142fbd..033eff4e99 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -32,7 +32,7 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse -from ....types._events import ToolResultEvent +from ....types._events import ToolResultEvent, ToolUseStreamEvent from ..types.bidirectional_streaming import ( AudioInputEvent, AudioStreamEvent, @@ -547,8 +547,11 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: "name": tool_use["toolName"], "input": json.loads(tool_use["content"]), } - # Return dict with tool_use for event loop processing - return {"type": "tool_use", "tool_use": tool_use_event} + # Return ToolUseStreamEvent for consistency with standard agent + return ToolUseStreamEvent( + delta={"toolUse": tool_use_event}, + current_tool_use=tool_use_event + ) # Handle interruption elif nova_event.get("stopReason") == "INTERRUPTED": diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 92da277291..393deb0bd4 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -17,7 +17,7 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse -from ....types._events import ToolResultEvent +from ....types._events import ToolResultEvent, ToolUseStreamEvent from ..types.bidirectional_streaming import ( AudioInputEvent, AudioStreamEvent, @@ -365,8 +365,11 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEven "input": json.loads(function_call["arguments"]) if function_call["arguments"] else {}, } del self._function_call_buffer[call_id] - # Return dict with tool_use for event loop processing - return [{"type": "tool_use", "tool_use": tool_use}] + # Return ToolUseStreamEvent for consistency with standard agent + return [ToolUseStreamEvent( + delta={"toolUse": tool_use}, + current_tool_use=tool_use + )] except (json.JSONDecodeError, KeyError) as e: logger.warning("Error parsing function arguments for %s: %s", call_id, e) del self._function_call_buffer[call_id] From bd5401f081443dcb433f6be6bf1c85a69e9391a3 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 6 Nov 2025 13:44:50 +0300 Subject: [PATCH 08/16] Return typed dicts on agent and refactor error event --- .../bidirectional_streaming/agent/agent.py | 13 +++---- .../event_loop/bidirectional_event_loop.py | 4 +- .../types/bidirectional_streaming.py | 38 ++++++++++++------- 3 files changed, 31 insertions(+), 24 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index bbe3f3da2c..f0205f8a8d 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -434,24 +434,21 @@ async def send(self, input_data: str | AudioInputEvent | ImageInputEvent | dict) f"Input must be a string, InputEvent (TextInputEvent/AudioInputEvent/ImageInputEvent), or event dict with 'type' field, got: {type(input_data)}" ) - async def receive(self) -> AsyncIterable[dict[str, Any]]: + async def receive(self) -> AsyncIterable["OutputEvent"]: """Receive events from the model including audio, text, and tool calls. Yields model output events processed by background tasks including audio output, text responses, tool calls, and session updates. Yields: - dict: Event dictionaries from the model session. Each event is a TypedEvent - converted to a dictionary for consistency with the standard Agent API. + OutputEvent: TypedEvent objects from the model session. Events are + JSON-serializable by default (use json.dumps(event) for transport). """ while self._session and self._session.active: try: event = await asyncio.wait_for(self._output_queue.get(), timeout=0.1) - # Convert TypedEvent to dict for consistency with Agent.stream_async - if hasattr(event, 'as_dict'): - yield event.as_dict() - else: - yield event + # Return TypedEvent objects directly (JSON-serializable by default) + yield event except asyncio.TimeoutError: continue diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index e618245e14..8af2515ef2 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -444,14 +444,14 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: logger.debug("Tool result sent to model: %s", tool_use_id) # Also forward ToolResultEvent to output queue for client visibility - await session.agent._output_queue.put(tool_event.as_dict()) + await session.agent._output_queue.put(tool_event) logger.debug("Tool result sent to client: %s", tool_use_id) # Handle streaming events if needed later elif isinstance(tool_event, ToolStreamEvent): logger.debug("Tool stream event: %s", tool_event) # Forward tool stream events to output queue - await session.agent._output_queue.put(tool_event.as_dict()) + await session.agent._output_queue.put(tool_event) # Add tool result message to conversation history if tool_results: diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 160b15a27a..c7ca1515f8 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -427,44 +427,54 @@ def reason(self) -> str: class ErrorEvent(TypedEvent): """Error occurred during the session. - Similar to strands.types._events.ForceStopEvent, this event wraps exceptions - that occur during bidirectional streaming sessions. - - Note: The Exception object is not stored in the event data to maintain JSON - serializability. Only the error message, code, and details are stored. + Stores the full Exception object as an instance attribute for debugging while + keeping the event dict JSON-serializable. The exception can be accessed via + the `error` property for re-raising or type-based error handling. Parameters: - error: The exception that occurred (used to extract message and type). - code: Optional error code for programmatic handling (defaults to exception class name). + error: The exception that occurred. details: Optional additional error information. """ def __init__( self, error: Exception, - code: Optional[str] = None, details: Optional[Dict[str, Any]] = None, ): + # Store serializable data in dict (for JSON serialization) super().__init__( { "type": "bidirectional_error", - "error_message": str(error), - "error_code": code or type(error).__name__, - "error_details": details, + "message": str(error), + "code": type(error).__name__, + "details": details, } ) + # Store exception as instance attribute (not serialized) + self._error = error + + @property + def error(self) -> Exception: + """The original exception that occurred. + + Can be used for re-raising or type-based error handling. + """ + return self._error @property def code(self) -> str: - return cast(str, self.get("error_code")) + """Error code derived from exception class name.""" + return cast(str, self.get("code")) @property def message(self) -> str: - return cast(str, self.get("error_message")) + """Human-readable error message from the exception.""" + return cast(str, self.get("message")) @property def details(self) -> Optional[Dict[str, Any]]: - return cast(Optional[Dict[str, Any]], self.get("error_details")) + """Additional error context beyond the exception itself.""" + return cast(Optional[Dict[str, Any]], self.get("details")) # ============================================================================ From 8a74a93c963f13419ac5a453551b36f72840c6b5 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 6 Nov 2025 14:09:56 +0300 Subject: [PATCH 09/16] refactor: change multimodal usage event to usage event --- .../experimental/bidirectional_streaming/__init__.py | 4 ++-- .../bidirectional_streaming/models/gemini_live.py | 4 ++-- .../bidirectional_streaming/models/novasonic.py | 4 ++-- .../bidirectional_streaming/models/openai.py | 4 ++-- .../bidirectional_streaming/types/__init__.py | 4 ++-- .../types/bidirectional_streaming.py | 11 ++++++----- .../bidirectional_streaming/models/test_novasonic.py | 8 ++++---- 7 files changed, 20 insertions(+), 19 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 678dfc0d4d..86a1139d0a 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -20,7 +20,7 @@ InputEvent, InterruptionEvent, ModalityUsage, - MultimodalUsage, + UsageEvent, OutputEvent, SessionEndEvent, SessionStartEvent, @@ -59,7 +59,7 @@ "TranscriptStreamEvent", "InterruptionEvent", "TurnCompleteEvent", - "MultimodalUsage", + "UsageEvent", "ModalityUsage", "SessionEndEvent", "ErrorEvent", diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 1475edaac1..da3387e1d1 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -30,7 +30,7 @@ ErrorEvent, ImageInputEvent, InterruptionEvent, - MultimodalUsage, + UsageEvent, SessionEndEvent, SessionStartEvent, TextInputEvent, @@ -306,7 +306,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic "output_tokens": detail.token_count }) - return MultimodalUsage( + return UsageEvent( input_tokens=usage.prompt_token_count or 0, output_tokens=usage.response_token_count or 0, total_tokens=usage.total_token_count or 0, diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 033eff4e99..c6790e5062 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -39,7 +39,7 @@ ErrorEvent, ImageInputEvent, InterruptionEvent, - MultimodalUsage, + UsageEvent, OutputEvent, SessionEndEvent, SessionStartEvent, @@ -564,7 +564,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: total_input = usage_data.get("totalInputTokens", 0) total_output = usage_data.get("totalOutputTokens", 0) - return MultimodalUsage( + return UsageEvent( input_tokens=total_input, output_tokens=total_output, total_tokens=usage_data.get("totalTokens", total_input + total_output) diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 393deb0bd4..016ce32b15 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -24,7 +24,7 @@ ErrorEvent, ImageInputEvent, InterruptionEvent, - MultimodalUsage, + UsageEvent, OutputEvent, SessionEndEvent, SessionStartEvent, @@ -452,7 +452,7 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEven cached_tokens = input_details.get("cached_tokens", 0) # Add usage event - events.append(MultimodalUsage( + events.append(UsageEvent( input_tokens=usage.get("input_tokens", 0), output_tokens=usage.get("output_tokens", 0), total_tokens=usage.get("total_tokens", 0), diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py index 52034db1b3..2721ceab99 100644 --- a/src/strands/experimental/bidirectional_streaming/types/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -14,7 +14,7 @@ InputEvent, InterruptionEvent, ModalityUsage, - MultimodalUsage, + UsageEvent, OutputEvent, SessionEndEvent, SessionStartEvent, @@ -37,7 +37,7 @@ "TranscriptStreamEvent", "InterruptionEvent", "TurnCompleteEvent", - "MultimodalUsage", + "UsageEvent", "ModalityUsage", "SessionEndEvent", "ErrorEvent", diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index c7ca1515f8..fbec640b88 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -345,10 +345,11 @@ class ModalityUsage(dict): output_tokens: int -class MultimodalUsage(TypedEvent): - """Token usage event with modality breakdown for multimodal streaming. +class UsageEvent(TypedEvent): + """Token usage event with modality breakdown for bidirectional streaming. - Combines TypedEvent behavior with Usage fields for a unified event type. + Tracks token consumption across different modalities (audio, text, images) + during bidirectional streaming sessions. Parameters: input_tokens: Total tokens used for all input modalities. @@ -369,7 +370,7 @@ def __init__( cache_write_input_tokens: Optional[int] = None, ): data: Dict[str, Any] = { - "type": "multimodal_usage", + "type": "bidirectional_usage", "inputTokens": input_tokens, "outputTokens": output_tokens, "totalTokens": total_tokens, @@ -492,7 +493,7 @@ def details(self) -> Optional[Dict[str, Any]]: TranscriptStreamEvent, InterruptionEvent, TurnCompleteEvent, - MultimodalUsage, + UsageEvent, SessionEndEvent, ErrorEvent, ] diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 6c77457c26..1e07eb4494 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -286,8 +286,8 @@ async def test_event_conversion(nova_model): assert result.get("type") == "bidirectional_interruption" assert result.get("reason") == "user_speech" - # Test usage metrics (now returns MultimodalUsage) - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import MultimodalUsage + # Test usage metrics (now returns UsageEvent) + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import UsageEvent nova_event = { "usageEvent": { "totalTokens": 100, @@ -304,8 +304,8 @@ async def test_event_conversion(nova_model): } result = nova_model._convert_nova_event(nova_event) assert result is not None - assert isinstance(result, MultimodalUsage) - assert result.get("type") == "multimodal_usage" + assert isinstance(result, UsageEvent) + assert result.get("type") == "bidirectional_usage" assert result.get("totalTokens") == 100 assert result.get("inputTokens") == 40 assert result.get("outputTokens") == 60 From 433c610e188f725df95efe1c559026afd09dc655 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 6 Nov 2025 14:39:04 +0300 Subject: [PATCH 10/16] refactor(bidi): change session terminology to connection --- .../bidirectional_streaming/__init__.py | 8 ++-- .../models/gemini_live.py | 18 ++++---- .../models/novasonic.py | 42 +++++++++--------- .../bidirectional_streaming/models/openai.py | 18 ++++---- .../bidirectional_streaming/types/__init__.py | 8 ++-- .../types/bidirectional_streaming.py | 43 ++++++++++++------- .../models/test_novasonic.py | 4 +- 7 files changed, 77 insertions(+), 64 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 86a1139d0a..1b901b0a28 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -15,6 +15,8 @@ from .types.bidirectional_streaming import ( AudioInputEvent, AudioStreamEvent, + ConnectionCloseEvent, + ConnectionStartEvent, ErrorEvent, ImageInputEvent, InputEvent, @@ -22,8 +24,6 @@ ModalityUsage, UsageEvent, OutputEvent, - SessionEndEvent, - SessionStartEvent, TextInputEvent, TranscriptStreamEvent, TurnCompleteEvent, @@ -53,7 +53,8 @@ "InputEvent", # Output Event types - "SessionStartEvent", + "ConnectionStartEvent", + "ConnectionCloseEvent", "TurnStartEvent", "AudioStreamEvent", "TranscriptStreamEvent", @@ -61,7 +62,6 @@ "TurnCompleteEvent", "UsageEvent", "ModalityUsage", - "SessionEndEvent", "ErrorEvent", "OutputEvent", diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index da3387e1d1..5819b84eb3 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -27,12 +27,12 @@ from ..types.bidirectional_streaming import ( AudioInputEvent, AudioStreamEvent, + ConnectionCloseEvent, + ConnectionStartEvent, ErrorEvent, ImageInputEvent, InterruptionEvent, UsageEvent, - SessionEndEvent, - SessionStartEvent, TextInputEvent, TranscriptStreamEvent, TurnCompleteEvent, @@ -89,7 +89,7 @@ def __init__( # Connection state (initialized in connect()) self.live_session = None self.live_session_context_manager = None - self.session_id = None + self.connection_id = None self._active = False async def connect( @@ -112,7 +112,7 @@ async def connect( try: # Initialize connection state - self.session_id = str(uuid.uuid4()) + self.connection_id = str(uuid.uuid4()) self._active = True # Build live config @@ -163,9 +163,9 @@ async def _send_message_history(self, messages: Messages) -> None: async def receive(self) -> AsyncIterable[Dict[str, Any]]: """Receive Gemini Live API events and convert to provider-agnostic format.""" - # Emit session start event - yield SessionStartEvent( - session_id=self.session_id, + # Emit connection start event + yield ConnectionStartEvent( + connection_id=self.connection_id, model=self.model_id, capabilities=["audio", "tools", "images"] ) @@ -196,8 +196,8 @@ async def receive(self) -> AsyncIterable[Dict[str, Any]]: logger.error("Fatal error in receive loop: %s", e) yield ErrorEvent(error=e) finally: - # Emit session end event when exiting - yield SessionEndEvent(reason="complete") + # Emit connection close event when exiting + yield ConnectionCloseEvent(connection_id=self.connection_id, reason="complete") def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dict[str, Any]]: """Convert Gemini Live API events to provider-agnostic format. diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index c6790e5062..da57e0c579 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -36,13 +36,13 @@ from ..types.bidirectional_streaming import ( AudioInputEvent, AudioStreamEvent, + ConnectionCloseEvent, + ConnectionStartEvent, ErrorEvent, ImageInputEvent, InterruptionEvent, UsageEvent, OutputEvent, - SessionEndEvent, - SessionStartEvent, TextInputEvent, TranscriptStreamEvent, TurnCompleteEvent, @@ -111,7 +111,7 @@ def __init__( # Connection state (initialized in connect()) self.stream = None - self.session_id = None + self.connection_id = None self._active = False # Nova Sonic requires unique content names @@ -155,7 +155,7 @@ async def connect( await self._initialize_client() # Initialize connection state - self.session_id = str(uuid.uuid4()) + self.connection_id = str(uuid.uuid4()) self._active = True self.audio_content_name = str(uuid.uuid4()) self._event_queue = asyncio.Queue() @@ -170,7 +170,7 @@ async def connect( logger.error("Stream is None") raise ValueError("Stream cannot be None") - logger.debug("Nova Sonic connection initialized with session: %s", self.session_id) + logger.debug("Nova Sonic connection initialized with connection_id: %s", self.connection_id) # Send initialization events system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." @@ -272,9 +272,9 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: logger.debug("Nova events - starting event stream") - # Emit session start event - yield SessionStartEvent( - session_id=self.session_id, + # Emit connection start event + yield ConnectionStartEvent( + connection_id=self.connection_id, model=self.model_id, capabilities=["audio", "tools"] ) @@ -299,8 +299,8 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: logger.error(traceback.format_exc()) yield ErrorEvent(error=e) finally: - # Emit session end event - yield SessionEndEvent(reason="complete") + # Emit connection close event + yield ConnectionCloseEvent(connection_id=self.connection_id, reason="complete") async def send( self, @@ -345,7 +345,7 @@ async def _start_audio_connection(self) -> None: { "event": { "contentStart": { - "promptName": self.session_id, + "promptName": self.connection_id, "contentName": self.audio_content_name, "type": "AUDIO", "interactive": True, @@ -376,7 +376,7 @@ async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: { "event": { "audioInput": { - "promptName": self.session_id, + "promptName": self.connection_id, "contentName": self.audio_content_name, "content": audio_input.audio, } @@ -409,7 +409,7 @@ async def _end_audio_input(self) -> None: logger.debug("Nova audio connection end") audio_content_end = json.dumps( - {"event": {"contentEnd": {"promptName": self.session_id, "contentName": self.audio_content_name}}} + {"event": {"contentEnd": {"promptName": self.connection_id, "contentName": self.audio_content_name}}} ) await self._send_nova_event(audio_content_end) @@ -434,7 +434,7 @@ async def _send_interrupt(self) -> None: { "event": { "audioInput": { - "promptName": self.session_id, + "promptName": self.connection_id, "contentName": self.audio_content_name, "stopReason": "INTERRUPTED", } @@ -600,7 +600,7 @@ def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: prompt_start_event = { "event": { "promptStart": { - "promptName": self.session_id, + "promptName": self.connection_id, "textOutputConfiguration": NOVA_TEXT_CONFIG, "audioOutputConfiguration": NOVA_AUDIO_OUTPUT_CONFIG, } @@ -644,7 +644,7 @@ def _get_text_content_start_event(self, content_name: str, role: str = "USER") - { "event": { "contentStart": { - "promptName": self.session_id, + "promptName": self.connection_id, "contentName": content_name, "type": "TEXT", "role": role, @@ -661,7 +661,7 @@ def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> { "event": { "contentStart": { - "promptName": self.session_id, + "promptName": self.connection_id, "contentName": content_name, "interactive": False, "type": "TOOL", @@ -679,7 +679,7 @@ def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> def _get_text_input_event(self, content_name: str, text: str) -> str: """Generate text input event.""" return json.dumps( - {"event": {"textInput": {"promptName": self.session_id, "contentName": content_name, "content": text}}} + {"event": {"textInput": {"promptName": self.connection_id, "contentName": content_name, "content": text}}} ) def _get_tool_result_event(self, content_name: str, result: dict[str, any]) -> str: @@ -688,7 +688,7 @@ def _get_tool_result_event(self, content_name: str, result: dict[str, any]) -> s { "event": { "toolResult": { - "promptName": self.session_id, + "promptName": self.connection_id, "contentName": content_name, "content": json.dumps(result), } @@ -698,11 +698,11 @@ def _get_tool_result_event(self, content_name: str, result: dict[str, any]) -> s def _get_content_end_event(self, content_name: str) -> str: """Generate content end event.""" - return json.dumps({"event": {"contentEnd": {"promptName": self.session_id, "contentName": content_name}}}) + return json.dumps({"event": {"contentEnd": {"promptName": self.connection_id, "contentName": content_name}}}) def _get_prompt_end_event(self) -> str: """Generate prompt end event.""" - return json.dumps({"event": {"promptEnd": {"promptName": self.session_id}}}) + return json.dumps({"event": {"promptEnd": {"promptName": self.connection_id}}}) def _get_connection_end_event(self) -> str: """Generate connection end event.""" diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 016ce32b15..52a3cdf790 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -21,13 +21,13 @@ from ..types.bidirectional_streaming import ( AudioInputEvent, AudioStreamEvent, + ConnectionCloseEvent, + ConnectionStartEvent, ErrorEvent, ImageInputEvent, InterruptionEvent, UsageEvent, OutputEvent, - SessionEndEvent, - SessionStartEvent, TextInputEvent, TranscriptStreamEvent, TurnCompleteEvent, @@ -103,7 +103,7 @@ def __init__( # Connection state (initialized in connect()) self.websocket = None - self.session_id = None + self.connection_id = None self._active = False self._event_queue = None @@ -134,7 +134,7 @@ async def connect( try: # Initialize connection state - self.session_id = str(uuid.uuid4()) + self.connection_id = str(uuid.uuid4()) self._active = True self._event_queue = asyncio.Queue() self._function_call_buffer = {} @@ -279,9 +279,9 @@ async def _process_responses(self) -> None: async def receive(self) -> AsyncIterable[OutputEvent]: """Receive OpenAI events and convert to Strands TypedEvent format.""" - # Emit session start event - yield SessionStartEvent( - session_id=self.session_id, + # Emit connection start event + yield ConnectionStartEvent( + connection_id=self.connection_id, model=self.model, capabilities=["audio", "tools"] ) @@ -299,8 +299,8 @@ async def receive(self) -> AsyncIterable[OutputEvent]: logger.error("Error receiving OpenAI Realtime event: %s", e) yield ErrorEvent(error=e) finally: - # Emit session end event - yield SessionEndEvent(reason="complete") + # Emit connection close event + yield ConnectionCloseEvent(connection_id=self.connection_id, reason="complete") def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEvent] | None: """Convert OpenAI events to Strands TypedEvent format.""" diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py index 2721ceab99..9ab16dd38a 100644 --- a/src/strands/experimental/bidirectional_streaming/types/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -9,6 +9,8 @@ SUPPORTED_SAMPLE_RATES, AudioInputEvent, AudioStreamEvent, + ConnectionCloseEvent, + ConnectionStartEvent, ErrorEvent, ImageInputEvent, InputEvent, @@ -16,8 +18,6 @@ ModalityUsage, UsageEvent, OutputEvent, - SessionEndEvent, - SessionStartEvent, TextInputEvent, TranscriptStreamEvent, TurnCompleteEvent, @@ -31,7 +31,8 @@ "ImageInputEvent", "InputEvent", # Output Events - "SessionStartEvent", + "ConnectionStartEvent", + "ConnectionCloseEvent", "TurnStartEvent", "AudioStreamEvent", "TranscriptStreamEvent", @@ -39,7 +40,6 @@ "TurnCompleteEvent", "UsageEvent", "ModalityUsage", - "SessionEndEvent", "ErrorEvent", "OutputEvent", # Constants diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index fbec640b88..5dea738d9e 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -148,28 +148,28 @@ def mime_type(self) -> str: # ============================================================================ -class SessionStartEvent(TypedEvent): - """Session established and ready for interaction. +class ConnectionStartEvent(TypedEvent): + """Streaming connection established and ready for interaction. Parameters: - session_id: Unique identifier for this session. + connection_id: Unique identifier for this streaming connection. model: Model identifier (e.g., "gpt-realtime", "gemini-2.0-flash-live"). capabilities: List of supported features (e.g., ["audio", "tools", "images"]). """ - def __init__(self, session_id: str, model: str, capabilities: List[str]): + def __init__(self, connection_id: str, model: str, capabilities: List[str]): super().__init__( { - "type": "bidirectional_session_start", - "session_id": session_id, + "type": "bidirectional_connection_start", + "connection_id": connection_id, "model": model, "capabilities": capabilities, } ) @property - def session_id(self) -> str: - return cast(str, self.get("session_id")) + def connection_id(self) -> str: + return cast(str, self.get("connection_id")) @property def model(self) -> str: @@ -408,17 +408,30 @@ def cache_write_input_tokens(self) -> Optional[int]: return cast(Optional[int], self.get("cacheWriteInputTokens")) -class SessionEndEvent(TypedEvent): - """Session terminated. +class ConnectionCloseEvent(TypedEvent): + """Streaming connection closed. Parameters: - reason: Why the session ended. + connection_id: Unique identifier for this streaming connection (matches ConnectionStartEvent). + reason: Why the connection was closed. """ def __init__( - self, reason: Literal["client_disconnect", "timeout", "error", "complete"] + self, + connection_id: str, + reason: Literal["client_disconnect", "timeout", "error", "complete"], ): - super().__init__({"type": "bidirectional_session_end", "reason": reason}) + super().__init__( + { + "type": "bidirectional_connection_close", + "connection_id": connection_id, + "reason": reason, + } + ) + + @property + def connection_id(self) -> str: + return cast(str, self.get("connection_id")) @property def reason(self) -> str: @@ -487,13 +500,13 @@ def details(self) -> Optional[Dict[str, Any]]: InputEvent = Union[TextInputEvent, AudioInputEvent, ImageInputEvent] OutputEvent = Union[ - SessionStartEvent, + ConnectionStartEvent, TurnStartEvent, AudioStreamEvent, TranscriptStreamEvent, InterruptionEvent, TurnCompleteEvent, UsageEvent, - SessionEndEvent, + ConnectionCloseEvent, ErrorEvent, ] diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 1e07eb4494..851afd92a5 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -415,12 +415,12 @@ async def test_event_templates(nova_model): assert "inferenceConfiguration" in event["event"]["sessionStart"] # Test prompt start event - nova_model.session_id = "test-session" + nova_model.connection_id = "test-connection" event_json = nova_model._get_prompt_start_event([]) event = json.loads(event_json) assert "event" in event assert "promptStart" in event["event"] - assert event["event"]["promptStart"]["promptName"] == "test-session" + assert event["event"]["promptStart"]["promptName"] == "test-connection" # Test text input event content_name = "test-content" From e3389fb4d2ff243833c25c83f4fab7383bee8f7c Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 6 Nov 2025 15:01:38 +0300 Subject: [PATCH 11/16] refactor: rename turn events to response events --- .../bidirectional_streaming/__init__.py | 8 ++-- .../models/gemini_live.py | 10 ++--- .../models/novasonic.py | 14 +++---- .../bidirectional_streaming/models/openai.py | 14 +++---- .../bidirectional_streaming/types/__init__.py | 8 ++-- .../types/bidirectional_streaming.py | 42 +++++++++---------- .../models/test_gemini_live.py | 2 +- .../models/test_novasonic.py | 2 +- .../models/test_openai_realtime.py | 4 +- 9 files changed, 52 insertions(+), 52 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 1b901b0a28..31b9ead326 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -24,10 +24,10 @@ ModalityUsage, UsageEvent, OutputEvent, + ResponseCompleteEvent, + ResponseStartEvent, TextInputEvent, TranscriptStreamEvent, - TurnCompleteEvent, - TurnStartEvent, ) # Re-export standard agent events for tool handling @@ -55,11 +55,11 @@ # Output Event types "ConnectionStartEvent", "ConnectionCloseEvent", - "TurnStartEvent", + "ResponseStartEvent", + "ResponseCompleteEvent", "AudioStreamEvent", "TranscriptStreamEvent", "InterruptionEvent", - "TurnCompleteEvent", "UsageEvent", "ModalityUsage", "ErrorEvent", diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 5819b84eb3..ad2ca678dc 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -35,8 +35,8 @@ UsageEvent, TextInputEvent, TranscriptStreamEvent, - TurnCompleteEvent, - TurnStartEvent, + ResponseCompleteEvent, + ResponseStartEvent, ) from .bidirectional_model import BidirectionalModel @@ -222,7 +222,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic logger.debug(f"Input transcription detected: {transcription_text}") return TranscriptStreamEvent( text=transcription_text, - source="user", + role="user", is_final=True ) @@ -235,7 +235,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic logger.debug(f"Output transcription detected: {transcription_text}") return TranscriptStreamEvent( text=transcription_text, - source="assistant", + role="assistant", is_final=True ) @@ -244,7 +244,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic logger.debug(f"Text output as transcript: {message.text}") return TranscriptStreamEvent( text=message.text, - source="assistant", + role="assistant", is_final=True ) diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index da57e0c579..f180850201 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -45,8 +45,8 @@ OutputEvent, TextInputEvent, TranscriptStreamEvent, - TurnCompleteEvent, - TurnStartEvent, + ResponseCompleteEvent, + ResponseStartEvent, ) from .bidirectional_model import BidirectionalModel @@ -535,7 +535,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: return TranscriptStreamEvent( text=text_content, - source="user" if role == "USER" else "assistant", + role="user" if role == "USER" else "assistant", is_final=True ) @@ -575,14 +575,14 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: role = nova_event["contentStart"].get("role", "unknown") # Store role for subsequent text output events self._current_role = role - # Emit turn start event - return TurnStartEvent(turn_id=str(uuid.uuid4())) + # Emit response start event + return ResponseStartEvent(response_id=str(uuid.uuid4())) # Handle content stop events elif "contentStop" in nova_event: stop_reason = nova_event["contentStop"].get("stopReason", "complete") - return TurnCompleteEvent( - turn_id=str(uuid.uuid4()), + return ResponseCompleteEvent( + response_id=str(uuid.uuid4()), stop_reason="interrupted" if stop_reason == "INTERRUPTED" else "complete" ) diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 52a3cdf790..33c89ba6c0 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -30,8 +30,8 @@ OutputEvent, TextInputEvent, TranscriptStreamEvent, - TurnCompleteEvent, - TurnStartEvent, + ResponseCompleteEvent, + ResponseStartEvent, ) from .bidirectional_model import BidirectionalModel @@ -176,7 +176,7 @@ def _create_text_event(self, text: str, role: str) -> TranscriptStreamEvent: """Create standardized transcript event.""" return TranscriptStreamEvent( text=text, - source="user" if role == "user" else "assistant", + role="user" if role == "user" else "assistant", is_final=True ) @@ -310,7 +310,7 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEven if event_type == "response.created": response = openai_event.get("response", {}) response_id = response.get("id", str(uuid.uuid4())) - return [TurnStartEvent(turn_id=response_id)] + return [ResponseStartEvent(response_id=response_id)] # Audio output elif event_type == "response.output_audio.delta": @@ -405,9 +405,9 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEven # Build list of events to return events = [] - # Always add turn complete event - events.append(TurnCompleteEvent( - turn_id=response_id, + # Always add response complete event + events.append(ResponseCompleteEvent( + response_id=response_id, stop_reason=stop_reason_map.get(status, "complete") )) diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py index 9ab16dd38a..0a2abb68f3 100644 --- a/src/strands/experimental/bidirectional_streaming/types/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -18,10 +18,10 @@ ModalityUsage, UsageEvent, OutputEvent, + ResponseCompleteEvent, + ResponseStartEvent, TextInputEvent, TranscriptStreamEvent, - TurnCompleteEvent, - TurnStartEvent, ) __all__ = [ @@ -33,11 +33,11 @@ # Output Events "ConnectionStartEvent", "ConnectionCloseEvent", - "TurnStartEvent", + "ResponseStartEvent", + "ResponseCompleteEvent", "AudioStreamEvent", "TranscriptStreamEvent", "InterruptionEvent", - "TurnCompleteEvent", "UsageEvent", "ModalityUsage", "ErrorEvent", diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 5dea738d9e..5641200e70 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -180,19 +180,19 @@ def capabilities(self) -> List[str]: return cast(List[str], self.get("capabilities")) -class TurnStartEvent(TypedEvent): +class ResponseStartEvent(TypedEvent): """Model starts generating a response. Parameters: - turn_id: Unique identifier for this turn (used in turn.complete). + response_id: Unique identifier for this response (used in response.complete). """ - def __init__(self, turn_id: str): - super().__init__({"type": "bidirectional_turn_start", "turn_id": turn_id}) + def __init__(self, response_id: str): + super().__init__({"type": "bidirectional_response_start", "response_id": response_id}) @property - def turn_id(self) -> str: - return cast(str, self.get("turn_id")) + def response_id(self) -> str: + return cast(str, self.get("response_id")) class AudioStreamEvent(TypedEvent): @@ -244,18 +244,18 @@ class TranscriptStreamEvent(TypedEvent): Parameters: text: Transcribed text from audio. - source: Who is speaking ("user" or "assistant"). + role: Who is speaking ("user" or "assistant"). Aligns with Message.role convention. is_final: Whether this is the final/complete transcript. """ def __init__( - self, text: str, source: Literal["user", "assistant"], is_final: bool + self, text: str, role: Literal["user", "assistant"], is_final: bool ): super().__init__( { "type": "bidirectional_transcript_stream", "text": text, - "source": source, + "role": role, "is_final": is_final, } ) @@ -265,8 +265,8 @@ def text(self) -> str: return cast(str, self.get("text")) @property - def source(self) -> str: - return cast(str, self.get("source")) + def role(self) -> str: + return cast(str, self.get("role")) @property def is_final(self) -> bool: @@ -301,30 +301,30 @@ def turn_id(self) -> Optional[str]: return cast(Optional[str], self.get("turn_id")) -class TurnCompleteEvent(TypedEvent): +class ResponseCompleteEvent(TypedEvent): """Model finished generating response. Parameters: - turn_id: ID of the turn that completed (matches turn.start). - stop_reason: Why the turn ended. + response_id: ID of the response that completed (matches response.start). + stop_reason: Why the response ended. """ def __init__( self, - turn_id: str, + response_id: str, stop_reason: Literal["complete", "interrupted", "tool_use", "error"], ): super().__init__( { - "type": "bidirectional_turn_complete", - "turn_id": turn_id, + "type": "bidirectional_response_complete", + "response_id": response_id, "stop_reason": stop_reason, } ) @property - def turn_id(self) -> str: - return cast(str, self.get("turn_id")) + def response_id(self) -> str: + return cast(str, self.get("response_id")) @property def stop_reason(self) -> str: @@ -501,11 +501,11 @@ def details(self) -> Optional[Dict[str, Any]]: OutputEvent = Union[ ConnectionStartEvent, - TurnStartEvent, + ResponseStartEvent, AudioStreamEvent, TranscriptStreamEvent, InterruptionEvent, - TurnCompleteEvent, + ResponseCompleteEvent, UsageEvent, ConnectionCloseEvent, ErrorEvent, diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index d3bf965f4e..5f63193188 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -302,7 +302,7 @@ async def test_event_conversion(mock_genai_client, model): text_event = model._convert_gemini_live_event(mock_text) assert isinstance(text_event, TranscriptStreamEvent) assert text_event.text == "Hello from Gemini" - assert text_event.source == "assistant" + assert text_event.role == "assistant" assert text_event.is_final is True # Test audio output (base64 encoded) diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 851afd92a5..feb320d91c 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -258,7 +258,7 @@ async def test_event_conversion(nova_model): assert isinstance(result, TranscriptStreamEvent) assert result.get("type") == "bidirectional_transcript_stream" assert result.get("text") == "Hello, world!" - assert result.get("source") == "assistant" + assert result.get("role") == "assistant" # Test tool use (now returns dict with tool_use) tool_input = {"location": "Seattle"} diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 60e88aa0fa..98c520fdb7 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -368,7 +368,7 @@ async def test_event_conversion(mock_websockets_connect, model): assert isinstance(converted[0], TranscriptStreamEvent) assert converted[0].get("type") == "bidirectional_transcript_stream" assert converted[0].get("text") == "Hello from OpenAI" - assert converted[0].get("source") == "assistant" + assert converted[0].get("role") == "assistant" # Test function call sequence item_added = { @@ -467,7 +467,7 @@ def test_helper_methods(model): assert isinstance(text_event, TranscriptStreamEvent) assert text_event.get("type") == "bidirectional_transcript_stream" assert text_event.get("text") == "Hello" - assert text_event.get("source") == "user" + assert text_event.get("role") == "user" # Test _create_voice_activity_event (now returns InterruptionEvent for speech_started) from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import InterruptionEvent From 5877c5fce07435d35ed24b9c42cbc0c10681ac66 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 6 Nov 2025 15:22:10 +0300 Subject: [PATCH 12/16] fix: fix bidi tests --- .../models/test_gemini_live.py | 20 +++++++------- .../models/test_novasonic.py | 26 ++++++++++--------- .../models/test_openai_realtime.py | 21 ++++++++------- 3 files changed, 36 insertions(+), 31 deletions(-) diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index 5f63193188..25f11c23c5 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -116,7 +116,7 @@ async def test_connection_lifecycle(mock_genai_client, model, system_prompt, too # Test basic connection await model.connect() assert model._active is True - assert model.session_id is not None + assert model.connection_id is not None assert model.live_session == mock_live_session mock_client.aio.live.connect.assert_called_once() @@ -256,8 +256,8 @@ async def test_send_edge_cases(mock_genai_client, model): async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): """Test that receive() emits connection start and end events.""" from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( - SessionStartEvent, - SessionEndEvent, + ConnectionStartEvent, + ConnectionCloseEvent, ) _, mock_live_session, _ = mock_genai_client @@ -275,9 +275,9 @@ async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): # Verify connection start and end assert len(events) >= 2 - assert isinstance(events[0], SessionStartEvent) - assert events[0].session_id == model.session_id - assert isinstance(events[-1], SessionEndEvent) + assert isinstance(events[0], ConnectionStartEvent) + assert events[0].connection_id == model.connection_id + assert isinstance(events[-1], ConnectionCloseEvent) @pytest.mark.asyncio @@ -336,9 +336,11 @@ async def test_event_conversion(mock_genai_client, model): mock_tool.server_content = None tool_event = model._convert_gemini_live_event(mock_tool) - assert "toolUse" in tool_event - assert tool_event["toolUse"]["toolUseId"] == "tool-123" - assert tool_event["toolUse"]["name"] == "calculator" + # ToolUseStreamEvent has delta and current_tool_use, not a "type" field + assert "delta" in tool_event + assert "toolUse" in tool_event["delta"] + assert tool_event["delta"]["toolUse"]["toolUseId"] == "tool-123" + assert tool_event["delta"]["toolUse"]["name"] == "calculator" # Test interruption mock_server_content = unittest.mock.Mock() diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index feb320d91c..3865eb353d 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -72,7 +72,7 @@ async def test_model_initialization(model_id, region): assert model.region == region assert model.stream is None assert not model._active - assert model.session_id is None + assert model.connection_id is None @pytest.mark.asyncio @@ -85,7 +85,7 @@ async def test_connection_lifecycle(nova_model, mock_client, mock_stream): await nova_model.connect(system_prompt="Test system prompt") assert nova_model._active assert nova_model.stream == mock_stream - assert nova_model.session_id is not None + assert nova_model.connection_id is not None assert mock_client.invoke_model_with_bidirectional_stream.called # Test close @@ -228,9 +228,9 @@ async def mock_wait_for(*args, **kwargs): # Should have session start and end (new TypedEvent format) assert len(events) >= 2 - assert events[0].get("type") == "bidirectional_session_start" - assert events[0].get("session_id") == nova_model.session_id - assert events[-1].get("type") == "bidirectional_session_end" + assert events[0].get("type") == "bidirectional_connection_start" + assert events[0].get("connection_id") == nova_model.connection_id + assert events[-1].get("type") == "bidirectional_connection_close" @pytest.mark.asyncio @@ -260,7 +260,7 @@ async def test_event_conversion(nova_model): assert result.get("text") == "Hello, world!" assert result.get("role") == "assistant" - # Test tool use (now returns dict with tool_use) + # Test tool use (now returns ToolUseStreamEvent from core strands) tool_input = {"location": "Seattle"} nova_event = { "toolUse": { @@ -271,8 +271,10 @@ async def test_event_conversion(nova_model): } result = nova_model._convert_nova_event(nova_event) assert result is not None - assert result.get("type") == "tool_use" - tool_use = result.get("tool_use") + # ToolUseStreamEvent has delta and current_tool_use, not a "type" field + assert "delta" in result + assert "toolUse" in result["delta"] + tool_use = result["delta"]["toolUse"] assert tool_use["toolUseId"] == "tool-123" assert tool_use["name"] == "get_weather" assert tool_use["input"] == tool_input @@ -310,13 +312,13 @@ async def test_event_conversion(nova_model): assert result.get("inputTokens") == 40 assert result.get("outputTokens") == 60 - # Test content start tracks role and emits TurnStartEvent - from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import TurnStartEvent + # Test content start tracks role and emits ResponseStartEvent + from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ResponseStartEvent nova_event = {"contentStart": {"role": "USER"}} result = nova_model._convert_nova_event(nova_event) assert result is not None - assert isinstance(result, TurnStartEvent) - assert result.get("type") == "bidirectional_turn_start" + assert isinstance(result, ResponseStartEvent) + assert result.get("type") == "bidirectional_response_start" assert nova_model._current_role == "USER" diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 98c520fdb7..a1c7e65cbe 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -130,7 +130,7 @@ async def test_connection_lifecycle(mock_websockets_connect, model, system_promp # Test basic connection await model.connect() assert model._active is True - assert model.session_id is not None + assert model.connection_id is not None assert model.websocket == mock_ws assert model._event_queue is not None assert model._response_task is not None @@ -316,9 +316,9 @@ async def test_receive_lifecycle_events(mock_websockets_connect, model): receive_gen = model.receive() first_event = await anext(receive_gen) - # First event should be session start (new TypedEvent format) - assert first_event.get("type") == "bidirectional_session_start" - assert first_event.get("session_id") == model.session_id + # First event should be connection start (new TypedEvent format) + assert first_event.get("type") == "bidirectional_connection_start" + assert first_event.get("connection_id") == model.connection_id assert first_event.get("model") == model.model # Close to trigger session end @@ -332,8 +332,8 @@ async def test_receive_lifecycle_events(mock_websockets_connect, model): except StopAsyncIteration: pass - # Last event should be session end (new TypedEvent format) - assert events[-1].get("type") == "bidirectional_session_end" + # Last event should be connection close (new TypedEvent format) + assert events[-1].get("type") == "bidirectional_connection_close" @pytest.mark.asyncio @@ -393,12 +393,13 @@ async def test_event_conversion(mock_websockets_connect, model): "call_id": "call-123" } converted = model._convert_openai_event(args_done) - # Now returns list with dict containing tool_use + # Now returns list with ToolUseStreamEvent assert isinstance(converted, list) assert len(converted) == 1 - assert isinstance(converted[0], dict) - assert converted[0].get("type") == "tool_use" - tool_use = converted[0].get("tool_use") + # ToolUseStreamEvent has delta and current_tool_use, not a "type" field + assert "delta" in converted[0] + assert "toolUse" in converted[0]["delta"] + tool_use = converted[0]["delta"]["toolUse"] assert tool_use["toolUseId"] == "call-123" assert tool_use["name"] == "calculator" assert tool_use["input"]["expression"] == "2+2" From 1e0e65ae8124b5d2985f72b3bcefb355d63593f7 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 6 Nov 2025 15:38:34 +0300 Subject: [PATCH 13/16] feat: add json serialization tests for events --- .../bidirectional_streaming/types/__init__.py | 1 + .../types/test_bidirectional_streaming.py | 108 ++++++++++++++++++ 2 files changed, 109 insertions(+) create mode 100644 tests/strands/experimental/bidirectional_streaming/types/__init__.py create mode 100644 tests/strands/experimental/bidirectional_streaming/types/test_bidirectional_streaming.py diff --git a/tests/strands/experimental/bidirectional_streaming/types/__init__.py b/tests/strands/experimental/bidirectional_streaming/types/__init__.py new file mode 100644 index 0000000000..a1330e552c --- /dev/null +++ b/tests/strands/experimental/bidirectional_streaming/types/__init__.py @@ -0,0 +1 @@ +"""Tests for bidirectional streaming types.""" diff --git a/tests/strands/experimental/bidirectional_streaming/types/test_bidirectional_streaming.py b/tests/strands/experimental/bidirectional_streaming/types/test_bidirectional_streaming.py new file mode 100644 index 0000000000..0efde88230 --- /dev/null +++ b/tests/strands/experimental/bidirectional_streaming/types/test_bidirectional_streaming.py @@ -0,0 +1,108 @@ +"""Tests for bidirectional streaming event types. + +This module tests JSON serialization for all bidirectional streaming event types. +""" + +import base64 +import json + +import pytest + +from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( + AudioInputEvent, + AudioStreamEvent, + ConnectionCloseEvent, + ConnectionStartEvent, + ErrorEvent, + ImageInputEvent, + InterruptionEvent, + ResponseCompleteEvent, + ResponseStartEvent, + TextInputEvent, + TranscriptStreamEvent, + UsageEvent, +) + + +@pytest.mark.parametrize( + "event_class,kwargs,expected_type", + [ + # Input events + (TextInputEvent, {"text": "Hello", "role": "user"}, "bidirectional_text_input"), + ( + AudioInputEvent, + { + "audio": base64.b64encode(b"audio").decode("utf-8"), + "format": "pcm", + "sample_rate": 16000, + "channels": 1, + }, + "bidirectional_audio_input", + ), + ( + ImageInputEvent, + {"image": base64.b64encode(b"image").decode("utf-8"), "mime_type": "image/jpeg"}, + "bidirectional_image_input", + ), + # Output events + ( + ConnectionStartEvent, + {"connection_id": "c1", "model": "m1", "capabilities": ["audio"]}, + "bidirectional_connection_start", + ), + (ResponseStartEvent, {"response_id": "r1"}, "bidirectional_response_start"), + ( + AudioStreamEvent, + { + "audio": base64.b64encode(b"audio").decode("utf-8"), + "format": "pcm", + "sample_rate": 24000, + "channels": 1, + }, + "bidirectional_audio_stream", + ), + ( + TranscriptStreamEvent, + {"text": "Hello", "role": "assistant", "is_final": True}, + "bidirectional_transcript_stream", + ), + (InterruptionEvent, {"reason": "user_speech", "turn_id": None}, "bidirectional_interruption"), + ( + ResponseCompleteEvent, + {"response_id": "r1", "stop_reason": "complete"}, + "bidirectional_response_complete", + ), + ( + UsageEvent, + {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, + "bidirectional_usage", + ), + ( + ConnectionCloseEvent, + {"connection_id": "c1", "reason": "complete"}, + "bidirectional_connection_close", + ), + (ErrorEvent, {"error": ValueError("test"), "details": None}, "bidirectional_error"), + ], +) +def test_event_json_serialization(event_class, kwargs, expected_type): + """Test that all event types are JSON serializable and deserializable.""" + # Create event + event = event_class(**kwargs) + + # Verify type field + assert event["type"] == expected_type + + # Serialize to JSON + json_str = json.dumps(event) + + # Deserialize back + data = json.loads(json_str) + + # Verify type preserved + assert data["type"] == expected_type + + # Verify all non-private keys preserved + for key in event.keys(): + if not key.startswith("_"): + assert key in data From 4d091a4659bca8658a7ae6a1416f8a7abdd2f6c3 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 6 Nov 2025 16:04:22 +0300 Subject: [PATCH 14/16] refactor: transcript events to extend model stream event --- .../models/gemini_live.py | 15 +++-- .../models/novasonic.py | 7 ++- .../bidirectional_streaming/models/openai.py | 12 ++-- .../types/bidirectional_streaming.py | 40 +++++++++---- .../models/test_gemini_live.py | 2 + .../models/test_novasonic.py | 2 + .../models/test_openai_realtime.py | 5 ++ .../types/test_bidirectional_streaming.py | 59 ++++++++++++++++++- 8 files changed, 115 insertions(+), 27 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index ad2ca678dc..29b18da9e6 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -166,8 +166,7 @@ async def receive(self) -> AsyncIterable[Dict[str, Any]]: # Emit connection start event yield ConnectionStartEvent( connection_id=self.connection_id, - model=self.model_id, - capabilities=["audio", "tools", "images"] + model=self.model_id ) try: @@ -221,9 +220,11 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic transcription_text = input_transcript.text logger.debug(f"Input transcription detected: {transcription_text}") return TranscriptStreamEvent( + delta={"text": transcription_text}, text=transcription_text, role="user", - is_final=True + is_final=True, + current_transcript=transcription_text ) # Handle output transcription (model's audio) - emit as transcript event @@ -234,18 +235,22 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic transcription_text = output_transcript.text logger.debug(f"Output transcription detected: {transcription_text}") return TranscriptStreamEvent( + delta={"text": transcription_text}, text=transcription_text, role="assistant", - is_final=True + is_final=True, + current_transcript=transcription_text ) # Handle text output from model if message.text: logger.debug(f"Text output as transcript: {message.text}") return TranscriptStreamEvent( + delta={"text": message.text}, text=message.text, role="assistant", - is_final=True + is_final=True, + current_transcript=message.text ) # Handle audio output using SDK's built-in data property diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index f180850201..3b4419586f 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -275,8 +275,7 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: # Emit connection start event yield ConnectionStartEvent( connection_id=self.connection_id, - model=self.model_id, - capabilities=["audio", "tools"] + model=self.model_id ) try: @@ -534,9 +533,11 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None: return InterruptionEvent(reason="user_speech", turn_id=None) return TranscriptStreamEvent( + delta={"text": text_content}, text=text_content, role="user" if role == "USER" else "assistant", - is_final=True + is_final=True, + current_transcript=text_content ) # Handle tool use diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 33c89ba6c0..1f072ac875 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -172,12 +172,14 @@ def _require_active(self) -> bool: """Check if session is active.""" return self._active - def _create_text_event(self, text: str, role: str) -> TranscriptStreamEvent: + def _create_text_event(self, text: str, role: str, is_final: bool = True) -> TranscriptStreamEvent: """Create standardized transcript event.""" return TranscriptStreamEvent( + delta={"text": text}, text=text, role="user" if role == "user" else "assistant", - is_final=True + is_final=is_final, + current_transcript=text if is_final else None ) def _create_voice_activity_event(self, activity_type: str) -> InterruptionEvent | None: @@ -282,8 +284,7 @@ async def receive(self) -> AsyncIterable[OutputEvent]: # Emit connection start event yield ConnectionStartEvent( connection_id=self.connection_id, - model=self.model, - capabilities=["audio", "tools"] + model=self.model ) try: @@ -331,7 +332,8 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEven "conversation.item.input_audio_transcription.completed"]: text_key = "delta" if "delta" in event_type else "transcript" text = openai_event.get(text_key, "") - return [self._create_text_event(text, "user")] if text.strip() else None + is_final = "completed" in event_type + return [self._create_text_event(text, "user", is_final=is_final)] if text.strip() else None elif event_type == "conversation.item.input_audio_transcription.segment": segment_data = openai_event.get("segment", {}) diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 5641200e70..355e78c2f4 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -21,7 +21,8 @@ from typing import Any, Dict, List, Literal, Optional, Union, cast -from ....types._events import TypedEvent +from ....types._events import ModelStreamEvent, TypedEvent +from ....types.streaming import ContentBlockDelta # Audio format constants SUPPORTED_AUDIO_FORMATS = ["pcm", "wav", "opus", "mp3"] @@ -154,16 +155,14 @@ class ConnectionStartEvent(TypedEvent): Parameters: connection_id: Unique identifier for this streaming connection. model: Model identifier (e.g., "gpt-realtime", "gemini-2.0-flash-live"). - capabilities: List of supported features (e.g., ["audio", "tools", "images"]). """ - def __init__(self, connection_id: str, model: str, capabilities: List[str]): + def __init__(self, connection_id: str, model: str): super().__init__( { "type": "bidirectional_connection_start", "connection_id": connection_id, "model": model, - "capabilities": capabilities, } ) @@ -175,10 +174,6 @@ def connection_id(self) -> str: def model(self) -> str: return cast(str, self.get("model")) - @property - def capabilities(self) -> List[str]: - return cast(List[str], self.get("capabilities")) - class ResponseStartEvent(TypedEvent): """Model starts generating a response. @@ -239,27 +234,44 @@ def channels(self) -> int: return cast(int, self.get("channels")) -class TranscriptStreamEvent(TypedEvent): - """Audio transcription of speech (user or assistant). +class TranscriptStreamEvent(ModelStreamEvent): + """Audio transcription streaming (user or assistant speech). + + Follows the same delta + current state pattern as TextStreamEvent and ToolUseStreamEvent + from core Strands. Supports incremental transcript updates for providers like OpenAI + that send partial transcripts before the final version. Parameters: - text: Transcribed text from audio. + delta: The incremental transcript change (ContentBlockDelta). + text: The delta text (same as delta content for convenience). role: Who is speaking ("user" or "assistant"). Aligns with Message.role convention. is_final: Whether this is the final/complete transcript. + current_transcript: The accumulated transcript text so far (None for first delta). """ def __init__( - self, text: str, role: Literal["user", "assistant"], is_final: bool + self, + delta: ContentBlockDelta, + text: str, + role: Literal["user", "assistant"], + is_final: bool, + current_transcript: Optional[str] = None, ): super().__init__( { "type": "bidirectional_transcript_stream", + "delta": delta, "text": text, "role": role, "is_final": is_final, + "current_transcript": current_transcript, } ) + @property + def delta(self) -> ContentBlockDelta: + return cast(ContentBlockDelta, self.get("delta")) + @property def text(self) -> str: return cast(str, self.get("text")) @@ -272,6 +284,10 @@ def role(self) -> str: def is_final(self) -> bool: return cast(bool, self.get("is_final")) + @property + def current_transcript(self) -> Optional[str]: + return cast(Optional[str], self.get("current_transcript")) + class InterruptionEvent(TypedEvent): """Model generation was interrupted. diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index 25f11c23c5..107a8a84a7 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -304,6 +304,8 @@ async def test_event_conversion(mock_genai_client, model): assert text_event.text == "Hello from Gemini" assert text_event.role == "assistant" assert text_event.is_final is True + assert text_event.delta == {"text": "Hello from Gemini"} + assert text_event.current_transcript == "Hello from Gemini" # Test audio output (base64 encoded) import base64 diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 3865eb353d..1a2fef4267 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -259,6 +259,8 @@ async def test_event_conversion(nova_model): assert result.get("type") == "bidirectional_transcript_stream" assert result.get("text") == "Hello, world!" assert result.get("role") == "assistant" + assert result.delta == {"text": "Hello, world!"} + assert result.current_transcript == "Hello, world!" # Test tool use (now returns ToolUseStreamEvent from core strands) tool_input = {"location": "Seattle"} diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index a1c7e65cbe..2045424e19 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -369,6 +369,8 @@ async def test_event_conversion(mock_websockets_connect, model): assert converted[0].get("type") == "bidirectional_transcript_stream" assert converted[0].get("text") == "Hello from OpenAI" assert converted[0].get("role") == "assistant" + assert converted[0].delta == {"text": "Hello from OpenAI"} + assert converted[0].is_final is True # Test function call sequence item_added = { @@ -469,6 +471,9 @@ def test_helper_methods(model): assert text_event.get("type") == "bidirectional_transcript_stream" assert text_event.get("text") == "Hello" assert text_event.get("role") == "user" + assert text_event.delta == {"text": "Hello"} + assert text_event.is_final is True + assert text_event.current_transcript == "Hello" # Test _create_voice_activity_event (now returns InterruptionEvent for speech_started) from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import InterruptionEvent diff --git a/tests/strands/experimental/bidirectional_streaming/types/test_bidirectional_streaming.py b/tests/strands/experimental/bidirectional_streaming/types/test_bidirectional_streaming.py index 0efde88230..b6290cfcfa 100644 --- a/tests/strands/experimental/bidirectional_streaming/types/test_bidirectional_streaming.py +++ b/tests/strands/experimental/bidirectional_streaming/types/test_bidirectional_streaming.py @@ -47,7 +47,7 @@ # Output events ( ConnectionStartEvent, - {"connection_id": "c1", "model": "m1", "capabilities": ["audio"]}, + {"connection_id": "c1", "model": "m1"}, "bidirectional_connection_start", ), (ResponseStartEvent, {"response_id": "r1"}, "bidirectional_response_start"), @@ -63,7 +63,13 @@ ), ( TranscriptStreamEvent, - {"text": "Hello", "role": "assistant", "is_final": True}, + { + "delta": {"text": "Hello"}, + "text": "Hello", + "role": "assistant", + "is_final": True, + "current_transcript": "Hello", + }, "bidirectional_transcript_stream", ), (InterruptionEvent, {"reason": "user_speech", "turn_id": None}, "bidirectional_interruption"), @@ -106,3 +112,52 @@ def test_event_json_serialization(event_class, kwargs, expected_type): for key in event.keys(): if not key.startswith("_"): assert key in data + + + +def test_transcript_stream_event_delta_pattern(): + """Test that TranscriptStreamEvent follows ModelStreamEvent delta pattern.""" + # Test partial transcript (delta) + partial_event = TranscriptStreamEvent( + delta={"text": "Hello"}, + text="Hello", + role="user", + is_final=False, + current_transcript=None, + ) + + assert partial_event.text == "Hello" + assert partial_event.role == "user" + assert partial_event.is_final is False + assert partial_event.current_transcript is None + assert partial_event.delta == {"text": "Hello"} + + # Test final transcript with accumulated text + final_event = TranscriptStreamEvent( + delta={"text": " world"}, + text=" world", + role="user", + is_final=True, + current_transcript="Hello world", + ) + + assert final_event.text == " world" + assert final_event.role == "user" + assert final_event.is_final is True + assert final_event.current_transcript == "Hello world" + assert final_event.delta == {"text": " world"} + + +def test_transcript_stream_event_extends_model_stream_event(): + """Test that TranscriptStreamEvent is a ModelStreamEvent.""" + from strands.types._events import ModelStreamEvent + + event = TranscriptStreamEvent( + delta={"text": "test"}, + text="test", + role="assistant", + is_final=True, + current_transcript="test", + ) + + assert isinstance(event, ModelStreamEvent) From e13d51f2a3b140d26c8b9e1b1f514964e46c13f1 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 6 Nov 2025 17:20:20 +0300 Subject: [PATCH 15/16] fix novasonic example script --- .../bidirectional_streaming/tests/test_bidi_novasonic.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py index b538fc0238..e5a2e7c468 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py @@ -148,12 +148,12 @@ async def receive(agent, context): # Handle transcript events (bidirectional_transcript_stream) elif event_type == "bidirectional_transcript_stream": text_content = event.get("text", "") - source = event.get("source", "unknown") + role = event.get("role", "unknown") # Log transcript output - if source == "user": + if role == "user": print(f"User: {text_content}") - elif source == "assistant": + elif role == "assistant": print(f"Assistant: {text_content}") # Handle turn complete events (bidirectional_turn_complete) From 43f4428fdcc98121b4921ad2755d004447fe1ea9 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 10 Nov 2025 15:49:51 +0300 Subject: [PATCH 16/16] feat(bidirectional): Add reconnection logic on model provider errors --- .../bidirectional_streaming/agent/agent.py | 8 + .../event_loop/bidirectional_event_loop.py | 220 ++++++++---- .../test_event_loop.py | 335 ++++++++++++++++++ 3 files changed, 496 insertions(+), 67 deletions(-) create mode 100644 tests/strands/experimental/bidirectional_streaming/test_event_loop.py diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index f0205f8a8d..d143b20fb3 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -171,6 +171,8 @@ def __init__( hooks: Optional[list[HookProvider]] = None, trace_attributes: Optional[Mapping[str, AttributeValue]] = None, description: Optional[str] = None, + enable_reconnection: bool = True, + max_reconnection_attempts: int = 3, ): """Initialize bidirectional agent with required model and optional configuration. @@ -187,11 +189,17 @@ def __init__( hooks: Hooks to be added to the agent hook registry. trace_attributes: Custom trace attributes to apply to the agent's trace span. description: Description of what the Agent does. + enable_reconnection: Whether to automatically reconnect on connection failures (default: True). + max_reconnection_attempts: Maximum number of reconnection attempts (default: 3). """ self.model = model self.system_prompt = system_prompt self.messages = messages or [] + # Reconnection configuration + self.enable_reconnection = enable_reconnection + self.max_reconnection_attempts = max_reconnection_attempts + # Agent identification self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) self.name = name or _DEFAULT_AGENT_NAME diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 8af2515ef2..520f8a058b 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -65,7 +65,7 @@ def __init__(self, model: BidirectionalModel, agent: "BidirectionalAgent") -> No async def start_bidirectional_connection(agent: "BidirectionalAgent") -> BidirectionalConnection: - """Initialize bidirectional session with conycurrent background tasks. + """Initialize bidirectional session with concurrent background tasks. Creates a model-specific session and starts background tasks for processing model events, executing tools, and managing the session lifecycle. @@ -90,7 +90,7 @@ async def start_bidirectional_connection(agent: "BidirectionalAgent") -> Bidirec # This is critical - Nova Sonic needs response processing during initialization logger.debug("Starting background processors for concurrent processing") session.background_tasks = [ - asyncio.create_task(_process_model_events(session)), # Handle model responses + asyncio.create_task(_process_model_stream(session)), # Handle model responses asyncio.create_task(_process_tool_execution(session)), # Execute tools concurrently ] @@ -216,24 +216,20 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: except asyncio.QueueEmpty: break - # Also clear the agent's audio output queue + # Clear audio events from agent's output queue audio_cleared = 0 - # Create a temporary list to hold non-audio events temp_events = [] - try: - while True: + + while not session.agent._output_queue.empty(): + try: event = session.agent._output_queue.get_nowait() - # Check for audio events - event_type = event.get("type", "") - if event_type == "bidirectional_audio_stream": + if event.get("type") == "bidirectional_audio_stream": audio_cleared += 1 else: - # Keep non-audio events temp_events.append(event) - except asyncio.QueueEmpty: - pass - - # Put back non-audio events + except asyncio.QueueEmpty: + break + for event in temp_events: session.agent._output_queue.put_nowait(event) @@ -248,67 +244,104 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: logger.debug("Interruption handled - tools cancelled: %d, audio cleared: %d", cancelled_tools, cleared_count) -async def _process_model_events(session: BidirectionalConnection) -> None: - """Process model events and convert them to Strands format. +async def _handle_model_event(session: BidirectionalConnection, event: dict) -> None: + """Handle a single model event. + + Args: + session: BidirectionalConnection containing model. + event: Event dictionary from model. + """ + event_type = event.get("type", "") + + # Handle interruption detection + if event_type == "bidirectional_interruption": + logger.debug("Interruption forwarded") + await _handle_interruption(session) + await session.agent._output_queue.put(event) + return + + # Queue tool requests for concurrent execution + if "current_tool_use" in event: + tool_use = event.get("current_tool_use") + if tool_use: + tool_name = tool_use.get("name") + logger.debug("Tool usage detected: %s", tool_name) + await session.tool_queue.put(tool_use) + await session.agent._output_queue.put(event) + return - Background task that handles all model responses, converts provider-specific - events to standardized formats, and manages interruption detection. + # Send all output events to Agent for receive() method + await session.agent._output_queue.put(event) + # Update Agent conversation history for user transcripts + if event_type == "bidirectional_transcript_stream": + source = event.get("source") + text = event.get("text", "") + if source == "user" and text.strip(): + user_message = {"role": "user", "content": text} + session.agent.messages.append(user_message) + logger.debug("User transcript added to history") + + +async def _handle_connection_error(session: BidirectionalConnection, error: Exception) -> bool: + """Handle connection errors with automatic reconnection. + Args: session: BidirectionalConnection containing model. + error: Exception that occurred. + + Returns: + True if reconnected successfully, False if should propagate error. """ - logger.debug("Model events processor started") + # Check if this is a reconnectable error and reconnection is enabled + if not (_is_reconnectable_error(error) and session.agent.enable_reconnection): + logger.error("Model events error: %s", str(error)) + traceback.print_exc() + session.active = False + return False + + logger.warning("Connection lost: %s, attempting reconnection...", str(error)) + try: - async for provider_event in session.model.receive(): - if not session.active: - break + await _reconnect_session(session) + logger.info("Reconnection successful, resuming event processing") + return True + + except Exception as reconnect_error: + logger.error("Reconnection failed: %s", str(reconnect_error)) + session.active = False + raise - # Basic validation - skip invalid events - if not isinstance(provider_event, dict): - continue - - strands_event = provider_event - # Get event type - event_type = strands_event.get("type", "") - - # Handle interruption detection - if event_type == "bidirectional_interruption": - logger.debug("Interruption forwarded") - await _handle_interruption(session) - # Forward interruption event to agent for application-level handling - await session.agent._output_queue.put(strands_event) - continue - - # Queue tool requests for concurrent execution - # Check for ToolUseStreamEvent (standard agent event) - if "current_tool_use" in strands_event: - tool_use = strands_event.get("current_tool_use") - if tool_use: - tool_name = tool_use.get("name") - logger.debug("Tool usage detected: %s", tool_name) - await session.tool_queue.put(tool_use) - # Forward ToolUseStreamEvent to output queue for client visibility - await session.agent._output_queue.put(strands_event) - continue - - # Send all output events to Agent for receive() method - await session.agent._output_queue.put(strands_event) - - # Update Agent conversation history for user transcripts - if event_type == "bidirectional_transcript_stream": - source = strands_event.get("source") - text = strands_event.get("text", "") - if source == "user" and text.strip(): - user_message = {"role": "user", "content": text} - session.agent.messages.append(user_message) - logger.debug("User transcript added to history") +async def _process_model_stream(session: BidirectionalConnection) -> None: + """Process model event stream and convert to Strands format. - except Exception as e: - logger.error("Model events error: %s", str(e)) - traceback.print_exc() - finally: - logger.debug("Model events processor stopped") + Background task that handles all model responses, converts provider-specific + events to standardized formats, and manages interruption detection. + + Args: + session: BidirectionalConnection containing model. + """ + logger.debug("Model stream processor started") + + while session.active: + try: + async for provider_event in session.model.receive(): + if not session.active: + return + + # Basic validation - skip invalid events + if not isinstance(provider_event, dict): + logger.warning("Skipping invalid event (not a dict): %s", type(provider_event).__name__) + continue + + await _handle_model_event(session, provider_event) + + except Exception as e: + if not await _handle_connection_error(session, e): + raise + + logger.debug("Model stream processor stopped") async def _process_tool_execution(session: BidirectionalConnection) -> None: @@ -329,7 +362,7 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: tool_id = tool_use.get("toolUseId") session.tool_count += 1 - print(f"\nTool #{session.tool_count}: {tool_name}") + logger.info("Tool #%d: %s", session.tool_count, tool_name) logger.debug("Tool execution started: %s (id: %s)", tool_name, tool_id) @@ -376,7 +409,60 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: logger.debug("Tool execution processor stopped") +def _is_reconnectable_error(error: Exception) -> bool: + """Check if error is reconnectable (connection-related). + + Args: + error: Exception to check. + + Returns: + True if error is reconnectable, False otherwise. + """ + # Check for standard connection errors + return isinstance(error, (ConnectionError, ConnectionResetError, BrokenPipeError)) + +async def _reconnect_session(session: BidirectionalConnection) -> None: + """Reconnect session after connection failure. + + Closes old connection and attempts to reconnect based on agent's max_reconnection_attempts. + Uses agent's existing state (messages, system_prompt, tools). + + Args: + session: BidirectionalConnection to reconnect. + + Raises: + Exception: If all reconnection attempts fail. + """ + # Close old connection (ignore errors) + try: + await session.model.close() + except Exception as e: + logger.debug("Error closing old connection: %s", str(e)) + + max_attempts = session.agent.max_reconnection_attempts + + # Try reconnecting up to max_attempts times + for attempt in range(max_attempts): + try: + logger.debug("Reconnection attempt %d/%d", attempt + 1, max_attempts) + + # Reconnect using agent's existing state + await session.model.connect( + system_prompt=session.agent.system_prompt, + tools=session.agent.tool_registry.get_all_tool_specs(), + messages=session.agent.messages + ) + + logger.info("Reconnected successfully after %d attempts", attempt + 1) + return + + except Exception as e: + logger.warning("Reconnection attempt %d failed: %s", attempt + 1, str(e)) + if attempt == max_attempts - 1: # Last attempt + logger.error("All %d reconnection attempts failed", max_attempts) + raise + await asyncio.sleep(1) # Brief pause between attempts async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: diff --git a/tests/strands/experimental/bidirectional_streaming/test_event_loop.py b/tests/strands/experimental/bidirectional_streaming/test_event_loop.py new file mode 100644 index 0000000000..ea32ec9cbc --- /dev/null +++ b/tests/strands/experimental/bidirectional_streaming/test_event_loop.py @@ -0,0 +1,335 @@ +"""Unit tests for bidirectional streaming event loop.""" + +import asyncio +import pytest +from unittest.mock import AsyncMock, Mock, patch + +from strands.experimental.bidirectional_streaming.event_loop.bidirectional_event_loop import ( + BidirectionalConnection, + _handle_model_event, + _handle_connection_error, + _is_reconnectable_error, + _reconnect_session, +) + + +class TestHandleModelEvent: + """Test individual model event handling.""" + + @pytest.mark.asyncio + async def test_handle_interruption_event(self): + """Interruption events should trigger interruption handling.""" + mock_model = AsyncMock() + mock_agent = Mock() + mock_agent._output_queue = asyncio.Queue() + + session = BidirectionalConnection(model=mock_model, agent=mock_agent) + session.interruption_lock = asyncio.Lock() + session.interrupted = False + session.pending_tool_tasks = {} + session.audio_output_queue = asyncio.Queue() + + event = {"type": "bidirectional_interruption", "reason": "user_speech"} + + await _handle_model_event(session, event) + + # Event should be forwarded to output queue + output_event = await session.agent._output_queue.get() + assert output_event["type"] == "bidirectional_interruption" + + @pytest.mark.asyncio + async def test_handle_tool_use_event(self): + """Tool use events should be queued for execution.""" + mock_model = AsyncMock() + mock_agent = Mock() + mock_agent._output_queue = asyncio.Queue() + + session = BidirectionalConnection(model=mock_model, agent=mock_agent) + session.tool_queue = asyncio.Queue() + + event = { + "type": "tool_use_stream", + "current_tool_use": { + "name": "calculator", + "toolUseId": "tool-123", + "input": {"expression": "2+2"} + } + } + + await _handle_model_event(session, event) + + # Tool should be queued + tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=1.0) + assert tool_use["name"] == "calculator" + + # Event should be forwarded to output queue + output_event = await session.agent._output_queue.get() + assert "current_tool_use" in output_event + + @pytest.mark.asyncio + async def test_handle_transcript_event_user(self): + """User transcript events should update conversation history.""" + mock_model = AsyncMock() + mock_agent = Mock() + mock_agent._output_queue = asyncio.Queue() + mock_agent.messages = [] + + session = BidirectionalConnection(model=mock_model, agent=mock_agent) + + event = { + "type": "bidirectional_transcript_stream", + "source": "user", + "text": "Hello world", + "is_final": True + } + + await _handle_model_event(session, event) + + # User message should be added to history + assert len(session.agent.messages) == 1 + assert session.agent.messages[0]["role"] == "user" + assert session.agent.messages[0]["content"] == "Hello world" + + # Event should be forwarded to output queue + output_event = await session.agent._output_queue.get() + assert output_event["type"] == "bidirectional_transcript_stream" + + @pytest.mark.asyncio + async def test_handle_transcript_event_assistant(self): + """Assistant transcript events should not update history.""" + mock_model = AsyncMock() + mock_agent = Mock() + mock_agent._output_queue = asyncio.Queue() + mock_agent.messages = [] + + session = BidirectionalConnection(model=mock_model, agent=mock_agent) + + event = { + "type": "bidirectional_transcript_stream", + "source": "assistant", + "text": "Hello back", + "is_final": True + } + + await _handle_model_event(session, event) + + # Assistant messages should not be added to history + assert len(session.agent.messages) == 0 + + # Event should still be forwarded + output_event = await session.agent._output_queue.get() + assert output_event["type"] == "bidirectional_transcript_stream" + + @pytest.mark.asyncio + async def test_handle_audio_stream_event(self): + """Audio stream events should be forwarded to output queue.""" + mock_model = AsyncMock() + mock_agent = Mock() + mock_agent._output_queue = asyncio.Queue() + + session = BidirectionalConnection(model=mock_model, agent=mock_agent) + + event = { + "type": "bidirectional_audio_stream", + "audio": "base64data", + "format": "pcm", + "sample_rate": 16000, + "channels": 1 + } + + await _handle_model_event(session, event) + + # Event should be forwarded + output_event = await session.agent._output_queue.get() + assert output_event["type"] == "bidirectional_audio_stream" + assert output_event["audio"] == "base64data" + + +class TestReconnectableErrorDetection: + """Test error type detection for reconnection.""" + + def test_connection_error_is_reconnectable(self): + """ConnectionError should be reconnectable.""" + error = ConnectionError("Connection lost") + assert _is_reconnectable_error(error) is True + + def test_connection_reset_error_is_reconnectable(self): + """ConnectionResetError should be reconnectable.""" + error = ConnectionResetError("Connection reset by peer") + assert _is_reconnectable_error(error) is True + + def test_broken_pipe_error_is_reconnectable(self): + """BrokenPipeError should be reconnectable.""" + error = BrokenPipeError("Broken pipe") + assert _is_reconnectable_error(error) is True + + def test_value_error_not_reconnectable(self): + """ValueError should not be reconnectable.""" + error = ValueError("Invalid value") + assert _is_reconnectable_error(error) is False + + def test_runtime_error_not_reconnectable(self): + """RuntimeError should not be reconnectable.""" + error = RuntimeError("Runtime error") + assert _is_reconnectable_error(error) is False + + +class TestReconnectSession: + """Test session reconnection logic.""" + + @pytest.mark.asyncio + async def test_reconnect_success_first_attempt(self): + """Reconnection should succeed on first attempt.""" + mock_model = AsyncMock() + mock_model.close = AsyncMock() + mock_model.connect = AsyncMock() + + mock_agent = Mock() + mock_agent.system_prompt = "test prompt" + mock_agent.messages = [] + mock_agent.max_reconnection_attempts = 3 + mock_agent.tool_registry = Mock() + mock_agent.tool_registry.get_all_tool_specs = Mock(return_value=[]) + + session = BidirectionalConnection(model=mock_model, agent=mock_agent) + + await _reconnect_session(session) + + # Verify close and connect were called + mock_model.close.assert_called_once() + mock_model.connect.assert_called_once_with( + system_prompt="test prompt", + tools=[], + messages=[] + ) + + @pytest.mark.asyncio + async def test_reconnect_success_after_retries(self): + """Reconnection should succeed after failed attempts.""" + mock_model = AsyncMock() + mock_model.close = AsyncMock() + mock_model.connect = AsyncMock( + side_effect=[ + ConnectionError("Failed 1"), + ConnectionError("Failed 2"), + None # Success + ] + ) + + mock_agent = Mock() + mock_agent.system_prompt = "test prompt" + mock_agent.messages = [] + mock_agent.max_reconnection_attempts = 3 + mock_agent.tool_registry = Mock() + mock_agent.tool_registry.get_all_tool_specs = Mock(return_value=[]) + + session = BidirectionalConnection(model=mock_model, agent=mock_agent) + + await _reconnect_session(session) + + # Should have tried 3 times + assert mock_model.connect.call_count == 3 + + @pytest.mark.asyncio + async def test_reconnect_fails_after_max_attempts(self): + """Reconnection should fail after max attempts.""" + mock_model = AsyncMock() + mock_model.close = AsyncMock() + mock_model.connect = AsyncMock(side_effect=ConnectionError("Always fails")) + + mock_agent = Mock() + mock_agent.system_prompt = "test prompt" + mock_agent.messages = [] + mock_agent.max_reconnection_attempts = 3 + mock_agent.tool_registry = Mock() + mock_agent.tool_registry.get_all_tool_specs = Mock(return_value=[]) + + session = BidirectionalConnection(model=mock_model, agent=mock_agent) + + with pytest.raises(ConnectionError): + await _reconnect_session(session) + + # Should have tried exactly 3 times + assert mock_model.connect.call_count == 3 + + @pytest.mark.asyncio + async def test_reconnect_respects_max_attempts_config(self): + """Reconnection should respect configured max_reconnection_attempts.""" + mock_model = AsyncMock() + mock_model.close = AsyncMock() + mock_model.connect = AsyncMock(side_effect=ConnectionError("Always fails")) + + mock_agent = Mock() + mock_agent.system_prompt = "test prompt" + mock_agent.messages = [] + mock_agent.max_reconnection_attempts = 5 # Custom value + mock_agent.tool_registry = Mock() + mock_agent.tool_registry.get_all_tool_specs = Mock(return_value=[]) + + session = BidirectionalConnection(model=mock_model, agent=mock_agent) + + with pytest.raises(ConnectionError): + await _reconnect_session(session) + + # Should have tried exactly 5 times (not 3) + assert mock_model.connect.call_count == 5 + + +class TestHandleConnectionError: + """Test connection error handling logic.""" + + @pytest.mark.asyncio + async def test_reconnectable_error_with_reconnection_enabled(self): + """Should reconnect when error is reconnectable and feature is enabled.""" + mock_model = AsyncMock() + mock_model.close = AsyncMock() + mock_model.connect = AsyncMock() + + mock_agent = Mock() + mock_agent.enable_reconnection = True + mock_agent.max_reconnection_attempts = 3 + mock_agent.system_prompt = "test prompt" + mock_agent.messages = [] + mock_agent.tool_registry = Mock() + mock_agent.tool_registry.get_all_tool_specs = Mock(return_value=[]) + + session = BidirectionalConnection(model=mock_model, agent=mock_agent) + + error = ConnectionError("Connection lost") + result = await _handle_connection_error(session, error) + + assert result is True + assert session.active is True + mock_model.connect.assert_called_once() + + @pytest.mark.asyncio + async def test_reconnectable_error_with_reconnection_disabled(self): + """Should not reconnect when feature is disabled.""" + mock_model = AsyncMock() + mock_agent = Mock() + mock_agent.enable_reconnection = False + + session = BidirectionalConnection(model=mock_model, agent=mock_agent) + + error = ConnectionError("Connection lost") + result = await _handle_connection_error(session, error) + + assert result is False + assert session.active is False + mock_model.connect.assert_not_called() + + @pytest.mark.asyncio + async def test_non_reconnectable_error(self): + """Should not reconnect for non-reconnectable errors.""" + mock_model = AsyncMock() + mock_agent = Mock() + mock_agent.enable_reconnection = True + + session = BidirectionalConnection(model=mock_model, agent=mock_agent) + + error = ValueError("Invalid value") + result = await _handle_connection_error(session, error) + + assert result is False + assert session.active is False + mock_model.connect.assert_not_called()