From 1ab25c6d5203b6bb5c7a669436b3e8d26b8bfa75 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 30 Oct 2025 14:03:00 +0100 Subject: [PATCH 01/11] refactor(model): unify bidirectional model and session --- .../bidirectional_streaming/__init__.py | 14 +- .../bidirectional_streaming/agent/agent.py | 16 +- .../event_loop/bidirectional_event_loop.py | 25 +- .../models/__init__.py | 11 +- .../models/bidirectional_model.py | 125 ++-- .../models/gemini_live.py | 249 +++---- .../models/novasonic.py | 272 ++++---- .../bidirectional_streaming/models/openai.py | 246 ++++--- .../types/bidirectional_streaming.py | 14 + tests/strands/experimental/__init__.py | 1 + .../bidirectional_streaming/__init__.py | 1 + .../models/__init__.py | 1 + .../models/test_gemini_live.py | 500 ++++++++++++++ .../models/test_novasonic.py | 551 +++++++++++++++ .../models/test_openai_realtime.py | 625 ++++++++++++++++++ 15 files changed, 2176 insertions(+), 475 deletions(-) create mode 100644 tests/strands/experimental/bidirectional_streaming/__init__.py create mode 100644 tests/strands/experimental/bidirectional_streaming/models/__init__.py create mode 100644 tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py create mode 100644 tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py create mode 100644 tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 3c47dd9574..e31bc670e1 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -3,10 +3,11 @@ # Main components - Primary user interface from .agent.agent import BidirectionalAgent -# Advanced interfaces (for custom implementations) +# Unified model interface (for custom implementations) from .models.bidirectional_model import BidirectionalModel, BidirectionalModelSession # Model providers - What users need to create models +from .models.gemini_live import GeminiLiveBidirectionalModel from .models.novasonic import NovaSonicBidirectionalModel from .models.openai import OpenAIRealtimeBidirectionalModel @@ -15,7 +16,9 @@ AudioInputEvent, AudioOutputEvent, BidirectionalStreamEvent, + ImageInputEvent, InterruptionDetectedEvent, + TextInputEvent, TextOutputEvent, UsageMetricsEvent, VoiceActivityEvent, @@ -26,19 +29,22 @@ "BidirectionalAgent", # Model providers + "GeminiLiveBidirectionalModel", "NovaSonicBidirectionalModel", "OpenAIRealtimeBidirectionalModel", # Event types "AudioInputEvent", - "AudioOutputEvent", + "AudioOutputEvent", + "ImageInputEvent", + "TextInputEvent", "TextOutputEvent", "InterruptionDetectedEvent", "BidirectionalStreamEvent", "VoiceActivityEvent", "UsageMetricsEvent", - # Model interface + # Unified model interface "BidirectionalModel", - "BidirectionalModelSession", + "BidirectionalModelSession", # Backwards compatibility alias ] diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 820a6c490e..62528d4722 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -379,13 +379,15 @@ async def send(self, input_data: str | AudioInputEvent | ImageInputEvent) -> Non self.messages.append({"role": "user", "content": input_data}) logger.debug("Text sent: %d characters", len(input_data)) - await self._session.model_session.send_text_content(input_data) + # Create TextInputEvent for unified send() + text_event = {"text": input_data, "role": "user"} + await self._session.model_session.send(text_event) elif isinstance(input_data, dict) and "audioData" in input_data: - # Handle audio input - await self._session.model_session.send_audio_content(input_data) + # Handle audio input - already in AudioInputEvent format + await self._session.model_session.send(input_data) elif isinstance(input_data, dict) and "imageData" in input_data: - # Handle image input (ImageInputEvent) - await self._session.model_session.send_image_content(input_data) + # Handle image input - already in ImageInputEvent format + await self._session.model_session.send(input_data) else: raise ValueError( "Input must be either a string (text), AudioInputEvent " @@ -419,7 +421,9 @@ async def interrupt(self) -> None: ValueError: If no active session. """ self._validate_active_session() - await self._session.model_session.send_interrupt() + # Interruption is now handled internally by models through audio/event processing + # No explicit interrupt method needed in unified interface + logger.debug("Interrupt requested - handled by model's audio processing") async def end(self) -> None: """End the conversation session and cleanup all resources. diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index bbf5fb425f..521ebc0dd5 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -21,7 +21,7 @@ from ....types._events import ToolResultEvent, ToolStreamEvent from ....types.content import Message from ....types.tools import ToolResult, ToolUse -from ..models.bidirectional_model import BidirectionalModelSession +from ..models.bidirectional_model import BidirectionalModel logger = logging.getLogger(__name__) @@ -37,11 +37,11 @@ class BidirectionalConnection: handling while providing a simple interface for agent interactions. """ - def __init__(self, model_session: BidirectionalModelSession, agent: "BidirectionalAgent") -> None: - """Initialize session with model session and agent reference. + def __init__(self, model_session: BidirectionalModel, agent: "BidirectionalAgent") -> None: + """Initialize session with model and agent reference. Args: - model_session: Provider-specific bidirectional model session. + model_session: Bidirectional model instance (unified interface). agent: BidirectionalAgent instance for tool registry access. """ self.model_session = model_session @@ -76,12 +76,15 @@ async def start_bidirectional_connection(agent: "BidirectionalAgent") -> Bidirec Returns: BidirectionalConnection: Active session with background tasks running. """ - logger.debug("Starting bidirectional session - initializing model session") + logger.debug("Starting bidirectional session - initializing model connection") - # Create provider-specific session - model_session = await agent.model.create_bidirectional_connection( + # Connect to model using unified interface + await agent.model.connect( system_prompt=agent.system_prompt, tools=agent.tool_registry.get_all_tool_specs(), messages=agent.messages ) + + # Use the model directly (unified interface - no separate session) + model_session = agent.model # Create session wrapper for background processing session = BidirectionalConnection(model_session=model_session, agent=agent) @@ -257,7 +260,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: """ logger.debug("Model events processor started") try: - async for provider_event in session.model_session.receive_events(): + async for provider_event in session.model_session.receive(): if not session.active: break @@ -434,8 +437,8 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: tool_result = tool_event.tool_result tool_use_id = tool_result.get("toolUseId") - # Send result through provider-specific session - await session.model_session.send_tool_result(tool_use_id, tool_result) + # Send result through unified send() method + await session.model_session.send(tool_result) logger.debug("Tool result sent: %s", tool_use_id) # Handle streaming events if needed later @@ -471,7 +474,7 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: "content": [{"text": f"Error: {str(e)}"}] } try: - await session.model_session.send_tool_result(tool_id, error_result) + await session.model_session.send(error_result) logger.debug("Error result sent: %s", tool_id) except Exception: logger.error("Failed to send error result: %s", tool_id) diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index c5287d15d8..e2745310c5 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -1,17 +1,14 @@ """Bidirectional model interfaces and implementations.""" from .bidirectional_model import BidirectionalModel, BidirectionalModelSession -from .gemini_live import GeminiLiveBidirectionalModel, GeminiLiveSession -from .novasonic import NovaSonicBidirectionalModel, NovaSonicSession -from .openai import OpenAIRealtimeBidirectionalModel, OpenAIRealtimeSession +from .gemini_live import GeminiLiveBidirectionalModel +from .novasonic import NovaSonicBidirectionalModel +from .openai import OpenAIRealtimeBidirectionalModel __all__ = [ "BidirectionalModel", - "BidirectionalModelSession", + "BidirectionalModelSession", # Backwards compatibility alias "GeminiLiveBidirectionalModel", - "GeminiLiveSession", "NovaSonicBidirectionalModel", - "NovaSonicSession", "OpenAIRealtimeBidirectionalModel", - "OpenAIRealtimeSession", ] diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 42485561b3..3af05e1138 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -1,11 +1,10 @@ -"""Bidirectional model interface for real-time streaming conversations. +"""Unified bidirectional streaming interface. -Defines the interface for models that support bidirectional streaming capabilities. -Provides abstractions for different model providers with connection-based communication -patterns that support real-time audio and text interaction. +Single layer combining model and session abstractions for simpler implementation. +Providers implement this directly without separate model/session classes. Features: -- connection-based persistent connections +- Unified model interface (no separate session class) - Real-time bidirectional communication - Provider-agnostic event normalization - Tool execution integration @@ -13,101 +12,85 @@ import abc import logging -from typing import AsyncIterable +from typing import AsyncIterable, Union from ....types.content import Messages -from ....types.tools import ToolSpec -from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent, ImageInputEvent +from ....types.tools import ToolResult, ToolSpec +from ..types.bidirectional_streaming import ( + AudioInputEvent, + BidirectionalStreamEvent, + ImageInputEvent, + TextInputEvent, +) logger = logging.getLogger(__name__) -class BidirectionalModelSession(abc.ABC): - """Abstract interface for model-specific bidirectional communication connections. +class BidirectionalModel(abc.ABC): + """Unified interface for bidirectional streaming models. - Defines the contract for managing persistent streaming connections with individual - model providers, handling audio/text input, receiving events, and managing - tool execution results. + Combines model configuration and session communication in a single abstraction. + Providers implement this directly without separate model/session classes. """ @abc.abstractmethod - async def receive_events(self) -> AsyncIterable[BidirectionalStreamEvent]: - """Receive events from the model in standardized format. - - Converts provider-specific events to a common format that can be - processed uniformly by the event loop. - """ - raise NotImplementedError + async def connect( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs, + ) -> None: + """Establish bidirectional connection with the model. - @abc.abstractmethod - async def send_audio_content(self, audio_input: AudioInputEvent) -> None: - """Send audio content to the model during an active connection. + Initializes the connection state and prepares for real-time communication. + This replaces the old create_bidirectional_connection pattern. - Handles audio encoding and provider-specific formatting while presenting - a simple AudioInputEvent interface. + Args: + system_prompt: System instructions for the model. + tools: List of tools available to the model. + messages: Conversation history to initialize with. + **kwargs: Provider-specific configuration options. """ raise NotImplementedError - # TODO: remove with interface unification - async def send_image_content(self, image_input: ImageInputEvent) -> None: - """Send image content to the model during an active connection. - - Handles image encoding and provider-specific formatting while presenting - a simple ImageInputEvent interface. - """ - raise NotImplementedError - @abc.abstractmethod - async def send_text_content(self, text: str, **kwargs) -> None: - """Send text content to the model during ongoing generation. + async def close(self) -> None: + """Close connection and cleanup resources. - Allows natural interruption and follow-up questions without requiring - connection restart. + Terminates the active connection and releases any held resources. """ raise NotImplementedError @abc.abstractmethod - async def send_interrupt(self) -> None: - """Send interruption signal to stop generation immediately. - - Enables responsive conversational experiences where users can - naturally interrupt during model responses. - """ - raise NotImplementedError + async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: + """Receive events from the model in standardized format. - @abc.abstractmethod - async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: - """Send tool execution result to the model. + Yields provider-agnostic events that can be processed uniformly + by the event loop. Converts provider-specific events to common format. - Formats and sends tool results according to the provider's specific protocol. - Handles both successful results and error cases through the result dictionary. + Yields: + BidirectionalStreamEvent: Standardized event dictionaries. """ raise NotImplementedError @abc.abstractmethod - async def close(self) -> None: - """Close the connection and cleanup resources.""" - raise NotImplementedError + async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: + """Send structured content to the model. + Unified method for sending all types of content. Implementations should + dispatch to appropriate internal handlers based on content type. -class BidirectionalModel(abc.ABC): - """Interface for models that support bidirectional streaming. - - Defines the contract for creating persistent streaming connections that support - real-time audio and text communication with AI models. - """ - - @abc.abstractmethod - async def create_bidirectional_connection( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs, - ) -> BidirectionalModelSession: - """Create a bidirectional connection with the model. + Args: + content: Typed event (TextInputEvent, ImageInputEvent, AudioInputEvent, or ToolResult). - Establishes a persistent connection for real-time communication while - abstracting provider-specific initialization requirements. + Example: + await model.send(TextInputEvent(text="Hello", role="user")) + await model.send(AudioInputEvent(audioData=bytes, format="pcm", ...)) + await model.send(ToolResult(toolUseId="123", status="success", ...)) """ raise NotImplementedError + + +# Backwards compatibility alias - will be removed in future version +BidirectionalModelSession = BidirectionalModel diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 64c4d73483..578de5a2b6 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -1,11 +1,11 @@ """Gemini Live API bidirectional model provider using official Google GenAI SDK. -Implements the BidirectionalModel interface for Google's Gemini Live API using the +Implements the unified BidirectionalModel interface for Google's Gemini Live API using the official Google GenAI SDK for simplified and robust WebSocket communication. Key improvements over custom WebSocket implementation: - Uses official google-genai SDK with native Live API support -- Simplified session management with client.aio.live.connect() +- Unified model interface (no separate session class) - Built-in tool integration and event handling - Automatic WebSocket connection management and error handling - Native support for audio/text streaming and interruption @@ -15,14 +15,14 @@ import base64 import logging import uuid -from typing import Any, AsyncIterable, Dict, List, Optional +from typing import Any, AsyncIterable, Dict, List, Optional, Union from google import genai from google.genai import types as genai_types from google.genai.types import LiveServerMessage, LiveServerContent from ....types.content import Messages -from ....types.tools import ToolSpec, ToolUse +from ....types.tools import ToolResult, ToolSpec, ToolUse from ..types.bidirectional_streaming import ( AudioInputEvent, AudioOutputEvent, @@ -30,10 +30,11 @@ BidirectionalConnectionStartEvent, ImageInputEvent, InterruptionDetectedEvent, + TextInputEvent, TextOutputEvent, TranscriptEvent, ) -from .bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .bidirectional_model import BidirectionalModel logger = logging.getLogger(__name__) @@ -43,45 +44,70 @@ GEMINI_CHANNELS = 1 -class GeminiLiveSession(BidirectionalModelSession): - """Gemini Live API session using official Google GenAI SDK. +class GeminiLiveBidirectionalModel(BidirectionalModel): + """Unified Gemini Live API implementation using official Google GenAI SDK. + Combines model configuration and connection state in a single class. Provides a clean interface to Gemini Live API using the official SDK, eliminating custom WebSocket handling and providing robust error handling. """ - def __init__(self, client: genai.Client, model_id: str, config: Dict[str, Any]): - """Initialize Gemini Live API session. + def __init__( + self, + model_id: str = "models/gemini-2.0-flash-live-preview-04-09", + api_key: Optional[str] = None, + **config + ): + """Initialize Gemini Live API bidirectional model. Args: - client: Gemini client instance - model_id: Model identifier - config: Model configuration including live config + model_id: Gemini Live model identifier. + api_key: Google AI API key for authentication. + **config: Additional configuration. """ - self.client = client + # Model configuration self.model_id = model_id + self.api_key = api_key self.config = config - self.session_id = str(uuid.uuid4()) - self._active = True + + # Create Gemini client with proper API version + client_kwargs = {} + if api_key: + client_kwargs["api_key"] = api_key + + # Use v1alpha for Live API as it has better model support + client_kwargs["http_options"] = {"api_version": "v1alpha"} + + self.client = genai.Client(**client_kwargs) + + # Connection state (initialized in connect()) self.live_session = None self.live_session_cm = None - - + self.session_id = None + self._active = False - async def initialize( + async def connect( self, system_prompt: Optional[str] = None, tools: Optional[List[ToolSpec]] = None, - messages: Optional[Messages] = None + messages: Optional[Messages] = None, + **kwargs ) -> None: - """Initialize Gemini Live API session by creating the connection.""" + """Establish bidirectional connection with Gemini Live API. + Args: + system_prompt: System instructions for the model. + tools: List of tools available to the model. + messages: Conversation history to initialize with. + **kwargs: Additional configuration options. + """ try: - # Build live config - live_config = self.config.get("live_config") + # Initialize connection state + self.session_id = str(uuid.uuid4()) + self._active = True - if live_config is None: - raise ValueError("live_config is required but not found in session config") + # Build live config + live_config = self._build_live_config(system_prompt, tools, **kwargs) # Create the context manager self.live_session_cm = self.client.aio.live.connect( @@ -96,9 +122,8 @@ async def initialize( if messages: await self._send_message_history(messages) - except Exception as e: - logger.error("Error initializing Gemini Live session: %s", e) + logger.error("Error connecting to Gemini Live: %s", e) raise async def _send_message_history(self, messages: Messages) -> None: @@ -125,13 +150,13 @@ async def _send_message_history(self, messages: Messages) -> None: content = genai_types.Content(role=role, parts=content_parts) await self.live_session.send_client_content(turns=content) - async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: + async def receive(self) -> AsyncIterable[Dict[str, Any]]: """Receive Gemini Live API events and convert to provider-agnostic format.""" # Emit connection start event connection_start: BidirectionalConnectionStartEvent = { "connectionId": self.session_id, - "metadata": {"provider": "gemini_live", "model_id": self.config.get("model_id")} + "metadata": {"provider": "gemini_live", "model_id": self.model_id} } yield {"BidirectionalConnectionStart": connection_start} @@ -251,15 +276,43 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic logger.error("Message attributes: %s", [attr for attr in dir(message) if not attr.startswith('_')]) return None - async def send_audio_content(self, audio_input: AudioInputEvent) -> None: - """Send audio content using Gemini Live API. + async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: + """Unified send method for all content types. - Gemini Live expects continuous audio streaming via send_realtime_input. - This automatically triggers VAD and can interrupt ongoing responses. + Dispatches to appropriate internal handler based on content type. + + Args: + content: Typed event (TextInputEvent, ImageInputEvent, AudioInputEvent, or ToolResult). """ if not self._active: return + try: + if isinstance(content, dict): + # Dispatch based on content structure + if "text" in content and "role" in content: + # TextInputEvent + await self._send_text_content(content["text"]) + elif "audioData" in content: + # AudioInputEvent + await self._send_audio_content(content) + elif "imageData" in content or "image_url" in content: + # ImageInputEvent + await self._send_image_content(content) + elif "toolUseId" in content and "status" in content: + # ToolResult + await self._send_tool_result(content) + else: + logger.warning(f"Unknown content type with keys: {content.keys()}") + except Exception as e: + logger.error(f"Error sending content: {e}") + + async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Internal: Send audio content using Gemini Live API. + + Gemini Live expects continuous audio streaming via send_realtime_input. + This automatically triggers VAD and can interrupt ongoing responses. + """ try: # Create audio blob for the SDK audio_blob = genai_types.Blob( @@ -273,18 +326,15 @@ async def send_audio_content(self, audio_input: AudioInputEvent) -> None: except Exception as e: logger.error("Error sending audio content: %s", e) - async def send_image_content(self, image_input: ImageInputEvent) -> None: - """Send image content using Gemini Live API. + async def _send_image_content(self, image_input: ImageInputEvent) -> None: + """Internal: Send image content using Gemini Live API. Sends image frames following the same pattern as the GitHub example. Images are sent as base64-encoded data with MIME type. """ - if not self._active: - return - try: # Prepare the message based on encoding - if image_input["encoding"] == "base64": + if image_input.get("encoding") == "base64": # Data is already base64 encoded if isinstance(image_input["imageData"], bytes): data_str = image_input["imageData"].decode() @@ -306,11 +356,8 @@ async def send_image_content(self, image_input: ImageInputEvent) -> None: except Exception as e: logger.error("Error sending image content: %s", e) - async def send_text_content(self, text: str, **kwargs) -> None: - """Send text content using Gemini Live API.""" - if not self._active: - return - + async def _send_text_content(self, text: str) -> None: + """Internal: Send text content using Gemini Live API.""" try: # Create content with text content = genai_types.Content( @@ -324,36 +371,25 @@ async def send_text_content(self, text: str, **kwargs) -> None: except Exception as e: logger.error("Error sending text content: %s", e) - async def send_interrupt(self) -> None: - """Send interruption signal to Gemini Live API. - - Gemini Live uses automatic VAD-based interruption. When new audio input - is detected, it automatically interrupts the ongoing generation. - We don't need to send explicit interrupt signals like Nova Sonic. - """ - if not self._active: - return - + async def _send_tool_result(self, tool_result: ToolResult) -> None: + """Internal: Send tool result using Gemini Live API.""" try: - # Gemini Live handles interruption automatically through VAD - # When new audio input is sent via send_realtime_input, it automatically - # interrupts any ongoing generation. No explicit interrupt signal needed. - logger.debug("Interrupt requested - Gemini Live handles this automatically via VAD") + tool_use_id = tool_result.get("toolUseId") + + # Extract result content + result_data = {} + if "content" in tool_result: + # Extract text from content blocks + for block in tool_result["content"]: + if "text" in block: + result_data = {"result": block["text"]} + break - except Exception as e: - logger.error("Error in interrupt handling: %s", e) - - async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: - """Send tool result using Gemini Live API.""" - if not self._active: - return - - try: # Create function response func_response = genai_types.FunctionResponse( id=tool_use_id, name=tool_use_id, # Gemini uses name as identifier - response=result + response=result_data ) # Send tool response @@ -361,11 +397,6 @@ async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> No except Exception as e: logger.error("Error sending tool result: %s", e) - async def send_tool_error(self, tool_use_id: str, error: str) -> None: - """Send tool error using Gemini Live API.""" - error_result = {"error": error} - await self.send_tool_result(tool_use_id, error_result) - async def close(self) -> None: """Close Gemini Live API connection.""" if not self._active: @@ -378,69 +409,7 @@ async def close(self) -> None: if self.live_session_cm: await self.live_session_cm.__aexit__(None, None, None) except Exception as e: - logger.error("Error closing Gemini Live session: %s", e) - raise - - -class GeminiLiveBidirectionalModel(BidirectionalModel): - """Gemini Live API model implementation using official Google GenAI SDK. - - Provides access to Google's Gemini Live API through the bidirectional - streaming interface, using the official SDK for robust and simple integration. - """ - - def __init__( - self, - model_id: str = "models/gemini-2.0-flash-live-preview-04-09", - api_key: Optional[str] = None, - **config - ): - """Initialize Gemini Live API bidirectional model. - - Args: - model_id: Gemini Live model identifier. - api_key: Google AI API key for authentication. - **config: Additional configuration. - """ - self.model_id = model_id - self.api_key = api_key - self.config = config - - # Create Gemini client with proper API version - client_kwargs = {} - if api_key: - client_kwargs["api_key"] = api_key - - # Use v1alpha for Live API as it has better model support - client_kwargs["http_options"] = {"api_version": "v1alpha"} - - self.client = genai.Client(**client_kwargs) - - async def create_bidirectional_connection( - self, - system_prompt: Optional[str] = None, - tools: Optional[List[ToolSpec]] = None, - messages: Optional[Messages] = None, - **kwargs - ) -> BidirectionalModelSession: - """Create Gemini Live API bidirectional connection using official SDK.""" - - try: - # Build configuration - live_config = self._build_live_config(system_prompt, tools, **kwargs) - - # Create session config - session_config = self._get_session_config() - session_config["live_config"] = live_config - - # Create and initialize session wrapper - session = GeminiLiveSession(self.client, self.model_id, session_config) - await session.initialize(system_prompt, tools, messages) - - return session - - except Exception as e: - logger.error("Failed to create Gemini Live connection: %s", e) + logger.error("Error closing Gemini Live connection: %s", e) raise def _build_live_config( @@ -488,12 +457,4 @@ def _format_tools_for_live_api(self, tool_specs: List[ToolSpec]) -> List[genai_t for tool_spec in tool_specs ], ), - ] - - def _get_session_config(self) -> Dict[str, Any]: - """Get session configuration for Gemini Live API.""" - return { - "model_id": self.model_id, - "params": self.config.get("params"), - **self.config - } \ No newline at end of file + ] \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 134ff73fd2..62b53a127b 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -1,9 +1,11 @@ """Nova Sonic bidirectional model provider for real-time streaming conversations. -Implements the BidirectionalModel interface for Amazon's Nova Sonic, handling the +Implements the unified BidirectionalModel interface for Amazon's Nova Sonic, handling the complex event sequencing and audio processing required by Nova Sonic's InvokeModelWithBidirectionalStream protocol. +Unified model interface - combines configuration and connection state in single class. + Nova Sonic specifics: - Hierarchical event sequences: connectionStart → promptStart → content streaming - Base64-encoded audio format with hex encoding @@ -19,25 +21,31 @@ import time import traceback import uuid -from typing import AsyncIterable +from typing import AsyncIterable, Union from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme -from aws_sdk_bedrock_runtime.models import BidirectionalInputPayloadPart, InvokeModelWithBidirectionalStreamInputChunk, InvokeModelWithBidirectionalStreamOperationOutput +from aws_sdk_bedrock_runtime.models import ( + BidirectionalInputPayloadPart, + InvokeModelWithBidirectionalStreamInputChunk, + InvokeModelWithBidirectionalStreamOperationOutput, +) from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver from ....types.content import Messages -from ....types.tools import ToolSpec, ToolUse +from ....types.tools import ToolResult, ToolSpec, ToolUse from ..types.bidirectional_streaming import ( AudioInputEvent, AudioOutputEvent, BidirectionalConnectionEndEvent, BidirectionalConnectionStartEvent, + ImageInputEvent, InterruptionDetectedEvent, + TextInputEvent, TextOutputEvent, UsageMetricsEvent, ) -from .bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .bidirectional_model import BidirectionalModel logger = logging.getLogger(__name__) @@ -72,29 +80,36 @@ RESPONSE_TIMEOUT = 1.0 -class NovaSonicSession(BidirectionalModelSession): - """Nova Sonic connection implementation handling the provider's specific protocol. +class NovaSonicBidirectionalModel(BidirectionalModel): + """Unified Nova Sonic implementation for bidirectional streaming. + Combines model configuration and connection state in a single class. Manages Nova Sonic's complex event sequencing, audio format conversion, and - tool execution patterns while providing the standard BidirectionalModelSession - interface. + tool execution patterns while providing the standard BidirectionalModel interface. """ - def __init__(self, stream: InvokeModelWithBidirectionalStreamOperationOutput, config: dict[str, any]) -> None: - """Initialize Nova Sonic connection. + def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config: any) -> None: + """Initialize Nova Sonic bidirectional model. Args: - stream: Nova Sonic bidirectional stream operation output from AWS SDK. - config: Model configuration. + model_id: Nova Sonic model identifier. + region: AWS region. + **config: Additional configuration. """ - self.stream = stream + # Model configuration + self.model_id = model_id + self.region = region self.config = config - self.prompt_name = str(uuid.uuid4()) - self._active = True + self._client = None + + # Connection state (initialized in connect()) + self.stream = None + self.prompt_name = None + self._active = False # Nova Sonic requires unique content names - self.audio_content_name = str(uuid.uuid4()) - self.text_content_name = str(uuid.uuid4()) + self.audio_content_name = None + self.text_content_name = None # Audio connection state self.audio_connection_active = False @@ -102,33 +117,67 @@ def __init__(self, stream: InvokeModelWithBidirectionalStreamOperationOutput, co self.silence_threshold = SILENCE_THRESHOLD self.silence_task = None - # Validate stream - if not stream: - logger.error("Stream is None") - raise ValueError("Stream cannot be None") + # Background task and event queue + self._response_task = None + self._event_queue = None - logger.debug("Nova Sonic connection initialized with prompt: %s", self.prompt_name) + logger.debug("Nova Sonic bidirectional model initialized: %s", model_id) - async def initialize( + async def connect( self, system_prompt: str | None = None, tools: list[ToolSpec] | None = None, messages: Messages | None = None, + **kwargs, ) -> None: - """Initialize Nova Sonic connection with required protocol sequence.""" + """Establish bidirectional connection to Nova Sonic. + + Args: + system_prompt: System instructions for the model. + tools: List of tools available to the model. + messages: Conversation history to initialize with. + **kwargs: Additional configuration options. + """ + logger.debug("Nova connection create - starting") + try: - system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." + # Initialize client if needed + if not self._client: + await self._initialize_client() + + # Initialize connection state + self.prompt_name = str(uuid.uuid4()) + self._active = True + self.audio_content_name = str(uuid.uuid4()) + self.text_content_name = str(uuid.uuid4()) + self._event_queue = asyncio.Queue() + + # Start Nova Sonic bidirectional stream + self.stream = await self._client.invoke_model_with_bidirectional_stream( + InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) + ) + + # Validate stream + if not self.stream: + logger.error("Stream is None") + raise ValueError("Stream cannot be None") + logger.debug("Nova Sonic connection initialized with prompt: %s", self.prompt_name) + + # Send initialization events + system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." init_events = self._build_initialization_events(system_prompt, tools or [], messages) logger.debug("Nova Sonic initialization - sending %d events", len(init_events)) await self._send_initialization_events(init_events) - logger.info("Nova Sonic connection initialized successfully") + # Start background response processor self._response_task = asyncio.create_task(self._process_responses()) + logger.info("Nova Sonic connection established successfully") + except Exception as e: - logger.error("Error during Nova Sonic initialization: %s", e) + logger.error("Nova connection create error: %s", str(e)) raise def _build_initialization_events( @@ -206,7 +255,7 @@ def _log_event_type(self, nova_event: dict[str, any]) -> None: audio_bytes = base64.b64decode(audio_content) logger.debug("Nova audio output: %d bytes", len(audio_bytes)) - async def receive_events(self) -> AsyncIterable[dict[str, any]]: + async def receive(self) -> AsyncIterable[dict[str, any]]: """Receive Nova Sonic events and convert to provider-agnostic format.""" if not self.stream: logger.error("Stream is None") @@ -217,14 +266,10 @@ async def receive_events(self) -> AsyncIterable[dict[str, any]]: # Emit connection start event to Strands event system connection_start: BidirectionalConnectionStartEvent = { "connectionId": self.prompt_name, - "metadata": {"provider": "nova_sonic", "model_id": self.config.get("model_id")}, + "metadata": {"provider": "nova_sonic", "model_id": self.model_id}, } yield {"BidirectionalConnectionStart": connection_start} - # Initialize event queue if not already done - if not hasattr(self, "_event_queue"): - self._event_queue = asyncio.Queue() - try: while self._active: try: @@ -252,8 +297,39 @@ async def receive_events(self) -> AsyncIterable[dict[str, any]]: } yield {"BidirectionalConnectionEnd": connection_end} - async def start_audio_connection(self) -> None: - """Start audio input connection (call once before sending audio chunks).""" + async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: + """Unified send method for all content types. + + Dispatches to appropriate internal handler based on content type. + + Args: + content: Typed event (TextInputEvent, ImageInputEvent, AudioInputEvent, or ToolResult). + """ + if not self._active: + return + + try: + if isinstance(content, dict): + # Dispatch based on content structure + if "text" in content and "role" in content: + # TextInputEvent + await self._send_text_content(content["text"]) + elif "audioData" in content: + # AudioInputEvent + await self._send_audio_content(content) + elif "imageData" in content or "image_url" in content: + # ImageInputEvent - not supported by Nova Sonic + logger.warning("Image input not supported by Nova Sonic") + elif "toolUseId" in content and "status" in content: + # ToolResult + await self._send_tool_result(content) + else: + logger.warning(f"Unknown content type with keys: {content.keys()}") + except Exception as e: + logger.error(f"Error sending content: {e}") + + async def _start_audio_connection(self) -> None: + """Internal: Start audio input connection (call once before sending audio chunks).""" if self.audio_connection_active: return @@ -277,14 +353,11 @@ async def start_audio_connection(self) -> None: await self._send_nova_event(audio_content_start) self.audio_connection_active = True - async def send_audio_content(self, audio_input: AudioInputEvent) -> None: - """Send audio using Nova Sonic protocol-specific format.""" - if not self._active: - return - + async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Internal: Send audio using Nova Sonic protocol-specific format.""" # Start audio connection if not already active if not self.audio_connection_active: - await self.start_audio_connection() + await self._start_audio_connection() # Update last audio time and cancel any pending silence task self.last_audio_time = time.time() @@ -313,19 +386,19 @@ async def send_audio_content(self, audio_input: AudioInputEvent) -> None: self.silence_task = asyncio.create_task(self._check_silence()) async def _check_silence(self) -> None: - """Check for silence and automatically end audio connection.""" + """Internal: Check for silence and automatically end audio connection.""" try: await asyncio.sleep(self.silence_threshold) if self.audio_connection_active and self.last_audio_time: elapsed = time.time() - self.last_audio_time if elapsed >= self.silence_threshold: logger.debug("Nova silence detected: %.2f seconds", elapsed) - await self.end_audio_input() + await self._end_audio_input() except asyncio.CancelledError: pass - async def end_audio_input(self) -> None: - """End current audio input connection to trigger Nova Sonic processing.""" + async def _end_audio_input(self) -> None: + """Internal: End current audio input connection to trigger Nova Sonic processing.""" if not self.audio_connection_active: return @@ -338,11 +411,8 @@ async def end_audio_input(self) -> None: await self._send_nova_event(audio_content_end) self.audio_connection_active = False - async def send_text_content(self, text: str, **kwargs) -> None: - """Send text content using Nova Sonic format.""" - if not self._active: - return - + async def _send_text_content(self, text: str) -> None: + """Internal: Send text content using Nova Sonic format.""" content_name = str(uuid.uuid4()) events = [ self._get_text_content_start_event(content_name), @@ -353,37 +423,45 @@ async def send_text_content(self, text: str, **kwargs) -> None: for event in events: await self._send_nova_event(event) - async def send_interrupt(self) -> None: - """Send interruption signal to Nova Sonic.""" - if not self._active: - return - + async def _send_interrupt(self) -> None: + """Internal: Send interruption signal to Nova Sonic.""" # Nova Sonic handles interruption through special input events - interrupt_event = { - "event": { - "audioInput": { - "promptName": self.prompt_name, - "contentName": self.audio_content_name, - "stopReason": "INTERRUPTED", + interrupt_event = json.dumps( + { + "event": { + "audioInput": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name, + "stopReason": "INTERRUPTED", + } } } - } + ) await self._send_nova_event(interrupt_event) - async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: - """Send tool result using Nova Sonic toolResult format.""" - if not self._active: - return + async def _send_tool_result(self, tool_result: ToolResult) -> None: + """Internal: Send tool result using Nova Sonic toolResult format.""" + tool_use_id = tool_result.get("toolUseId") logger.debug("Nova tool result send: %s", tool_use_id) + + # Extract result content + result_data = {} + if "content" in tool_result: + # Extract text from content blocks + for block in tool_result["content"]: + if "text" in block: + result_data = {"result": block["text"]} + break + content_name = str(uuid.uuid4()) events = [ self._get_tool_content_start_event(content_name, tool_use_id), - self._get_tool_result_event(content_name, result), + self._get_tool_result_event(content_name, result_data), self._get_content_end_event(content_name), ] - for _i, event in enumerate(events): + for event in events: await self._send_nova_event(event) async def close(self) -> None: @@ -405,7 +483,7 @@ async def close(self) -> None: try: # End audio connection if active if self.audio_connection_active: - await self.end_audio_input() + await self._end_audio_input() # Send cleanup events cleanup_events = [self._get_prompt_end_event(), self._get_connection_end_event()] @@ -640,60 +718,6 @@ async def _send_nova_event(self, event: str) -> None: logger.error("Event was: %s", event) raise - -class NovaSonicBidirectionalModel(BidirectionalModel): - """Nova Sonic model implementation for bidirectional streaming. - - Provides access to Amazon's Nova Sonic model through the bidirectional - streaming interface, handling AWS authentication and connection management. - """ - - def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config: any) -> None: - """Initialize Nova Sonic bidirectional model. - - Args: - model_id: Nova Sonic model identifier. - region: AWS region. - **config: Additional configuration. - """ - self.model_id = model_id - self.region = region - self.config = config - self._client = None - - logger.debug("Nova Sonic bidirectional model initialized: %s", model_id) - - async def create_bidirectional_connection( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs, - ) -> BidirectionalModelSession: - """Create Nova Sonic bidirectional connection.""" - logger.debug("Nova connection create - starting") - - # Initialize client if needed - if not self._client: - await self._initialize_client() - - # Start Nova Sonic bidirectional stream - try: - stream = await self._client.invoke_model_with_bidirectional_stream( - InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) - ) - - # Create and initialize connection - connection = NovaSonicSession(stream, self.config) - await connection.initialize(system_prompt, tools, messages) - - logger.debug("Nova connection created") - return connection - except Exception as e: - logger.error("Nova connection create error: %s", str(e)) - logger.error("Failed to create Nova Sonic connection: %s", e) - raise - async def _initialize_client(self) -> None: """Initialize Nova Sonic client.""" try: diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 7d009b1c74..0208ee162d 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -2,6 +2,8 @@ Provides real-time audio and text communication through OpenAI's Realtime API with WebSocket connections, voice activity detection, and function calling. + +Unified model interface - combines configuration and connection state in single class. """ import asyncio @@ -9,24 +11,26 @@ import json import logging import uuid -from typing import AsyncIterable +from typing import AsyncIterable, Union import websockets from websockets.client import WebSocketClientProtocol from websockets.exceptions import ConnectionClosed from ....types.content import Messages -from ....types.tools import ToolSpec, ToolUse +from ....types.tools import ToolResult, ToolSpec, ToolUse from ..types.bidirectional_streaming import ( AudioInputEvent, AudioOutputEvent, BidirectionalConnectionEndEvent, BidirectionalConnectionStartEvent, BidirectionalStreamEvent, + ImageInputEvent, + TextInputEvent, TextOutputEvent, VoiceActivityEvent, ) -from .bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .bidirectional_model import BidirectionalModel logger = logging.getLogger(__name__) @@ -55,63 +59,115 @@ } -class OpenAIRealtimeSession(BidirectionalModelSession): - """OpenAI Realtime API session for real-time audio/text streaming. +class OpenAIRealtimeBidirectionalModel(BidirectionalModel): + """Unified OpenAI Realtime API implementation for bidirectional streaming. + Combines model configuration and connection state in a single class. Manages WebSocket connection to OpenAI's Realtime API with automatic VAD, function calling, and event conversion to Strands format. """ - def __init__(self, websocket: WebSocketClientProtocol, config: dict[str, any]) -> None: - """Initialize OpenAI Realtime session.""" - self.websocket = websocket + def __init__( + self, + model: str = DEFAULT_MODEL, + api_key: str | None = None, + **config: any + ) -> None: + """Initialize OpenAI Realtime bidirectional model. + + Args: + model: OpenAI model identifier (default: gpt-realtime). + api_key: OpenAI API key for authentication. + **config: Additional configuration (organization, project, session params). + """ + # Model configuration + self.model = model + self.api_key = api_key self.config = config - self.session_id = str(uuid.uuid4()) - self._active = True - self._event_queue = asyncio.Queue() + import os + if not self.api_key: + self.api_key = os.getenv("OPENAI_API_KEY") + if not self.api_key: + raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter.") + + # Connection state (initialized in connect()) + self.websocket = None + self.session_id = None + self._active = False + + self._event_queue = None self._response_task = None self._function_call_buffer = {} - logger.debug("OpenAI Realtime session initialized: %s", self.session_id) - - def _require_active(self) -> bool: - """Check if session is active.""" - return self._active - - def _create_text_event(self, text: str, role: str) -> dict[str, any]: - """Create standardized text output event.""" - text_output: TextOutputEvent = {"text": text, "role": role} - return {"textOutput": text_output} - - def _create_voice_activity_event(self, activity_type: str) -> dict[str, any]: - """Create standardized voice activity event.""" - voice_activity: VoiceActivityEvent = {"activityType": activity_type} - return {"voiceActivity": voice_activity} - - + logger.debug("OpenAI Realtime bidirectional model initialized: %s", model) - async def initialize( + async def connect( self, system_prompt: str | None = None, tools: list[ToolSpec] | None = None, messages: Messages | None = None, + **kwargs, ) -> None: - """Initialize session with configuration.""" + """Establish bidirectional connection to OpenAI Realtime API. + + Args: + system_prompt: System instructions for the model. + tools: List of tools available to the model. + messages: Conversation history to initialize with. + **kwargs: Additional configuration options. + """ + logger.info("Creating OpenAI Realtime connection...") + try: + # Initialize connection state + self.session_id = str(uuid.uuid4()) + self._active = True + self._event_queue = asyncio.Queue() + self._function_call_buffer = {} + + # Establish WebSocket connection + url = f"{OPENAI_REALTIME_URL}?model={self.model}" + + headers = [("Authorization", f"Bearer {self.api_key}")] + if "organization" in self.config: + headers.append(("OpenAI-Organization", self.config["organization"])) + if "project" in self.config: + headers.append(("OpenAI-Project", self.config["project"])) + + self.websocket = await websockets.connect(url, additional_headers=headers) + logger.info("WebSocket connected successfully") + + # Configure session session_config = self._build_session_config(system_prompt, tools) await self._send_event({"type": "session.update", "session": session_config}) + # Add conversation history if provided if messages: await self._add_conversation_history(messages) + # Start background response processor self._response_task = asyncio.create_task(self._process_responses()) - logger.info("OpenAI Realtime session initialized successfully") + logger.info("OpenAI Realtime connection established") except Exception as e: - logger.error("Error during OpenAI Realtime initialization: %s", e) + logger.error("OpenAI connection error: %s", e) raise + def _require_active(self) -> bool: + """Check if session is active.""" + return self._active + + def _create_text_event(self, text: str, role: str) -> dict[str, any]: + """Create standardized text output event.""" + text_output: TextOutputEvent = {"text": text, "role": role} + return {"textOutput": text_output} + + def _create_voice_activity_event(self, activity_type: str) -> dict[str, any]: + """Create standardized voice activity event.""" + voice_activity: VoiceActivityEvent = {"activityType": activity_type} + return {"voiceActivity": voice_activity} + def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict: """Build session configuration for OpenAI Realtime API.""" config = DEFAULT_SESSION_CONFIG.copy() @@ -201,11 +257,11 @@ async def _process_responses(self) -> None: self._active = False logger.debug("OpenAI Realtime response processor stopped") - async def receive_events(self) -> AsyncIterable[BidirectionalStreamEvent]: + async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: """Receive OpenAI events and convert to Strands format.""" connection_start: BidirectionalConnectionStartEvent = { "connectionId": self.session_id, - "metadata": {"provider": "openai_realtime", "model": self.config.get("model", DEFAULT_MODEL)}, + "metadata": {"provider": "openai_realtime", "model": self.model}, } yield {"BidirectionalConnectionStart": connection_start} @@ -366,19 +422,44 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] logger.debug("Unhandled OpenAI event type: %s", event_type) return None - async def send_audio_content(self, audio_input: AudioInputEvent) -> None: - """Send audio content to OpenAI for processing.""" + async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: + """Unified send method for all content types. + + Dispatches to appropriate internal handler based on content type. + + Args: + content: Typed event (TextInputEvent, ImageInputEvent, AudioInputEvent, or ToolResult). + """ if not self._require_active(): return + try: + if isinstance(content, dict): + # Dispatch based on content structure + if "text" in content and "role" in content: + # TextInputEvent + await self._send_text_content(content["text"]) + elif "audioData" in content: + # AudioInputEvent + await self._send_audio_content(content) + elif "imageData" in content or "image_url" in content: + # ImageInputEvent - not supported by OpenAI Realtime yet + logger.warning("Image input not supported by OpenAI Realtime API") + elif "toolUseId" in content and "status" in content: + # ToolResult + await self._send_tool_result(content) + else: + logger.warning(f"Unknown content type with keys: {content.keys()}") + except Exception as e: + logger.error(f"Error sending content: {e}") + + async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Internal: Send audio content to OpenAI for processing.""" audio_base64 = base64.b64encode(audio_input["audioData"]).decode("utf-8") await self._send_event({"type": "input_audio_buffer.append", "audio": audio_base64}) - async def send_text_content(self, text: str) -> None: - """Send text content to OpenAI for processing.""" - if not self._require_active(): - return - + async def _send_text_content(self, text: str) -> None: + """Internal: Send text content to OpenAI for processing.""" item_data = { "type": "message", "role": "user", @@ -387,20 +468,26 @@ async def send_text_content(self, text: str) -> None: await self._send_event({"type": "conversation.item.create", "item": item_data}) await self._send_event({"type": "response.create"}) - async def send_interrupt(self) -> None: - """Send interruption signal to OpenAI.""" - if not self._require_active(): - return - + async def _send_interrupt(self) -> None: + """Internal: Send interruption signal to OpenAI.""" await self._send_event({"type": "response.cancel"}) - async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: - """Send tool result back to OpenAI.""" - if not self._require_active(): - return + async def _send_tool_result(self, tool_result: ToolResult) -> None: + """Internal: Send tool result back to OpenAI.""" + tool_use_id = tool_result.get("toolUseId") logger.debug("OpenAI tool result send: %s", tool_use_id) - result_text = json.dumps(result) if not isinstance(result, str) else result + + # Extract result content + result_data = {} + if "content" in tool_result: + # Extract text from content blocks + for block in tool_result["content"]: + if "text" in block: + result_data = block["text"] + break + + result_text = json.dumps(result_data) if not isinstance(result_data, str) else result_data item_data = { "type": "function_call_output", @@ -443,60 +530,3 @@ async def _send_event(self, event: dict[str, any]) -> None: raise -class OpenAIRealtimeBidirectionalModel(BidirectionalModel): - """OpenAI Realtime API provider for Strands bidirectional streaming. - - Provides real-time audio/text communication through OpenAI's Realtime API - with WebSocket connections, voice activity detection, and function calling. - """ - - def __init__( - self, - model: str = DEFAULT_MODEL, - api_key: str | None = None, - **config: any - ) -> None: - """Initialize OpenAI Realtime bidirectional model.""" - self.model = model - self.api_key = api_key - self.config = config - - import os - if not self.api_key: - self.api_key = os.getenv("OPENAI_API_KEY") - if not self.api_key: - raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter.") - - logger.debug("OpenAI Realtime bidirectional model initialized: %s", model) - - async def create_bidirectional_connection( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs, - ) -> BidirectionalModelSession: - """Create bidirectional connection to OpenAI Realtime API.""" - logger.info("Creating OpenAI Realtime connection...") - - try: - url = f"{OPENAI_REALTIME_URL}?model={self.model}" - - headers = [("Authorization", f"Bearer {self.api_key}")] - if "organization" in self.config: - headers.append(("OpenAI-Organization", self.config["organization"])) - if "project" in self.config: - headers.append(("OpenAI-Project", self.config["project"])) - - websocket = await websockets.connect(url, additional_headers=headers) - logger.info("WebSocket connected successfully") - - session = OpenAIRealtimeSession(websocket, self.config) - await session.initialize(system_prompt, tools, messages) - - logger.info("OpenAI Realtime connection established") - return session - - except Exception as e: - logger.error("OpenAI connection error: %s", e) - raise \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 4b215d74e5..145710c3cb 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -88,6 +88,20 @@ class ImageInputEvent(TypedDict): encoding: Literal["base64", "raw"] +class TextInputEvent(TypedDict): + """Text input event for sending text to the model. + + Used for sending text content through the send() method. + + Attributes: + text: The text content to send to the model. + role: The role of the message sender (typically "user"). + """ + + text: str + role: Role + + class TextOutputEvent(TypedDict): """Text output event from the model during bidirectional streaming. diff --git a/tests/strands/experimental/__init__.py b/tests/strands/experimental/__init__.py index e69de29bb2..ac8db1d744 100644 --- a/tests/strands/experimental/__init__.py +++ b/tests/strands/experimental/__init__.py @@ -0,0 +1 @@ +"""Experimental features tests.""" diff --git a/tests/strands/experimental/bidirectional_streaming/__init__.py b/tests/strands/experimental/bidirectional_streaming/__init__.py new file mode 100644 index 0000000000..ea37091cc1 --- /dev/null +++ b/tests/strands/experimental/bidirectional_streaming/__init__.py @@ -0,0 +1 @@ +"""Bidirectional streaming tests.""" diff --git a/tests/strands/experimental/bidirectional_streaming/models/__init__.py b/tests/strands/experimental/bidirectional_streaming/models/__init__.py new file mode 100644 index 0000000000..ea9fbb2d01 --- /dev/null +++ b/tests/strands/experimental/bidirectional_streaming/models/__init__.py @@ -0,0 +1 @@ +"""Bidirectional streaming model tests.""" diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py new file mode 100644 index 0000000000..a5baaa522b --- /dev/null +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -0,0 +1,500 @@ +"""Unit tests for Gemini Live bidirectional streaming model. + +Tests the unified GeminiLiveBidirectionalModel interface including: +- Model initialization and configuration +- Connection establishment +- Unified send() method with different content types +- Event receiving and conversion +- Connection lifecycle management +""" + +import unittest.mock +import uuid + +import pytest +from google import genai +from google.genai import types as genai_types + +from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveBidirectionalModel +from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( + AudioInputEvent, + ImageInputEvent, + TextInputEvent, +) +from strands.types.tools import ToolResult + + +@pytest.fixture +def mock_genai_client(): + """Mock the Google GenAI client.""" + with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.gemini_live.genai.Client") as mock_client_cls: + mock_client = mock_client_cls.return_value + mock_client.aio = unittest.mock.MagicMock() + + # Mock the live session + mock_live_session = unittest.mock.AsyncMock() + + # Mock the context manager + mock_live_session_cm = unittest.mock.MagicMock() + mock_live_session_cm.__aenter__ = unittest.mock.AsyncMock(return_value=mock_live_session) + mock_live_session_cm.__aexit__ = unittest.mock.AsyncMock(return_value=None) + + # Make connect return the context manager + mock_client.aio.live.connect = unittest.mock.MagicMock(return_value=mock_live_session_cm) + + yield mock_client, mock_live_session, mock_live_session_cm + + +@pytest.fixture +def model_id(): + return "models/gemini-2.0-flash-live-preview-04-09" + + +@pytest.fixture +def api_key(): + return "test-api-key" + + +@pytest.fixture +def model(mock_genai_client, model_id, api_key): + """Create a GeminiLiveBidirectionalModel instance.""" + _ = mock_genai_client + return GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + + +@pytest.fixture +def tool_spec(): + return { + "description": "Calculate mathematical expressions", + "name": "calculator", + "inputSchema": {"json": {"type": "object", "properties": {"expression": {"type": "string"}}}}, + } + + +@pytest.fixture +def system_prompt(): + return "You are a helpful assistant" + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "Hello"}]}] + + +# Initialization Tests + + +def test_init_default_config(mock_genai_client): + """Test model initialization with default configuration.""" + _ = mock_genai_client + + model = GeminiLiveBidirectionalModel() + + assert model.model_id == "models/gemini-2.0-flash-live-preview-04-09" + assert model.api_key is None + assert model._active is False + assert model.live_session is None + + +def test_init_with_api_key(mock_genai_client, model_id, api_key): + """Test model initialization with API key.""" + mock_client, _, _ = mock_genai_client + + model = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + + assert model.model_id == model_id + assert model.api_key == api_key + + # Verify client was created with correct parameters + mock_client_cls = unittest.mock.patch("strands.experimental.bidirectional_streaming.models.gemini_live.genai.Client").start() + GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + mock_client_cls.assert_called() + + +def test_init_with_custom_config(mock_genai_client, model_id): + """Test model initialization with custom configuration.""" + _ = mock_genai_client + + custom_config = {"temperature": 0.7, "top_p": 0.9} + model = GeminiLiveBidirectionalModel(model_id=model_id, **custom_config) + + assert model.config == custom_config + + +# Connection Tests + + +@pytest.mark.asyncio +async def test_connect_basic(mock_genai_client, model): + """Test basic connection establishment.""" + mock_client, mock_live_session, _ = mock_genai_client + + await model.connect() + + assert model._active is True + assert model.session_id is not None + assert model.live_session == mock_live_session + mock_client.aio.live.connect.assert_called_once() + + +@pytest.mark.asyncio +async def test_connect_with_system_prompt(mock_genai_client, model, system_prompt): + """Test connection with system prompt.""" + mock_client, _, _ = mock_genai_client + + await model.connect(system_prompt=system_prompt) + + # Verify system prompt was included in config + call_args = mock_client.aio.live.connect.call_args + config = call_args.kwargs.get("config", {}) + assert config.get("system_instruction") == system_prompt + + +@pytest.mark.asyncio +async def test_connect_with_tools(mock_genai_client, model, tool_spec): + """Test connection with tools.""" + mock_client, _, _ = mock_genai_client + + await model.connect(tools=[tool_spec]) + + # Verify tools were formatted and included + call_args = mock_client.aio.live.connect.call_args + config = call_args.kwargs.get("config", {}) + assert "tools" in config + assert len(config["tools"]) > 0 + + +@pytest.mark.asyncio +async def test_connect_with_messages(mock_genai_client, model, messages): + """Test connection with message history.""" + _, mock_live_session, _ = mock_genai_client + + await model.connect(messages=messages) + + # Verify message history was sent + mock_live_session.send_client_content.assert_called() + + +@pytest.mark.asyncio +async def test_connect_error_handling(mock_genai_client, model): + """Test connection error handling.""" + mock_client, _, _ = mock_genai_client + mock_client.aio.live.connect.side_effect = Exception("Connection failed") + + with pytest.raises(Exception, match="Connection failed"): + await model.connect() + + +# Send Method Tests + + +@pytest.mark.asyncio +async def test_send_text_input(mock_genai_client, model): + """Test sending text input through unified send() method.""" + _, mock_live_session, _ = mock_genai_client + await model.connect() + + text_input: TextInputEvent = {"text": "Hello", "role": "user"} + await model.send(text_input) + + # Verify text was sent via send_client_content + mock_live_session.send_client_content.assert_called_once() + call_args = mock_live_session.send_client_content.call_args + content = call_args.kwargs.get("turns") + assert content.role == "user" + assert content.parts[0].text == "Hello" + + +@pytest.mark.asyncio +async def test_send_audio_input(mock_genai_client, model): + """Test sending audio input through unified send() method.""" + _, mock_live_session, _ = mock_genai_client + await model.connect() + + audio_input: AudioInputEvent = { + "audioData": b"audio_bytes", + "format": "pcm", + "sampleRate": 16000, + "channels": 1, + } + await model.send(audio_input) + + # Verify audio was sent via send_realtime_input + mock_live_session.send_realtime_input.assert_called_once() + + +@pytest.mark.asyncio +async def test_send_image_input(mock_genai_client, model): + """Test sending image input through unified send() method.""" + _, mock_live_session, _ = mock_genai_client + await model.connect() + + image_input: ImageInputEvent = { + "imageData": b"image_bytes", + "mimeType": "image/jpeg", + "encoding": "raw", + } + await model.send(image_input) + + # Verify image was sent + mock_live_session.send.assert_called_once() + + +@pytest.mark.asyncio +async def test_send_tool_result(mock_genai_client, model): + """Test sending tool result through unified send() method.""" + _, mock_live_session, _ = mock_genai_client + await model.connect() + + tool_result: ToolResult = { + "toolUseId": "tool-123", + "status": "success", + "content": [{"text": "Result: 42"}], + } + await model.send(tool_result) + + # Verify tool result was sent + mock_live_session.send_tool_response.assert_called_once() + + +@pytest.mark.asyncio +async def test_send_when_inactive(mock_genai_client, model): + """Test that send() does nothing when connection is inactive.""" + _, mock_live_session, _ = mock_genai_client + + # Don't connect, so _active is False + text_input: TextInputEvent = {"text": "Hello", "role": "user"} + await model.send(text_input) + + # Verify nothing was sent + mock_live_session.send_client_content.assert_not_called() + + +@pytest.mark.asyncio +async def test_send_unknown_content_type(mock_genai_client, model): + """Test sending unknown content type logs warning.""" + _, _, _ = mock_genai_client + await model.connect() + + unknown_content = {"unknown_field": "value"} + + # Should not raise, just log warning + await model.send(unknown_content) + + +# Receive Method Tests + + +@pytest.mark.asyncio +async def test_receive_connection_start_event(mock_genai_client, model, agenerator): + """Test that receive() emits connection start event.""" + _, mock_live_session, _ = mock_genai_client + mock_live_session.receive.return_value = agenerator([]) + + await model.connect() + + # Get first event + receive_gen = model.receive() + first_event = await anext(receive_gen) + + # First event should be connection start + assert "BidirectionalConnectionStart" in first_event + assert first_event["BidirectionalConnectionStart"]["connectionId"] == model.session_id + + # Close to stop the loop + await model.close() + + +@pytest.mark.asyncio +async def test_receive_connection_end_event(mock_genai_client, model, agenerator): + """Test that receive() emits connection end event.""" + _, mock_live_session, _ = mock_genai_client + mock_live_session.receive.return_value = agenerator([]) + + await model.connect() + + # Collect events until connection ends + events = [] + async for event in model.receive(): + events.append(event) + # Close after first event to trigger connection end + if len(events) == 1: + await model.close() + + # Last event should be connection end + assert "BidirectionalConnectionEnd" in events[-1] + + +@pytest.mark.asyncio +async def test_receive_text_output(mock_genai_client, model): + """Test receiving text output from model.""" + _, mock_live_session, _ = mock_genai_client + + mock_message = unittest.mock.Mock() + mock_message.text = "Hello from Gemini" + mock_message.data = None + mock_message.tool_call = None + mock_message.server_content = None + + await model.connect() + + # Test the conversion method directly + converted_event = model._convert_gemini_live_event(mock_message) + + assert "textOutput" in converted_event + assert converted_event["textOutput"]["text"] == "Hello from Gemini" + assert converted_event["textOutput"]["role"] == "assistant" + + +@pytest.mark.asyncio +async def test_receive_audio_output(mock_genai_client, model): + """Test receiving audio output from model.""" + _, mock_live_session, _ = mock_genai_client + + mock_message = unittest.mock.Mock() + mock_message.text = None + mock_message.data = b"audio_data" + mock_message.tool_call = None + mock_message.server_content = None + + await model.connect() + + # Test the conversion method directly + converted_event = model._convert_gemini_live_event(mock_message) + + assert "audioOutput" in converted_event + assert converted_event["audioOutput"]["audioData"] == b"audio_data" + assert converted_event["audioOutput"]["format"] == "pcm" + + +@pytest.mark.asyncio +async def test_receive_tool_call(mock_genai_client, model): + """Test receiving tool call from model.""" + _, mock_live_session, _ = mock_genai_client + + mock_func_call = unittest.mock.Mock() + mock_func_call.id = "tool-123" + mock_func_call.name = "calculator" + mock_func_call.args = {"expression": "2+2"} + + mock_tool_call = unittest.mock.Mock() + mock_tool_call.function_calls = [mock_func_call] + + mock_message = unittest.mock.Mock() + mock_message.text = None + mock_message.data = None + mock_message.tool_call = mock_tool_call + mock_message.server_content = None + + await model.connect() + + # Test the conversion method directly + converted_event = model._convert_gemini_live_event(mock_message) + + assert "toolUse" in converted_event + assert converted_event["toolUse"]["toolUseId"] == "tool-123" + assert converted_event["toolUse"]["name"] == "calculator" + + +@pytest.mark.asyncio +async def test_receive_interruption(mock_genai_client, model): + """Test receiving interruption event.""" + _, mock_live_session, _ = mock_genai_client + + mock_server_content = unittest.mock.Mock() + mock_server_content.interrupted = True + mock_server_content.input_transcription = None + mock_server_content.output_transcription = None + + mock_message = unittest.mock.Mock() + mock_message.text = None + mock_message.data = None + mock_message.tool_call = None + mock_message.server_content = mock_server_content + + await model.connect() + + # Test the conversion method directly + converted_event = model._convert_gemini_live_event(mock_message) + + assert "interruptionDetected" in converted_event + assert converted_event["interruptionDetected"]["reason"] == "user_input" + + +# Close Method Tests + + +@pytest.mark.asyncio +async def test_close_connection(mock_genai_client, model): + """Test closing connection.""" + _, _, mock_live_session_cm = mock_genai_client + + await model.connect() + await model.close() + + assert model._active is False + mock_live_session_cm.__aexit__.assert_called_once() + + +@pytest.mark.asyncio +async def test_close_when_not_connected(mock_genai_client, model): + """Test closing when not connected does nothing.""" + _, _, mock_live_session_cm = mock_genai_client + + # Don't connect + await model.close() + + # Should not raise, and __aexit__ should not be called + mock_live_session_cm.__aexit__.assert_not_called() + + +@pytest.mark.asyncio +async def test_close_error_handling(mock_genai_client, model): + """Test close error handling.""" + _, _, mock_live_session_cm = mock_genai_client + mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") + + await model.connect() + + with pytest.raises(Exception, match="Close failed"): + await model.close() + + +# Helper Method Tests + + +def test_build_live_config_basic(model): + """Test building basic live config.""" + config = model._build_live_config() + + assert isinstance(config, dict) + + +def test_build_live_config_with_system_prompt(model, system_prompt): + """Test building config with system prompt.""" + config = model._build_live_config(system_prompt=system_prompt) + + assert config["system_instruction"] == system_prompt + + +def test_build_live_config_with_tools(model, tool_spec): + """Test building config with tools.""" + config = model._build_live_config(tools=[tool_spec]) + + assert "tools" in config + assert len(config["tools"]) > 0 + + +def test_format_tools_for_live_api(model, tool_spec): + """Test tool formatting for Gemini Live API.""" + formatted_tools = model._format_tools_for_live_api([tool_spec]) + + assert len(formatted_tools) == 1 + assert isinstance(formatted_tools[0], genai_types.Tool) + + +def test_format_tools_empty_list(model): + """Test formatting empty tool list.""" + formatted_tools = model._format_tools_for_live_api([]) + + assert formatted_tools == [] diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py new file mode 100644 index 0000000000..451a98aa2a --- /dev/null +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -0,0 +1,551 @@ +"""Unit tests for Nova Sonic bidirectional model implementation. + +Tests the unified BidirectionalModel interface implementation for Amazon Nova Sonic, +covering connection lifecycle, event conversion, audio streaming, and tool execution. +""" + +import asyncio +import base64 +import json +import uuid +from typing import Any, Dict +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest +import pytest_asyncio + +from strands.experimental.bidirectional_streaming.models.novasonic import ( + NovaSonicBidirectionalModel, +) +from strands.types.tools import ToolResult, ToolSpec + + +# Test fixtures +@pytest.fixture +def model_id(): + """Nova Sonic model identifier.""" + return "amazon.nova-sonic-v1:0" + + +@pytest.fixture +def region(): + """AWS region.""" + return "us-east-1" + + +@pytest.fixture +def mock_stream(): + """Mock Nova Sonic bidirectional stream.""" + stream = AsyncMock() + stream.input_stream = AsyncMock() + stream.input_stream.send = AsyncMock() + stream.input_stream.close = AsyncMock() + stream.await_output = AsyncMock() + return stream + + +@pytest.fixture +def mock_client(mock_stream): + """Mock Bedrock Runtime client.""" + client = AsyncMock() + client.invoke_model_with_bidirectional_stream = AsyncMock(return_value=mock_stream) + return client + + +@pytest_asyncio.fixture +async def nova_model(model_id, region): + """Create Nova Sonic model instance.""" + model = NovaSonicBidirectionalModel(model_id=model_id, region=region) + yield model + # Cleanup + if model._active: + await model.close() + + +# Connection lifecycle tests +@pytest.mark.asyncio +async def test_model_initialization(model_id, region): + """Test model initialization with configuration.""" + model = NovaSonicBidirectionalModel(model_id=model_id, region=region) + + assert model.model_id == model_id + assert model.region == region + assert model.stream is None + assert not model._active + assert model.prompt_name is None + + +@pytest.mark.asyncio +async def test_connect_establishes_connection(nova_model, mock_client, mock_stream): + """Test that connect() establishes bidirectional connection.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + await nova_model.connect(system_prompt="Test system prompt") + + assert nova_model._active + assert nova_model.stream == mock_stream + assert nova_model.prompt_name is not None + assert mock_client.invoke_model_with_bidirectional_stream.called + + +@pytest.mark.asyncio +async def test_connect_sends_initialization_events(nova_model, mock_client, mock_stream): + """Test that connect() sends proper initialization sequence.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + system_prompt = "You are a helpful assistant" + tools = [ + { + "name": "get_weather", + "description": "Get weather information", + "inputSchema": {"json": json.dumps({"type": "object", "properties": {}})} + } + ] + + await nova_model.connect(system_prompt=system_prompt, tools=tools) + + # Verify initialization events were sent + assert mock_stream.input_stream.send.call_count >= 3 # connectionStart, promptStart, system prompt + + +@pytest.mark.asyncio +async def test_close_cleanup(nova_model, mock_client, mock_stream): + """Test that close() properly cleans up resources.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + await nova_model.connect() + await nova_model.close() + + assert not nova_model._active + assert mock_stream.input_stream.close.called + + +# Event conversion tests +@pytest.mark.asyncio +async def test_receive_emits_connection_start(nova_model, mock_client, mock_stream): + """Test that receive() emits connection start event.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + # Setup mock to return no events and then stop + async def mock_wait_for(*args, **kwargs): + await asyncio.sleep(0.1) + nova_model._active = False + raise asyncio.TimeoutError() + + with patch("asyncio.wait_for", side_effect=mock_wait_for): + await nova_model.connect() + + events = [] + async for event in nova_model.receive(): + events.append(event) + + # Should have connection start and end + assert len(events) >= 2 + assert "BidirectionalConnectionStart" in events[0] + assert events[0]["BidirectionalConnectionStart"]["connectionId"] == nova_model.prompt_name + + +@pytest.mark.asyncio +async def test_convert_audio_output_event(nova_model): + """Test conversion of Nova Sonic audio output to standard format.""" + audio_bytes = b"test audio data" + audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") + + nova_event = { + "audioOutput": { + "content": audio_base64 + } + } + + result = nova_model._convert_nova_event(nova_event) + + assert result is not None + assert "audioOutput" in result + assert result["audioOutput"]["audioData"] == audio_bytes + assert result["audioOutput"]["format"] == "pcm" + assert result["audioOutput"]["sampleRate"] == 24000 + + +@pytest.mark.asyncio +async def test_convert_text_output_event(nova_model): + """Test conversion of Nova Sonic text output to standard format.""" + nova_event = { + "textOutput": { + "content": "Hello, world!", + "role": "ASSISTANT" + } + } + + result = nova_model._convert_nova_event(nova_event) + + assert result is not None + assert "textOutput" in result + assert result["textOutput"]["text"] == "Hello, world!" + assert result["textOutput"]["role"] == "assistant" + + +@pytest.mark.asyncio +async def test_convert_tool_use_event(nova_model): + """Test conversion of Nova Sonic tool use to standard format.""" + tool_input = {"location": "Seattle"} + nova_event = { + "toolUse": { + "toolUseId": "tool-123", + "toolName": "get_weather", + "content": json.dumps(tool_input) + } + } + + result = nova_model._convert_nova_event(nova_event) + + assert result is not None + assert "toolUse" in result + assert result["toolUse"]["toolUseId"] == "tool-123" + assert result["toolUse"]["name"] == "get_weather" + assert result["toolUse"]["input"] == tool_input + + +@pytest.mark.asyncio +async def test_convert_interruption_event(nova_model): + """Test conversion of Nova Sonic interruption to standard format.""" + nova_event = { + "stopReason": "INTERRUPTED" + } + + result = nova_model._convert_nova_event(nova_event) + + assert result is not None + assert "interruptionDetected" in result + assert result["interruptionDetected"]["reason"] == "user_input" + + +@pytest.mark.asyncio +async def test_convert_usage_metrics_event(nova_model): + """Test conversion of Nova Sonic usage event to standard format.""" + nova_event = { + "usageEvent": { + "totalTokens": 100, + "totalInputTokens": 40, + "totalOutputTokens": 60, + "details": { + "total": { + "output": { + "speechTokens": 30 + } + } + } + } + } + + result = nova_model._convert_nova_event(nova_event) + + assert result is not None + assert "usageMetrics" in result + assert result["usageMetrics"]["totalTokens"] == 100 + assert result["usageMetrics"]["inputTokens"] == 40 + assert result["usageMetrics"]["outputTokens"] == 60 + assert result["usageMetrics"]["audioTokens"] == 30 + + +@pytest.mark.asyncio +async def test_convert_content_start_tracks_role(nova_model): + """Test that contentStart events track role for subsequent text output.""" + nova_event = { + "contentStart": { + "role": "USER" + } + } + + result = nova_model._convert_nova_event(nova_event) + + # contentStart doesn't emit an event but stores role + assert result is None + assert nova_model._current_role == "USER" + + +# Send method tests +@pytest.mark.asyncio +async def test_send_text_content(nova_model, mock_client, mock_stream): + """Test sending text content through unified send() method.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + await nova_model.connect() + + text_event = { + "text": "Hello, Nova!", + "role": "user" + } + + await nova_model.send(text_event) + + # Should send contentStart, textInput, and contentEnd + assert mock_stream.input_stream.send.call_count >= 3 + + +@pytest.mark.asyncio +async def test_send_audio_content(nova_model, mock_client, mock_stream): + """Test sending audio content through unified send() method.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + await nova_model.connect() + + audio_event = { + "audioData": b"audio data", + "format": "pcm", + "sampleRate": 16000, + "channels": 1 + } + + await nova_model.send(audio_event) + + # Should start audio connection and send audio + assert nova_model.audio_connection_active + assert mock_stream.input_stream.send.called + + +@pytest.mark.asyncio +async def test_send_tool_result(nova_model, mock_client, mock_stream): + """Test sending tool result through unified send() method.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + await nova_model.connect() + + tool_result = { + "toolUseId": "tool-123", + "status": "success", + "content": [{"text": "Weather is sunny"}] + } + + await nova_model.send(tool_result) + + # Should send contentStart, toolResult, and contentEnd + assert mock_stream.input_stream.send.call_count >= 3 + + +@pytest.mark.asyncio +async def test_send_image_content_not_supported(nova_model, mock_client, mock_stream, caplog): + """Test that image content logs warning (not supported by Nova Sonic).""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + await nova_model.connect() + + image_event = { + "imageData": b"image data", + "mimeType": "image/jpeg" + } + + await nova_model.send(image_event) + + # Should log warning about unsupported image input + assert any("not supported" in record.message.lower() for record in caplog.records) + + +# Audio streaming tests +@pytest.mark.asyncio +async def test_audio_connection_lifecycle(nova_model, mock_client, mock_stream): + """Test audio connection start and end lifecycle.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + await nova_model.connect() + + # Start audio connection + await nova_model._start_audio_connection() + assert nova_model.audio_connection_active + + # End audio connection + await nova_model._end_audio_input() + assert not nova_model.audio_connection_active + + +@pytest.mark.asyncio +async def test_silence_detection_ends_audio(nova_model, mock_client, mock_stream): + """Test that silence detection automatically ends audio input.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + nova_model.silence_threshold = 0.1 # Short threshold for testing + + await nova_model.connect() + + # Send audio to start connection + audio_event = { + "audioData": b"audio data", + "format": "pcm", + "sampleRate": 16000, + "channels": 1 + } + + await nova_model.send(audio_event) + assert nova_model.audio_connection_active + + # Wait for silence detection + await asyncio.sleep(0.2) + + # Audio connection should be ended + assert not nova_model.audio_connection_active + + +# Tool configuration tests +@pytest.mark.asyncio +async def test_build_tool_configuration(nova_model): + """Test building tool configuration from tool specs.""" + tools = [ + { + "name": "get_weather", + "description": "Get weather information", + "inputSchema": { + "json": json.dumps({ + "type": "object", + "properties": { + "location": {"type": "string"} + } + }) + } + } + ] + + tool_config = nova_model._build_tool_configuration(tools) + + assert len(tool_config) == 1 + assert tool_config[0]["toolSpec"]["name"] == "get_weather" + assert tool_config[0]["toolSpec"]["description"] == "Get weather information" + assert "inputSchema" in tool_config[0]["toolSpec"] + + +# Event template tests +@pytest.mark.asyncio +async def test_get_connection_start_event(nova_model): + """Test connection start event generation.""" + event_json = nova_model._get_connection_start_event() + event = json.loads(event_json) + + assert "event" in event + assert "sessionStart" in event["event"] + assert "inferenceConfiguration" in event["event"]["sessionStart"] + + +@pytest.mark.asyncio +async def test_get_prompt_start_event(nova_model): + """Test prompt start event generation.""" + nova_model.prompt_name = "test-prompt" + + event_json = nova_model._get_prompt_start_event([]) + event = json.loads(event_json) + + assert "event" in event + assert "promptStart" in event["event"] + assert event["event"]["promptStart"]["promptName"] == "test-prompt" + + +@pytest.mark.asyncio +async def test_get_text_input_event(nova_model): + """Test text input event generation.""" + nova_model.prompt_name = "test-prompt" + content_name = "test-content" + + event_json = nova_model._get_text_input_event(content_name, "Hello") + event = json.loads(event_json) + + assert "event" in event + assert "textInput" in event["event"] + assert event["event"]["textInput"]["content"] == "Hello" + + +@pytest.mark.asyncio +async def test_get_tool_result_event(nova_model): + """Test tool result event generation.""" + nova_model.prompt_name = "test-prompt" + content_name = "test-content" + result = {"result": "Success"} + + event_json = nova_model._get_tool_result_event(content_name, result) + event = json.loads(event_json) + + assert "event" in event + assert "toolResult" in event["event"] + assert json.loads(event["event"]["toolResult"]["content"]) == result + + +# Error handling tests +@pytest.mark.asyncio +async def test_send_when_inactive(nova_model): + """Test that send() handles inactive connection gracefully.""" + text_event = { + "text": "Hello", + "role": "user" + } + + # Should not raise error when inactive + await nova_model.send(text_event) + + +@pytest.mark.asyncio +async def test_close_when_already_closed(nova_model): + """Test that close() handles already closed connection.""" + # Should not raise error when already inactive + await nova_model.close() + await nova_model.close() # Second call should be safe + + +@pytest.mark.asyncio +async def test_response_processor_handles_errors(nova_model, mock_client, mock_stream): + """Test that response processor handles errors gracefully.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + # Setup mock to raise error + async def mock_error(*args, **kwargs): + raise Exception("Test error") + + mock_stream.await_output.side_effect = mock_error + + await nova_model.connect() + + # Wait a bit for response processor to handle error + await asyncio.sleep(0.1) + + # Should still be able to close cleanly + await nova_model.close() + + +# Integration-style tests +@pytest.mark.asyncio +async def test_full_conversation_flow(nova_model, mock_client, mock_stream): + """Test a complete conversation flow with text and audio.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + # Connect + await nova_model.connect(system_prompt="You are helpful") + + # Send text + await nova_model.send({"text": "Hello", "role": "user"}) + + # Send audio + await nova_model.send({ + "audioData": b"audio", + "format": "pcm", + "sampleRate": 16000, + "channels": 1 + }) + + # Send tool result + await nova_model.send({ + "toolUseId": "tool-1", + "status": "success", + "content": [{"text": "Result"}] + }) + + # Close + await nova_model.close() + + # Verify all operations completed + assert not nova_model._active diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py new file mode 100644 index 0000000000..be69929cdd --- /dev/null +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -0,0 +1,625 @@ +"""Unit tests for OpenAI Realtime bidirectional streaming model. + +Tests the unified OpenAIRealtimeBidirectionalModel interface including: +- Model initialization and configuration +- Connection establishment with WebSocket +- Unified send() method with different content types +- Event receiving and conversion +- Connection lifecycle management +- Background task management +""" + +import asyncio +import base64 +import json +import unittest.mock + +import pytest + +from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeBidirectionalModel +from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( + AudioInputEvent, + ImageInputEvent, + TextInputEvent, +) +from strands.types.tools import ToolResult + + +@pytest.fixture +def mock_websocket(): + """Mock WebSocket connection.""" + mock_ws = unittest.mock.AsyncMock() + mock_ws.send = unittest.mock.AsyncMock() + mock_ws.close = unittest.mock.AsyncMock() + return mock_ws + + +@pytest.fixture +def mock_websockets_connect(mock_websocket): + """Mock websockets.connect function.""" + async def async_connect(*args, **kwargs): + return mock_websocket + + with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.websockets.connect") as mock_connect: + mock_connect.side_effect = async_connect + yield mock_connect, mock_websocket + + +@pytest.fixture +def model_name(): + return "gpt-realtime" + + +@pytest.fixture +def api_key(): + return "test-api-key" + + +@pytest.fixture +def model(api_key, model_name): + """Create an OpenAIRealtimeBidirectionalModel instance.""" + return OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + + +@pytest.fixture +def tool_spec(): + return { + "description": "Calculate mathematical expressions", + "name": "calculator", + "inputSchema": {"json": {"type": "object", "properties": {"expression": {"type": "string"}}}}, + } + + +@pytest.fixture +def system_prompt(): + return "You are a helpful assistant" + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "Hello"}]}] + + +# Initialization Tests + + +def test_init_default_config(): + """Test model initialization with default configuration.""" + model = OpenAIRealtimeBidirectionalModel(api_key="test-key") + + assert model.model == "gpt-realtime" + assert model.api_key == "test-key" + assert model._active is False + assert model.websocket is None + + +def test_init_with_api_key(api_key, model_name): + """Test model initialization with API key.""" + model = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + + assert model.model == model_name + assert model.api_key == api_key + + +def test_init_with_custom_config(model_name, api_key): + """Test model initialization with custom configuration.""" + custom_config = {"organization": "org-123", "project": "proj-456"} + model = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key, **custom_config) + + assert model.config == custom_config + + +def test_init_without_api_key_raises(): + """Test that initialization without API key raises error.""" + with unittest.mock.patch.dict("os.environ", {}, clear=True): + with pytest.raises(ValueError, match="OpenAI API key is required"): + OpenAIRealtimeBidirectionalModel() + + +def test_init_with_env_api_key(): + """Test initialization with API key from environment.""" + with unittest.mock.patch.dict("os.environ", {"OPENAI_API_KEY": "env-key"}): + model = OpenAIRealtimeBidirectionalModel() + assert model.api_key == "env-key" + + +# Connection Tests + + +@pytest.mark.asyncio +async def test_connect_basic(mock_websockets_connect, model): + """Test basic connection establishment.""" + mock_connect, mock_ws = mock_websockets_connect + + await model.connect() + + assert model._active is True + assert model.session_id is not None + assert model.websocket == mock_ws + assert model._event_queue is not None + mock_connect.assert_called_once() + + +@pytest.mark.asyncio +async def test_connect_with_system_prompt(mock_websockets_connect, model, system_prompt): + """Test connection with system prompt.""" + _, mock_ws = mock_websockets_connect + + await model.connect(system_prompt=system_prompt) + + # Verify session.update was sent with system prompt + calls = mock_ws.send.call_args_list + session_update_call = None + for call in calls: + message = json.loads(call[0][0]) + if message.get("type") == "session.update": + session_update_call = message + break + + assert session_update_call is not None + assert session_update_call["session"]["instructions"] == system_prompt + + +@pytest.mark.asyncio +async def test_connect_with_tools(mock_websockets_connect, model, tool_spec): + """Test connection with tools.""" + _, mock_ws = mock_websockets_connect + + await model.connect(tools=[tool_spec]) + + # Verify tools were included in session config + calls = mock_ws.send.call_args_list + session_update_call = None + for call in calls: + message = json.loads(call[0][0]) + if message.get("type") == "session.update": + session_update_call = message + break + + assert session_update_call is not None + assert "tools" in session_update_call["session"] + + +@pytest.mark.asyncio +async def test_connect_with_messages(mock_websockets_connect, model, messages): + """Test connection with message history.""" + _, mock_ws = mock_websockets_connect + + await model.connect(messages=messages) + + # Verify conversation items were created + calls = mock_ws.send.call_args_list + item_create_calls = [ + json.loads(call[0][0]) for call in calls + if json.loads(call[0][0]).get("type") == "conversation.item.create" + ] + + assert len(item_create_calls) > 0 + + +@pytest.mark.asyncio +async def test_connect_error_handling(mock_websockets_connect, model): + """Test connection error handling.""" + mock_connect, _ = mock_websockets_connect + mock_connect.side_effect = Exception("Connection failed") + + with pytest.raises(Exception, match="Connection failed"): + await model.connect() + + +@pytest.mark.asyncio +async def test_connect_with_organization_header(mock_websockets_connect, api_key): + """Test connection includes organization header.""" + mock_connect, _ = mock_websockets_connect + + model = OpenAIRealtimeBidirectionalModel( + api_key=api_key, + organization="org-123" + ) + await model.connect() + + # Verify headers were passed + call_kwargs = mock_connect.call_args.kwargs + headers = call_kwargs.get("additional_headers", []) + org_header = [h for h in headers if h[0] == "OpenAI-Organization"] + assert len(org_header) == 1 + assert org_header[0][1] == "org-123" + + +# Send Method Tests + + +@pytest.mark.asyncio +async def test_send_text_input(mock_websockets_connect, model): + """Test sending text input through unified send() method.""" + _, mock_ws = mock_websockets_connect + await model.connect() + + text_input: TextInputEvent = {"text": "Hello", "role": "user"} + await model.send(text_input) + + # Verify conversation.item.create and response.create were sent + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + + item_create = [m for m in messages if m.get("type") == "conversation.item.create"] + response_create = [m for m in messages if m.get("type") == "response.create"] + + assert len(item_create) > 0 + assert len(response_create) > 0 + + +@pytest.mark.asyncio +async def test_send_audio_input(mock_websockets_connect, model): + """Test sending audio input through unified send() method.""" + _, mock_ws = mock_websockets_connect + await model.connect() + + audio_input: AudioInputEvent = { + "audioData": b"audio_bytes", + "format": "pcm", + "sampleRate": 24000, + "channels": 1, + } + await model.send(audio_input) + + # Verify input_audio_buffer.append was sent + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + + audio_append = [m for m in messages if m.get("type") == "input_audio_buffer.append"] + assert len(audio_append) > 0 + + # Verify audio was base64 encoded + assert "audio" in audio_append[0] + decoded = base64.b64decode(audio_append[0]["audio"]) + assert decoded == b"audio_bytes" + + +@pytest.mark.asyncio +async def test_send_image_input(mock_websockets_connect, model): + """Test sending image input logs warning (not supported).""" + _, mock_ws = mock_websockets_connect + await model.connect() + + image_input: ImageInputEvent = { + "imageData": b"image_bytes", + "mimeType": "image/jpeg", + "encoding": "raw", + } + + with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.logger") as mock_logger: + await model.send(image_input) + mock_logger.warning.assert_called_with("Image input not supported by OpenAI Realtime API") + + +@pytest.mark.asyncio +async def test_send_tool_result(mock_websockets_connect, model): + """Test sending tool result through unified send() method.""" + _, mock_ws = mock_websockets_connect + await model.connect() + + tool_result: ToolResult = { + "toolUseId": "tool-123", + "status": "success", + "content": [{"text": "Result: 42"}], + } + await model.send(tool_result) + + # Verify function_call_output was created + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + + item_create = [m for m in messages if m.get("type") == "conversation.item.create"] + assert len(item_create) > 0 + + # Verify it's a function_call_output + item = item_create[-1].get("item", {}) + assert item.get("type") == "function_call_output" + assert item.get("call_id") == "tool-123" + + +@pytest.mark.asyncio +async def test_send_when_inactive(mock_websockets_connect, model): + """Test that send() does nothing when connection is inactive.""" + _, mock_ws = mock_websockets_connect + + # Don't connect, so _active is False + text_input: TextInputEvent = {"text": "Hello", "role": "user"} + await model.send(text_input) + + # Verify nothing was sent + mock_ws.send.assert_not_called() + + +@pytest.mark.asyncio +async def test_send_unknown_content_type(mock_websockets_connect, model): + """Test sending unknown content type logs warning.""" + _, _ = mock_websockets_connect + await model.connect() + + unknown_content = {"unknown_field": "value"} + + with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.logger") as mock_logger: + await model.send(unknown_content) + # Should log warning about unknown content + assert mock_logger.warning.called + + +# Receive Method Tests + + +@pytest.mark.asyncio +async def test_receive_connection_start_event(mock_websockets_connect, model): + """Test that receive() emits connection start event.""" + _, _ = mock_websockets_connect + + await model.connect() + + # Get first event + receive_gen = model.receive() + first_event = await anext(receive_gen) + + # First event should be connection start + assert "BidirectionalConnectionStart" in first_event + assert first_event["BidirectionalConnectionStart"]["connectionId"] == model.session_id + + # Close to stop the loop + await model.close() + + +@pytest.mark.asyncio +async def test_receive_connection_end_event(mock_websockets_connect, model): + """Test that receive() emits connection end event.""" + _, _ = mock_websockets_connect + + await model.connect() + + # Collect events until connection ends + events = [] + async for event in model.receive(): + events.append(event) + # Close after first event to trigger connection end + if len(events) == 1: + await model.close() + + # Last event should be connection end + assert "BidirectionalConnectionEnd" in events[-1] + + +@pytest.mark.asyncio +async def test_receive_audio_output(mock_websockets_connect, model): + """Test receiving audio output from model.""" + _, _ = mock_websockets_connect + await model.connect() + + # Create mock OpenAI event + openai_event = { + "type": "response.output_audio.delta", + "delta": base64.b64encode(b"audio_data").decode() + } + + # Test conversion directly + converted_event = model._convert_openai_event(openai_event) + + assert "audioOutput" in converted_event + assert converted_event["audioOutput"]["audioData"] == b"audio_data" + assert converted_event["audioOutput"]["format"] == "pcm" + + +@pytest.mark.asyncio +async def test_receive_text_output(mock_websockets_connect, model): + """Test receiving text output from model.""" + _, _ = mock_websockets_connect + await model.connect() + + # Create mock OpenAI event + openai_event = { + "type": "response.output_text.delta", + "delta": "Hello from OpenAI" + } + + # Test conversion directly + converted_event = model._convert_openai_event(openai_event) + + assert "textOutput" in converted_event + assert converted_event["textOutput"]["text"] == "Hello from OpenAI" + assert converted_event["textOutput"]["role"] == "assistant" + + +@pytest.mark.asyncio +async def test_receive_function_call(mock_websockets_connect, model): + """Test receiving function call from model.""" + _, _ = mock_websockets_connect + await model.connect() + + # Simulate function call sequence + # First: output_item.added with function name + item_added = { + "type": "response.output_item.added", + "item": { + "type": "function_call", + "call_id": "call-123", + "name": "calculator" + } + } + model._convert_openai_event(item_added) + + # Second: function_call_arguments.delta + args_delta = { + "type": "response.function_call_arguments.delta", + "call_id": "call-123", + "delta": '{"expression": "2+2"}' + } + model._convert_openai_event(args_delta) + + # Third: function_call_arguments.done + args_done = { + "type": "response.function_call_arguments.done", + "call_id": "call-123" + } + converted_event = model._convert_openai_event(args_done) + + assert "toolUse" in converted_event + assert converted_event["toolUse"]["toolUseId"] == "call-123" + assert converted_event["toolUse"]["name"] == "calculator" + assert converted_event["toolUse"]["input"]["expression"] == "2+2" + + +@pytest.mark.asyncio +async def test_receive_voice_activity(mock_websockets_connect, model): + """Test receiving voice activity events.""" + _, _ = mock_websockets_connect + await model.connect() + + # Test speech started + speech_started = { + "type": "input_audio_buffer.speech_started" + } + converted_event = model._convert_openai_event(speech_started) + + assert "voiceActivity" in converted_event + assert converted_event["voiceActivity"]["activityType"] == "speech_started" + + +# Close Method Tests + + +@pytest.mark.asyncio +async def test_close_connection(mock_websockets_connect, model): + """Test closing connection.""" + _, mock_ws = mock_websockets_connect + + await model.connect() + await model.close() + + assert model._active is False + mock_ws.close.assert_called_once() + + +@pytest.mark.asyncio +async def test_close_when_not_connected(mock_websockets_connect, model): + """Test closing when not connected does nothing.""" + _, mock_ws = mock_websockets_connect + + # Don't connect + await model.close() + + # Should not raise, and close should not be called + mock_ws.close.assert_not_called() + + +@pytest.mark.asyncio +async def test_close_error_handling(mock_websockets_connect, model): + """Test close error handling.""" + _, mock_ws = mock_websockets_connect + mock_ws.close.side_effect = Exception("Close failed") + + await model.connect() + + # Should not raise, just log warning + await model.close() + assert model._active is False + + +@pytest.mark.asyncio +async def test_close_cancels_response_task(mock_websockets_connect, model): + """Test that close cancels the background response task.""" + _, _ = mock_websockets_connect + + await model.connect() + + # Verify response task is running + assert model._response_task is not None + assert not model._response_task.done() + + await model.close() + + # Task should be cancelled + assert model._response_task.cancelled() or model._response_task.done() + + +# Helper Method Tests + + +def test_build_session_config_basic(model): + """Test building basic session config.""" + config = model._build_session_config(None, None) + + assert isinstance(config, dict) + assert "instructions" in config + assert "audio" in config + + +def test_build_session_config_with_system_prompt(model, system_prompt): + """Test building config with system prompt.""" + config = model._build_session_config(system_prompt, None) + + assert config["instructions"] == system_prompt + + +def test_build_session_config_with_tools(model, tool_spec): + """Test building config with tools.""" + config = model._build_session_config(None, [tool_spec]) + + assert "tools" in config + assert len(config["tools"]) > 0 + + +def test_convert_tools_to_openai_format(model, tool_spec): + """Test tool conversion to OpenAI format.""" + openai_tools = model._convert_tools_to_openai_format([tool_spec]) + + assert len(openai_tools) == 1 + assert openai_tools[0]["type"] == "function" + assert openai_tools[0]["name"] == "calculator" + assert openai_tools[0]["description"] == "Calculate mathematical expressions" + + +def test_convert_tools_empty_list(model): + """Test converting empty tool list.""" + openai_tools = model._convert_tools_to_openai_format([]) + + assert openai_tools == [] + + +@pytest.mark.asyncio +async def test_send_event(mock_websockets_connect, model): + """Test sending event to WebSocket.""" + _, mock_ws = mock_websockets_connect + await model.connect() + + test_event = {"type": "test.event", "data": "test"} + await model._send_event(test_event) + + # Verify event was sent as JSON + calls = mock_ws.send.call_args_list + last_call = calls[-1] + sent_message = json.loads(last_call[0][0]) + + assert sent_message == test_event + + +def test_require_active(model): + """Test _require_active method.""" + assert model._require_active() is False + + model._active = True + assert model._require_active() is True + + +def test_create_text_event(model): + """Test creating text event.""" + event = model._create_text_event("Hello", "user") + + assert "textOutput" in event + assert event["textOutput"]["text"] == "Hello" + assert event["textOutput"]["role"] == "user" + + +def test_create_voice_activity_event(model): + """Test creating voice activity event.""" + event = model._create_voice_activity_event("speech_started") + + assert "voiceActivity" in event + assert event["voiceActivity"]["activityType"] == "speech_started" From 261e25fad104455447f8e4715820c16c8882456d Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 30 Oct 2025 14:05:38 +0100 Subject: [PATCH 02/11] fix: update bidirectional model docstrings --- .../models/bidirectional_model.py | 65 ++++++++++++------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 3af05e1138..75a4ab5f03 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -1,13 +1,15 @@ -"""Unified bidirectional streaming interface. +"""Bidirectional streaming model interface. -Single layer combining model and session abstractions for simpler implementation. -Providers implement this directly without separate model/session classes. +Defines the abstract interface for models that support real-time bidirectional +communication with persistent connections. Unlike traditional request-response +models, bidirectional models maintain an open connection for streaming audio, +text, and tool interactions. Features: -- Unified model interface (no separate session class) -- Real-time bidirectional communication +- Persistent connection management with connect/close lifecycle +- Real-time bidirectional communication (send and receive simultaneously) - Provider-agnostic event normalization -- Tool execution integration +- Support for audio, text, image, and tool result streaming """ import abc @@ -27,10 +29,11 @@ class BidirectionalModel(abc.ABC): - """Unified interface for bidirectional streaming models. + """Abstract base class for bidirectional streaming models. - Combines model configuration and session communication in a single abstraction. - Providers implement this directly without separate model/session classes. + This interface defines the contract for models that support persistent streaming + connections with real-time audio and text communication. Implementations handle + provider-specific protocols while exposing a standardized event-based API. """ @abc.abstractmethod @@ -41,48 +44,60 @@ async def connect( messages: Messages | None = None, **kwargs, ) -> None: - """Establish bidirectional connection with the model. + """Establish a persistent streaming connection with the model. - Initializes the connection state and prepares for real-time communication. - This replaces the old create_bidirectional_connection pattern. + Opens a bidirectional connection that remains active for real-time communication. + The connection supports concurrent sending and receiving of events until explicitly + closed. Must be called before any send() or receive() operations. Args: - system_prompt: System instructions for the model. - tools: List of tools available to the model. - messages: Conversation history to initialize with. + system_prompt: System instructions to configure model behavior. + tools: Tool specifications that the model can invoke during the conversation. + messages: Initial conversation history to provide context. **kwargs: Provider-specific configuration options. """ raise NotImplementedError @abc.abstractmethod async def close(self) -> None: - """Close connection and cleanup resources. + """Close the streaming connection and release resources. - Terminates the active connection and releases any held resources. + Terminates the active bidirectional connection and cleans up any associated + resources such as network connections, buffers, or background tasks. After + calling close(), the model instance cannot be used until connect() is called again. """ raise NotImplementedError @abc.abstractmethod async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: - """Receive events from the model in standardized format. + """Receive streaming events from the model. - Yields provider-agnostic events that can be processed uniformly - by the event loop. Converts provider-specific events to common format. + Continuously yields events from the model as they arrive over the connection. + Events are normalized to a provider-agnostic format for uniform processing. + This method should be called in a loop or async task to process model responses. + + The stream continues until the connection is closed or an error occurs. Yields: - BidirectionalStreamEvent: Standardized event dictionaries. + BidirectionalStreamEvent: Standardized event dictionaries containing + audio output, text responses, tool calls, or control signals. """ raise NotImplementedError @abc.abstractmethod async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: - """Send structured content to the model. + """Send content to the model over the active connection. - Unified method for sending all types of content. Implementations should - dispatch to appropriate internal handlers based on content type. + Transmits user input or tool results to the model during an active streaming + session. Supports multiple content types including text, audio, images, and + tool execution results. Can be called multiple times during a conversation. Args: - content: Typed event (TextInputEvent, ImageInputEvent, AudioInputEvent, or ToolResult). + content: The content to send. Must be one of: + - TextInputEvent: Text message from the user + - ImageInputEvent: Image data for visual understanding + - AudioInputEvent: Audio data for speech input + - ToolResult: Result from a tool execution Example: await model.send(TextInputEvent(text="Hello", role="user")) From a9d8c88376ffe2cf8ab43217fc71a0bdafb96d80 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 30 Oct 2025 14:23:37 +0100 Subject: [PATCH 03/11] fix: remove base session references --- .../bidirectional_streaming/__init__.py | 7 ++-- .../bidirectional_streaming/agent/agent.py | 8 ++--- .../event_loop/bidirectional_event_loop.py | 35 +++++++++---------- .../models/__init__.py | 3 +- .../models/bidirectional_model.py | 4 --- 5 files changed, 24 insertions(+), 33 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index e31bc670e1..d855ba0388 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -3,8 +3,8 @@ # Main components - Primary user interface from .agent.agent import BidirectionalAgent -# Unified model interface (for custom implementations) -from .models.bidirectional_model import BidirectionalModel, BidirectionalModelSession +# Model interface (for custom implementations) +from .models.bidirectional_model import BidirectionalModel # Model providers - What users need to create models from .models.gemini_live import GeminiLiveBidirectionalModel @@ -44,7 +44,6 @@ "VoiceActivityEvent", "UsageMetricsEvent", - # Unified model interface + # Model interface "BidirectionalModel", - "BidirectionalModelSession", # Backwards compatibility alias ] diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 62528d4722..c9d7292b88 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -379,15 +379,15 @@ async def send(self, input_data: str | AudioInputEvent | ImageInputEvent) -> Non self.messages.append({"role": "user", "content": input_data}) logger.debug("Text sent: %d characters", len(input_data)) - # Create TextInputEvent for unified send() + # Create TextInputEvent for send() text_event = {"text": input_data, "role": "user"} - await self._session.model_session.send(text_event) + await self._session.model.send(text_event) elif isinstance(input_data, dict) and "audioData" in input_data: # Handle audio input - already in AudioInputEvent format - await self._session.model_session.send(input_data) + await self._session.model.send(input_data) elif isinstance(input_data, dict) and "imageData" in input_data: # Handle image input - already in ImageInputEvent format - await self._session.model_session.send(input_data) + await self._session.model.send(input_data) else: raise ValueError( "Input must be either a string (text), AudioInputEvent " diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 521ebc0dd5..d1d6e90b32 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -37,14 +37,14 @@ class BidirectionalConnection: handling while providing a simple interface for agent interactions. """ - def __init__(self, model_session: BidirectionalModel, agent: "BidirectionalAgent") -> None: - """Initialize session with model and agent reference. + def __init__(self, model: BidirectionalModel, agent: "BidirectionalAgent") -> None: + """Initialize connection with model and agent reference. Args: - model_session: Bidirectional model instance (unified interface). + model: Bidirectional model instance. agent: BidirectionalAgent instance for tool registry access. """ - self.model_session = model_session + self.model = model self.agent = agent self.active = True @@ -78,16 +78,13 @@ async def start_bidirectional_connection(agent: "BidirectionalAgent") -> Bidirec """ logger.debug("Starting bidirectional session - initializing model connection") - # Connect to model using unified interface + # Connect to model await agent.model.connect( system_prompt=agent.system_prompt, tools=agent.tool_registry.get_all_tool_specs(), messages=agent.messages ) - - # Use the model directly (unified interface - no separate session) - model_session = agent.model - # Create session wrapper for background processing - session = BidirectionalConnection(model_session=model_session, agent=agent) + # Create connection wrapper for background processing + session = BidirectionalConnection(model=agent.model, agent=agent) # Start concurrent background processors IMMEDIATELY after session creation # This is critical - Nova Sonic needs response processing during initialization @@ -138,9 +135,9 @@ async def stop_bidirectional_connection(session: BidirectionalConnection) -> Non if all_tasks: await asyncio.gather(*all_tasks, return_exceptions=True) - # Close model session - await session.model_session.close() - logger.debug("Session closed") + # Close model connection + await session.model.close() + logger.debug("Connection closed") async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> None: @@ -256,11 +253,11 @@ async def _process_model_events(session: BidirectionalConnection) -> None: events to standardized formats, and manages interruption detection. Args: - session: BidirectionalConnection containing model session. + session: BidirectionalConnection containing model. """ logger.debug("Model events processor started") try: - async for provider_event in session.model_session.receive(): + async for provider_event in session.model.receive(): if not session.active: break @@ -437,8 +434,8 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: tool_result = tool_event.tool_result tool_use_id = tool_result.get("toolUseId") - # Send result through unified send() method - await session.model_session.send(tool_result) + # Send result through send() method + await session.model.send(tool_result) logger.debug("Tool result sent: %s", tool_use_id) # Handle streaming events if needed later @@ -474,10 +471,10 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: "content": [{"text": f"Error: {str(e)}"}] } try: - await session.model_session.send(error_result) + await session.model.send(error_result) logger.debug("Error result sent: %s", tool_id) except Exception: logger.error("Failed to send error result: %s", tool_id) - pass # Session might be closed + pass # Connection might be closed diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index e2745310c5..12fe6c2715 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -1,13 +1,12 @@ """Bidirectional model interfaces and implementations.""" -from .bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .bidirectional_model import BidirectionalModel from .gemini_live import GeminiLiveBidirectionalModel from .novasonic import NovaSonicBidirectionalModel from .openai import OpenAIRealtimeBidirectionalModel __all__ = [ "BidirectionalModel", - "BidirectionalModelSession", # Backwards compatibility alias "GeminiLiveBidirectionalModel", "NovaSonicBidirectionalModel", "OpenAIRealtimeBidirectionalModel", diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 75a4ab5f03..5b7091dcdc 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -105,7 +105,3 @@ async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputE await model.send(ToolResult(toolUseId="123", status="success", ...)) """ raise NotImplementedError - - -# Backwards compatibility alias - will be removed in future version -BidirectionalModelSession = BidirectionalModel From 17686d489679f4a3a61b7819841195bba5b02e6d Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 30 Oct 2025 14:38:20 +0100 Subject: [PATCH 04/11] feat: throw exceptions when connect is called on already active connection --- .../bidirectional_streaming/models/gemini_live.py | 11 +++++++---- .../bidirectional_streaming/models/novasonic.py | 3 +++ .../bidirectional_streaming/models/openai.py | 3 +++ .../models/test_gemini_live.py | 13 +++++++++++++ .../models/test_novasonic.py | 14 ++++++++++++++ .../models/test_openai_realtime.py | 13 +++++++++++++ 6 files changed, 53 insertions(+), 4 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 578de5a2b6..dabd1174ba 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -1,11 +1,11 @@ """Gemini Live API bidirectional model provider using official Google GenAI SDK. -Implements the unified BidirectionalModel interface for Google's Gemini Live API using the +Implements the BidirectionalModel interface for Google's Gemini Live API using the official Google GenAI SDK for simplified and robust WebSocket communication. Key improvements over custom WebSocket implementation: - Uses official google-genai SDK with native Live API support -- Unified model interface (no separate session class) +- Simplified session management with client.aio.live.connect() - Built-in tool integration and event handling - Automatic WebSocket connection management and error handling - Native support for audio/text streaming and interruption @@ -45,7 +45,7 @@ class GeminiLiveBidirectionalModel(BidirectionalModel): - """Unified Gemini Live API implementation using official Google GenAI SDK. + """Gemini Live API implementation using official Google GenAI SDK. Combines model configuration and connection state in a single class. Provides a clean interface to Gemini Live API using the official SDK, @@ -101,6 +101,9 @@ async def connect( messages: Conversation history to initialize with. **kwargs: Additional configuration options. """ + if self._active: + raise RuntimeError("Connection already active. Close the existing connection before creating a new one.") + try: # Initialize connection state self.session_id = str(uuid.uuid4()) @@ -277,7 +280,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic return None async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: - """Unified send method for all content types. + """Unified send method for all content types. Sends the given inputs to Google Live API Dispatches to appropriate internal handler based on content type. diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 62b53a127b..ee1bcb5737 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -138,6 +138,9 @@ async def connect( messages: Conversation history to initialize with. **kwargs: Additional configuration options. """ + if self._active: + raise RuntimeError("Connection already active. Close the existing connection before creating a new one.") + logger.debug("Nova connection create - starting") try: diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 0208ee162d..b62d4fa025 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -117,6 +117,9 @@ async def connect( messages: Conversation history to initialize with. **kwargs: Additional configuration options. """ + if self._active: + raise RuntimeError("Connection already active. Close the existing connection before creating a new one.") + logger.info("Creating OpenAI Realtime connection...") try: diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index a5baaa522b..de8fcfd56f 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -185,6 +185,19 @@ async def test_connect_error_handling(mock_genai_client, model): await model.connect() +@pytest.mark.asyncio +async def test_connect_when_already_active(mock_genai_client, model): + """Test that connect() raises exception when already active.""" + mock_client, _, _ = mock_genai_client + + # First connection + await model.connect() + + # Second connection attempt should raise + with pytest.raises(RuntimeError, match="Connection already active"): + await model.connect() + + # Send Method Tests diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 451a98aa2a..59c762b3e2 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -110,6 +110,20 @@ async def test_connect_sends_initialization_events(nova_model, mock_client, mock assert mock_stream.input_stream.send.call_count >= 3 # connectionStart, promptStart, system prompt +@pytest.mark.asyncio +async def test_connect_when_already_active(nova_model, mock_client, mock_stream): + """Test that connect() raises exception when already active.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + + # First connection + await nova_model.connect() + + # Second connection attempt should raise + with pytest.raises(RuntimeError, match="Connection already active"): + await nova_model.connect() + + @pytest.mark.asyncio async def test_close_cleanup(nova_model, mock_client, mock_stream): """Test that close() properly cleans up resources.""" diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index be69929cdd..6183765ae1 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -207,6 +207,19 @@ async def test_connect_error_handling(mock_websockets_connect, model): await model.connect() +@pytest.mark.asyncio +async def test_connect_when_already_active(mock_websockets_connect, model): + """Test that connect() raises exception when already active.""" + mock_connect, _ = mock_websockets_connect + + # First connection + await model.connect() + + # Second connection attempt should raise + with pytest.raises(RuntimeError, match="Connection already active"): + await model.connect() + + @pytest.mark.asyncio async def test_connect_with_organization_header(mock_websockets_connect, api_key): """Test connection includes organization header.""" From c2f88f75404a5807b308ff6fa31d1ed0b85c6d8f Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 30 Oct 2025 14:50:42 +0100 Subject: [PATCH 05/11] feat: Add explicit init params for gemini and openai to free kwargs --- .../models/gemini_live.py | 18 +++++++------ .../models/novasonic.py | 10 ++++--- .../bidirectional_streaming/models/openai.py | 26 ++++++++++++------- .../models/test_gemini_live.py | 6 ++--- .../models/test_openai_realtime.py | 11 +++++--- 5 files changed, 45 insertions(+), 26 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index dabd1174ba..639328c64b 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -56,19 +56,21 @@ def __init__( self, model_id: str = "models/gemini-2.0-flash-live-preview-04-09", api_key: Optional[str] = None, - **config + live_config: Optional[Dict[str, Any]] = None, + **kwargs ): """Initialize Gemini Live API bidirectional model. Args: model_id: Gemini Live model identifier. api_key: Google AI API key for authentication. - **config: Additional configuration. + live_config: Gemini Live API configuration parameters (e.g., response_modalities, speech_config). + **kwargs: Reserved for future parameters. """ # Model configuration self.model_id = model_id self.api_key = api_key - self.config = config + self.live_config = live_config or {} # Create Gemini client with proper API version client_kwargs = {} @@ -423,15 +425,15 @@ def _build_live_config( ) -> Dict[str, Any]: """Build LiveConnectConfig for the official SDK. - Simply passes through all config parameters from params, allowing users + Simply passes through all config parameters from live_config, allowing users to configure any Gemini Live API parameter directly. """ - # Start with user config from params + # Start with user-provided live_config config_dict = {} - if "params" in self.config: - config_dict.update(self.config["params"]) + if self.live_config: + config_dict.update(self.live_config) - # Override with any kwargs + # Override with any kwargs from connect() config_dict.update(kwargs) # Add system instruction if provided diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index ee1bcb5737..5436b5ae71 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -88,18 +88,22 @@ class NovaSonicBidirectionalModel(BidirectionalModel): tool execution patterns while providing the standard BidirectionalModel interface. """ - def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config: any) -> None: + def __init__( + self, + model_id: str = "amazon.nova-sonic-v1:0", + region: str = "us-east-1", + **kwargs + ) -> None: """Initialize Nova Sonic bidirectional model. Args: model_id: Nova Sonic model identifier. region: AWS region. - **config: Additional configuration. + **kwargs: Reserved for future parameters. """ # Model configuration self.model_id = model_id self.region = region - self.config = config self._client = None # Connection state (initialized in connect()) diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index b62d4fa025..8322eef4bc 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -71,19 +71,27 @@ def __init__( self, model: str = DEFAULT_MODEL, api_key: str | None = None, - **config: any + organization: str | None = None, + project: str | None = None, + session_config: dict[str, any] | None = None, + **kwargs ) -> None: """Initialize OpenAI Realtime bidirectional model. Args: model: OpenAI model identifier (default: gpt-realtime). api_key: OpenAI API key for authentication. - **config: Additional configuration (organization, project, session params). + organization: OpenAI organization ID for API requests. + project: OpenAI project ID for API requests. + session_config: Session configuration parameters (e.g., voice, turn_detection, modalities). + **kwargs: Reserved for future parameters. """ # Model configuration self.model = model self.api_key = api_key - self.config = config + self.organization = organization + self.project = project + self.session_config = session_config or {} import os if not self.api_key: @@ -133,10 +141,10 @@ async def connect( url = f"{OPENAI_REALTIME_URL}?model={self.model}" headers = [("Authorization", f"Bearer {self.api_key}")] - if "organization" in self.config: - headers.append(("OpenAI-Organization", self.config["organization"])) - if "project" in self.config: - headers.append(("OpenAI-Project", self.config["project"])) + if self.organization: + headers.append(("OpenAI-Organization", self.organization)) + if self.project: + headers.append(("OpenAI-Project", self.project)) self.websocket = await websockets.connect(url, additional_headers=headers) logger.info("WebSocket connected successfully") @@ -181,14 +189,14 @@ def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] if tools: config["tools"] = self._convert_tools_to_openai_format(tools) - custom_config = self.config.get("session", {}) + # Apply user-provided session configuration supported_params = { "type", "output_modalities", "instructions", "voice", "audio", "tools", "tool_choice", "input_audio_format", "output_audio_format", "input_audio_transcription", "turn_detection" } - for key, value in custom_config.items(): + for key, value in self.session_config.items(): if key in supported_params: config[key] = value else: diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index de8fcfd56f..5dec7ca2da 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -115,10 +115,10 @@ def test_init_with_custom_config(mock_genai_client, model_id): """Test model initialization with custom configuration.""" _ = mock_genai_client - custom_config = {"temperature": 0.7, "top_p": 0.9} - model = GeminiLiveBidirectionalModel(model_id=model_id, **custom_config) + live_config = {"temperature": 0.7, "top_p": 0.9} + model = GeminiLiveBidirectionalModel(model_id=model_id, live_config=live_config) - assert model.config == custom_config + assert model.live_config == live_config # Connection Tests diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 6183765ae1..ad0d3993a2 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -103,10 +103,15 @@ def test_init_with_api_key(api_key, model_name): def test_init_with_custom_config(model_name, api_key): """Test model initialization with custom configuration.""" - custom_config = {"organization": "org-123", "project": "proj-456"} - model = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key, **custom_config) + model = OpenAIRealtimeBidirectionalModel( + model=model_name, + api_key=api_key, + organization="org-123", + project="proj-456" + ) - assert model.config == custom_config + assert model.organization == "org-123" + assert model.project == "proj-456" def test_init_without_api_key_raises(): From 55554aac6338cb89be41690dd01cd8ef660020a3 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 30 Oct 2025 15:16:22 +0100 Subject: [PATCH 06/11] test: Consolidate bidi model tests --- .../models/test_gemini_live.py | 424 +++++-------- .../models/test_novasonic.py | 422 +++++-------- .../models/test_openai_realtime.py | 564 ++++++------------ 3 files changed, 481 insertions(+), 929 deletions(-) diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index 5dec7ca2da..b894509c91 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -2,14 +2,12 @@ Tests the unified GeminiLiveBidirectionalModel interface including: - Model initialization and configuration -- Connection establishment +- Connection establishment and lifecycle - Unified send() method with different content types - Event receiving and conversion -- Connection lifecycle management """ import unittest.mock -import uuid import pytest from google import genai @@ -84,146 +82,121 @@ def messages(): # Initialization Tests -def test_init_default_config(mock_genai_client): - """Test model initialization with default configuration.""" +def test_model_initialization(mock_genai_client, model_id, api_key): + """Test model initialization with various configurations.""" _ = mock_genai_client - model = GeminiLiveBidirectionalModel() + # Test default config + model_default = GeminiLiveBidirectionalModel() + assert model_default.model_id == "models/gemini-2.0-flash-live-preview-04-09" + assert model_default.api_key is None + assert model_default._active is False + assert model_default.live_session is None - assert model.model_id == "models/gemini-2.0-flash-live-preview-04-09" - assert model.api_key is None - assert model._active is False - assert model.live_session is None - - -def test_init_with_api_key(mock_genai_client, model_id, api_key): - """Test model initialization with API key.""" - mock_client, _, _ = mock_genai_client - - model = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) - - assert model.model_id == model_id - assert model.api_key == api_key - - # Verify client was created with correct parameters - mock_client_cls = unittest.mock.patch("strands.experimental.bidirectional_streaming.models.gemini_live.genai.Client").start() - GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) - mock_client_cls.assert_called() - - -def test_init_with_custom_config(mock_genai_client, model_id): - """Test model initialization with custom configuration.""" - _ = mock_genai_client + # Test with API key + model_with_key = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + assert model_with_key.model_id == model_id + assert model_with_key.api_key == api_key + # Test with custom config live_config = {"temperature": 0.7, "top_p": 0.9} - model = GeminiLiveBidirectionalModel(model_id=model_id, live_config=live_config) - - assert model.live_config == live_config + model_custom = GeminiLiveBidirectionalModel(model_id=model_id, live_config=live_config) + assert model_custom.live_config == live_config # Connection Tests @pytest.mark.asyncio -async def test_connect_basic(mock_genai_client, model): - """Test basic connection establishment.""" - mock_client, mock_live_session, _ = mock_genai_client +async def test_connection_lifecycle(mock_genai_client, model, system_prompt, tool_spec, messages): + """Test complete connection lifecycle with various configurations.""" + mock_client, mock_live_session, mock_live_session_cm = mock_genai_client + # Test basic connection await model.connect() - assert model._active is True assert model.session_id is not None assert model.live_session == mock_live_session mock_client.aio.live.connect.assert_called_once() - - -@pytest.mark.asyncio -async def test_connect_with_system_prompt(mock_genai_client, model, system_prompt): - """Test connection with system prompt.""" - mock_client, _, _ = mock_genai_client - await model.connect(system_prompt=system_prompt) + # Test close + await model.close() + assert model._active is False + mock_live_session_cm.__aexit__.assert_called_once() - # Verify system prompt was included in config + # Test connection with system prompt + await model.connect(system_prompt=system_prompt) call_args = mock_client.aio.live.connect.call_args config = call_args.kwargs.get("config", {}) assert config.get("system_instruction") == system_prompt - - -@pytest.mark.asyncio -async def test_connect_with_tools(mock_genai_client, model, tool_spec): - """Test connection with tools.""" - mock_client, _, _ = mock_genai_client + await model.close() + # Test connection with tools await model.connect(tools=[tool_spec]) - - # Verify tools were formatted and included call_args = mock_client.aio.live.connect.call_args config = call_args.kwargs.get("config", {}) assert "tools" in config assert len(config["tools"]) > 0 - - -@pytest.mark.asyncio -async def test_connect_with_messages(mock_genai_client, model, messages): - """Test connection with message history.""" - _, mock_live_session, _ = mock_genai_client + await model.close() + # Test connection with messages await model.connect(messages=messages) - - # Verify message history was sent mock_live_session.send_client_content.assert_called() + await model.close() @pytest.mark.asyncio -async def test_connect_error_handling(mock_genai_client, model): - """Test connection error handling.""" - mock_client, _, _ = mock_genai_client - mock_client.aio.live.connect.side_effect = Exception("Connection failed") +async def test_connection_edge_cases(mock_genai_client, api_key, model_id): + """Test connection error handling and edge cases.""" + mock_client, _, mock_live_session_cm = mock_genai_client + # Test connection error + model1 = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + mock_client.aio.live.connect.side_effect = Exception("Connection failed") with pytest.raises(Exception, match="Connection failed"): - await model.connect() - - -@pytest.mark.asyncio -async def test_connect_when_already_active(mock_genai_client, model): - """Test that connect() raises exception when already active.""" - mock_client, _, _ = mock_genai_client + await model1.connect() - # First connection - await model.connect() + # Reset mock for next tests + mock_client.aio.live.connect.side_effect = None - # Second connection attempt should raise + # Test double connection + model2 = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + await model2.connect() with pytest.raises(RuntimeError, match="Connection already active"): - await model.connect() + await model2.connect() + await model2.close() + + # Test close when not connected + model3 = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + await model3.close() # Should not raise + + # Test close error handling + model4 = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + await model4.connect() + mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") + with pytest.raises(Exception, match="Close failed"): + await model4.close() # Send Method Tests @pytest.mark.asyncio -async def test_send_text_input(mock_genai_client, model): - """Test sending text input through unified send() method.""" +async def test_send_all_content_types(mock_genai_client, model): + """Test sending all content types through unified send() method.""" _, mock_live_session, _ = mock_genai_client await model.connect() + # Test text input text_input: TextInputEvent = {"text": "Hello", "role": "user"} await model.send(text_input) - - # Verify text was sent via send_client_content mock_live_session.send_client_content.assert_called_once() call_args = mock_live_session.send_client_content.call_args content = call_args.kwargs.get("turns") assert content.role == "user" assert content.parts[0].text == "Hello" - - -@pytest.mark.asyncio -async def test_send_audio_input(mock_genai_client, model): - """Test sending audio input through unified send() method.""" - _, mock_live_session, _ = mock_genai_client - await model.connect() + # Test audio input audio_input: AudioInputEvent = { "audioData": b"audio_bytes", "format": "pcm", @@ -231,102 +204,59 @@ async def test_send_audio_input(mock_genai_client, model): "channels": 1, } await model.send(audio_input) - - # Verify audio was sent via send_realtime_input mock_live_session.send_realtime_input.assert_called_once() - - -@pytest.mark.asyncio -async def test_send_image_input(mock_genai_client, model): - """Test sending image input through unified send() method.""" - _, mock_live_session, _ = mock_genai_client - await model.connect() + # Test image input image_input: ImageInputEvent = { "imageData": b"image_bytes", "mimeType": "image/jpeg", "encoding": "raw", } await model.send(image_input) - - # Verify image was sent mock_live_session.send.assert_called_once() - - -@pytest.mark.asyncio -async def test_send_tool_result(mock_genai_client, model): - """Test sending tool result through unified send() method.""" - _, mock_live_session, _ = mock_genai_client - await model.connect() + # Test tool result tool_result: ToolResult = { "toolUseId": "tool-123", "status": "success", "content": [{"text": "Result: 42"}], } await model.send(tool_result) - - # Verify tool result was sent mock_live_session.send_tool_response.assert_called_once() + + await model.close() @pytest.mark.asyncio -async def test_send_when_inactive(mock_genai_client, model): - """Test that send() does nothing when connection is inactive.""" +async def test_send_edge_cases(mock_genai_client, model): + """Test send() edge cases and error handling.""" _, mock_live_session, _ = mock_genai_client - # Don't connect, so _active is False + # Test send when inactive text_input: TextInputEvent = {"text": "Hello", "role": "user"} await model.send(text_input) - - # Verify nothing was sent mock_live_session.send_client_content.assert_not_called() - - -@pytest.mark.asyncio -async def test_send_unknown_content_type(mock_genai_client, model): - """Test sending unknown content type logs warning.""" - _, _, _ = mock_genai_client - await model.connect() + # Test unknown content type + await model.connect() unknown_content = {"unknown_field": "value"} + await model.send(unknown_content) # Should not raise, just log warning - # Should not raise, just log warning - await model.send(unknown_content) + await model.close() # Receive Method Tests @pytest.mark.asyncio -async def test_receive_connection_start_event(mock_genai_client, model, agenerator): - """Test that receive() emits connection start event.""" +async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): + """Test that receive() emits connection start and end events.""" _, mock_live_session, _ = mock_genai_client mock_live_session.receive.return_value = agenerator([]) await model.connect() - # Get first event - receive_gen = model.receive() - first_event = await anext(receive_gen) - - # First event should be connection start - assert "BidirectionalConnectionStart" in first_event - assert first_event["BidirectionalConnectionStart"]["connectionId"] == model.session_id - - # Close to stop the loop - await model.close() - - -@pytest.mark.asyncio -async def test_receive_connection_end_event(mock_genai_client, model, agenerator): - """Test that receive() emits connection end event.""" - _, mock_live_session, _ = mock_genai_client - mock_live_session.receive.return_value = agenerator([]) - - await model.connect() - - # Collect events until connection ends + # Collect events events = [] async for event in model.receive(): events.append(event) @@ -334,57 +264,44 @@ async def test_receive_connection_end_event(mock_genai_client, model, agenerator if len(events) == 1: await model.close() - # Last event should be connection end + # Verify connection start and end + assert len(events) >= 2 + assert "BidirectionalConnectionStart" in events[0] + assert events[0]["BidirectionalConnectionStart"]["connectionId"] == model.session_id assert "BidirectionalConnectionEnd" in events[-1] @pytest.mark.asyncio -async def test_receive_text_output(mock_genai_client, model): - """Test receiving text output from model.""" - _, mock_live_session, _ = mock_genai_client - - mock_message = unittest.mock.Mock() - mock_message.text = "Hello from Gemini" - mock_message.data = None - mock_message.tool_call = None - mock_message.server_content = None - - await model.connect() - - # Test the conversion method directly - converted_event = model._convert_gemini_live_event(mock_message) - - assert "textOutput" in converted_event - assert converted_event["textOutput"]["text"] == "Hello from Gemini" - assert converted_event["textOutput"]["role"] == "assistant" - - -@pytest.mark.asyncio -async def test_receive_audio_output(mock_genai_client, model): - """Test receiving audio output from model.""" - _, mock_live_session, _ = mock_genai_client - - mock_message = unittest.mock.Mock() - mock_message.text = None - mock_message.data = b"audio_data" - mock_message.tool_call = None - mock_message.server_content = None - +async def test_event_conversion(mock_genai_client, model): + """Test conversion of all Gemini Live event types to standard format.""" + _, _, _ = mock_genai_client await model.connect() - # Test the conversion method directly - converted_event = model._convert_gemini_live_event(mock_message) - - assert "audioOutput" in converted_event - assert converted_event["audioOutput"]["audioData"] == b"audio_data" - assert converted_event["audioOutput"]["format"] == "pcm" - - -@pytest.mark.asyncio -async def test_receive_tool_call(mock_genai_client, model): - """Test receiving tool call from model.""" - _, mock_live_session, _ = mock_genai_client - + # Test text output + mock_text = unittest.mock.Mock() + mock_text.text = "Hello from Gemini" + mock_text.data = None + mock_text.tool_call = None + mock_text.server_content = None + + text_event = model._convert_gemini_live_event(mock_text) + assert "textOutput" in text_event + assert text_event["textOutput"]["text"] == "Hello from Gemini" + assert text_event["textOutput"]["role"] == "assistant" + + # Test audio output + mock_audio = unittest.mock.Mock() + mock_audio.text = None + mock_audio.data = b"audio_data" + mock_audio.tool_call = None + mock_audio.server_content = None + + audio_event = model._convert_gemini_live_event(mock_audio) + assert "audioOutput" in audio_event + assert audio_event["audioOutput"]["audioData"] == b"audio_data" + assert audio_event["audioOutput"]["format"] == "pcm" + + # Test tool call mock_func_call = unittest.mock.Mock() mock_func_call.id = "tool-123" mock_func_call.name = "calculator" @@ -393,121 +310,62 @@ async def test_receive_tool_call(mock_genai_client, model): mock_tool_call = unittest.mock.Mock() mock_tool_call.function_calls = [mock_func_call] - mock_message = unittest.mock.Mock() - mock_message.text = None - mock_message.data = None - mock_message.tool_call = mock_tool_call - mock_message.server_content = None - - await model.connect() + mock_tool = unittest.mock.Mock() + mock_tool.text = None + mock_tool.data = None + mock_tool.tool_call = mock_tool_call + mock_tool.server_content = None - # Test the conversion method directly - converted_event = model._convert_gemini_live_event(mock_message) - - assert "toolUse" in converted_event - assert converted_event["toolUse"]["toolUseId"] == "tool-123" - assert converted_event["toolUse"]["name"] == "calculator" - - -@pytest.mark.asyncio -async def test_receive_interruption(mock_genai_client, model): - """Test receiving interruption event.""" - _, mock_live_session, _ = mock_genai_client + tool_event = model._convert_gemini_live_event(mock_tool) + assert "toolUse" in tool_event + assert tool_event["toolUse"]["toolUseId"] == "tool-123" + assert tool_event["toolUse"]["name"] == "calculator" + # Test interruption mock_server_content = unittest.mock.Mock() mock_server_content.interrupted = True mock_server_content.input_transcription = None mock_server_content.output_transcription = None - mock_message = unittest.mock.Mock() - mock_message.text = None - mock_message.data = None - mock_message.tool_call = None - mock_message.server_content = mock_server_content + mock_interrupt = unittest.mock.Mock() + mock_interrupt.text = None + mock_interrupt.data = None + mock_interrupt.tool_call = None + mock_interrupt.server_content = mock_server_content - await model.connect() + interrupt_event = model._convert_gemini_live_event(mock_interrupt) + assert "interruptionDetected" in interrupt_event + assert interrupt_event["interruptionDetected"]["reason"] == "user_input" - # Test the conversion method directly - converted_event = model._convert_gemini_live_event(mock_message) - - assert "interruptionDetected" in converted_event - assert converted_event["interruptionDetected"]["reason"] == "user_input" - - -# Close Method Tests - - -@pytest.mark.asyncio -async def test_close_connection(mock_genai_client, model): - """Test closing connection.""" - _, _, mock_live_session_cm = mock_genai_client - - await model.connect() - await model.close() - - assert model._active is False - mock_live_session_cm.__aexit__.assert_called_once() - - -@pytest.mark.asyncio -async def test_close_when_not_connected(mock_genai_client, model): - """Test closing when not connected does nothing.""" - _, _, mock_live_session_cm = mock_genai_client - - # Don't connect await model.close() - - # Should not raise, and __aexit__ should not be called - mock_live_session_cm.__aexit__.assert_not_called() - - -@pytest.mark.asyncio -async def test_close_error_handling(mock_genai_client, model): - """Test close error handling.""" - _, _, mock_live_session_cm = mock_genai_client - mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") - - await model.connect() - - with pytest.raises(Exception, match="Close failed"): - await model.close() # Helper Method Tests -def test_build_live_config_basic(model): - """Test building basic live config.""" - config = model._build_live_config() +def test_config_building(model, system_prompt, tool_spec): + """Test building live config with various options.""" + # Test basic config + config_basic = model._build_live_config() + assert isinstance(config_basic, dict) - assert isinstance(config, dict) - - -def test_build_live_config_with_system_prompt(model, system_prompt): - """Test building config with system prompt.""" - config = model._build_live_config(system_prompt=system_prompt) + # Test with system prompt + config_prompt = model._build_live_config(system_prompt=system_prompt) + assert config_prompt["system_instruction"] == system_prompt - assert config["system_instruction"] == system_prompt + # Test with tools + config_tools = model._build_live_config(tools=[tool_spec]) + assert "tools" in config_tools + assert len(config_tools["tools"]) > 0 -def test_build_live_config_with_tools(model, tool_spec): - """Test building config with tools.""" - config = model._build_live_config(tools=[tool_spec]) - - assert "tools" in config - assert len(config["tools"]) > 0 - - -def test_format_tools_for_live_api(model, tool_spec): +def test_tool_formatting(model, tool_spec): """Test tool formatting for Gemini Live API.""" + # Test with tools formatted_tools = model._format_tools_for_live_api([tool_spec]) - assert len(formatted_tools) == 1 assert isinstance(formatted_tools[0], genai_types.Tool) - - -def test_format_tools_empty_list(model): - """Test formatting empty tool list.""" - formatted_tools = model._format_tools_for_live_api([]) - assert formatted_tools == [] + # Test empty list + formatted_empty = model._format_tools_for_live_api([]) + assert formatted_empty == [] diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 59c762b3e2..10066a6938 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -7,9 +7,7 @@ import asyncio import base64 import json -import uuid -from typing import Any, Dict -from unittest.mock import AsyncMock, MagicMock, Mock, patch +from unittest.mock import AsyncMock, patch import pytest import pytest_asyncio @@ -17,7 +15,7 @@ from strands.experimental.bidirectional_streaming.models.novasonic import ( NovaSonicBidirectionalModel, ) -from strands.types.tools import ToolResult, ToolSpec +from strands.types.tools import ToolResult # Test fixtures @@ -62,12 +60,14 @@ async def nova_model(model_id, region): await model.close() -# Connection lifecycle tests +# Initialization and Connection Tests + + @pytest.mark.asyncio async def test_model_initialization(model_id, region): """Test model initialization with configuration.""" model = NovaSonicBidirectionalModel(model_id=model_id, region=region) - + assert model.model_id == model_id assert model.region == region assert model.stream is None @@ -76,26 +76,24 @@ async def test_model_initialization(model_id, region): @pytest.mark.asyncio -async def test_connect_establishes_connection(nova_model, mock_client, mock_stream): - """Test that connect() establishes bidirectional connection.""" +async def test_connection_lifecycle(nova_model, mock_client, mock_stream): + """Test complete connection lifecycle with various configurations.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model._client = mock_client - + + # Test basic connection await nova_model.connect(system_prompt="Test system prompt") - assert nova_model._active assert nova_model.stream == mock_stream assert nova_model.prompt_name is not None assert mock_client.invoke_model_with_bidirectional_stream.called + # Test close + await nova_model.close() + assert not nova_model._active + assert mock_stream.input_stream.close.called -@pytest.mark.asyncio -async def test_connect_sends_initialization_events(nova_model, mock_client, mock_stream): - """Test that connect() sends proper initialization sequence.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client - - system_prompt = "You are a helpful assistant" + # Test connection with tools tools = [ { "name": "get_weather", @@ -103,108 +101,147 @@ async def test_connect_sends_initialization_events(nova_model, mock_client, mock "inputSchema": {"json": json.dumps({"type": "object", "properties": {}})} } ] - - await nova_model.connect(system_prompt=system_prompt, tools=tools) - - # Verify initialization events were sent - assert mock_stream.input_stream.send.call_count >= 3 # connectionStart, promptStart, system prompt + await nova_model.connect(system_prompt="You are helpful", tools=tools) + # Verify initialization events were sent (connectionStart, promptStart, system prompt) + assert mock_stream.input_stream.send.call_count >= 3 + await nova_model.close() @pytest.mark.asyncio -async def test_connect_when_already_active(nova_model, mock_client, mock_stream): - """Test that connect() raises exception when already active.""" +async def test_connection_edge_cases(nova_model, mock_client, mock_stream, model_id, region): + """Test connection error handling and edge cases.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model._client = mock_client - - # First connection + + # Test double connection await nova_model.connect() - - # Second connection attempt should raise with pytest.raises(RuntimeError, match="Connection already active"): await nova_model.connect() + await nova_model.close() + + # Test close when already closed + model2 = NovaSonicBidirectionalModel(model_id=model_id, region=region) + await model2.close() # Should not raise + await model2.close() # Second call should also be safe + + +# Send Method Tests @pytest.mark.asyncio -async def test_close_cleanup(nova_model, mock_client, mock_stream): - """Test that close() properly cleans up resources.""" +async def test_send_all_content_types(nova_model, mock_client, mock_stream): + """Test sending all content types through unified send() method.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model._client = mock_client - + await nova_model.connect() + + # Test text content + text_event = {"text": "Hello, Nova!", "role": "user"} + await nova_model.send(text_event) + # Should send contentStart, textInput, and contentEnd + assert mock_stream.input_stream.send.call_count >= 3 + + # Test audio content + audio_event = { + "audioData": b"audio data", + "format": "pcm", + "sampleRate": 16000, + "channels": 1 + } + await nova_model.send(audio_event) + # Should start audio connection and send audio + assert nova_model.audio_connection_active + assert mock_stream.input_stream.send.called + + # Test tool result + tool_result = { + "toolUseId": "tool-123", + "status": "success", + "content": [{"text": "Weather is sunny"}] + } + await nova_model.send(tool_result) + # Should send contentStart, toolResult, and contentEnd + assert mock_stream.input_stream.send.called + await nova_model.close() - - assert not nova_model._active - assert mock_stream.input_stream.close.called -# Event conversion tests @pytest.mark.asyncio -async def test_receive_emits_connection_start(nova_model, mock_client, mock_stream): - """Test that receive() emits connection start event.""" +async def test_send_edge_cases(nova_model, mock_client, mock_stream, caplog): + """Test send() edge cases and error handling.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model._client = mock_client - + + # Test send when inactive + text_event = {"text": "Hello", "role": "user"} + await nova_model.send(text_event) # Should not raise + + # Test image content (not supported) + await nova_model.connect() + image_event = { + "imageData": b"image data", + "mimeType": "image/jpeg" + } + await nova_model.send(image_event) + # Should log warning about unsupported image input + assert any("not supported" in record.message.lower() for record in caplog.records) + + await nova_model.close() + + +# Receive and Event Conversion Tests + + +@pytest.mark.asyncio +async def test_receive_lifecycle_events(nova_model, mock_client, mock_stream): + """Test that receive() emits connection start and end events.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model._client = mock_client + # Setup mock to return no events and then stop async def mock_wait_for(*args, **kwargs): await asyncio.sleep(0.1) nova_model._active = False raise asyncio.TimeoutError() - + with patch("asyncio.wait_for", side_effect=mock_wait_for): await nova_model.connect() - + events = [] async for event in nova_model.receive(): events.append(event) - + # Should have connection start and end assert len(events) >= 2 assert "BidirectionalConnectionStart" in events[0] assert events[0]["BidirectionalConnectionStart"]["connectionId"] == nova_model.prompt_name + assert "BidirectionalConnectionEnd" in events[-1] @pytest.mark.asyncio -async def test_convert_audio_output_event(nova_model): - """Test conversion of Nova Sonic audio output to standard format.""" +async def test_event_conversion(nova_model): + """Test conversion of all Nova Sonic event types to standard format.""" + # Test audio output audio_bytes = b"test audio data" audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") - - nova_event = { - "audioOutput": { - "content": audio_base64 - } - } - + nova_event = {"audioOutput": {"content": audio_base64}} result = nova_model._convert_nova_event(nova_event) - assert result is not None assert "audioOutput" in result assert result["audioOutput"]["audioData"] == audio_bytes assert result["audioOutput"]["format"] == "pcm" assert result["audioOutput"]["sampleRate"] == 24000 - -@pytest.mark.asyncio -async def test_convert_text_output_event(nova_model): - """Test conversion of Nova Sonic text output to standard format.""" - nova_event = { - "textOutput": { - "content": "Hello, world!", - "role": "ASSISTANT" - } - } - + # Test text output + nova_event = {"textOutput": {"content": "Hello, world!", "role": "ASSISTANT"}} result = nova_model._convert_nova_event(nova_event) - assert result is not None assert "textOutput" in result assert result["textOutput"]["text"] == "Hello, world!" assert result["textOutput"]["role"] == "assistant" - -@pytest.mark.asyncio -async def test_convert_tool_use_event(nova_model): - """Test conversion of Nova Sonic tool use to standard format.""" + # Test tool use tool_input = {"location": "Seattle"} nova_event = { "toolUse": { @@ -213,33 +250,21 @@ async def test_convert_tool_use_event(nova_model): "content": json.dumps(tool_input) } } - result = nova_model._convert_nova_event(nova_event) - assert result is not None assert "toolUse" in result assert result["toolUse"]["toolUseId"] == "tool-123" assert result["toolUse"]["name"] == "get_weather" assert result["toolUse"]["input"] == tool_input - -@pytest.mark.asyncio -async def test_convert_interruption_event(nova_model): - """Test conversion of Nova Sonic interruption to standard format.""" - nova_event = { - "stopReason": "INTERRUPTED" - } - + # Test interruption + nova_event = {"stopReason": "INTERRUPTED"} result = nova_model._convert_nova_event(nova_event) - assert result is not None assert "interruptionDetected" in result assert result["interruptionDetected"]["reason"] == "user_input" - -@pytest.mark.asyncio -async def test_convert_usage_metrics_event(nova_model): - """Test conversion of Nova Sonic usage event to standard format.""" + # Test usage metrics nova_event = { "usageEvent": { "totalTokens": 100, @@ -254,9 +279,7 @@ async def test_convert_usage_metrics_event(nova_model): } } } - result = nova_model._convert_nova_event(nova_event) - assert result is not None assert "usageMetrics" in result assert result["usageMetrics"]["totalTokens"] == 100 @@ -264,131 +287,44 @@ async def test_convert_usage_metrics_event(nova_model): assert result["usageMetrics"]["outputTokens"] == 60 assert result["usageMetrics"]["audioTokens"] == 30 - -@pytest.mark.asyncio -async def test_convert_content_start_tracks_role(nova_model): - """Test that contentStart events track role for subsequent text output.""" - nova_event = { - "contentStart": { - "role": "USER" - } - } - + # Test content start tracks role + nova_event = {"contentStart": {"role": "USER"}} result = nova_model._convert_nova_event(nova_event) - - # contentStart doesn't emit an event but stores role - assert result is None + assert result is None # contentStart doesn't emit an event assert nova_model._current_role == "USER" -# Send method tests -@pytest.mark.asyncio -async def test_send_text_content(nova_model, mock_client, mock_stream): - """Test sending text content through unified send() method.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client - - await nova_model.connect() - - text_event = { - "text": "Hello, Nova!", - "role": "user" - } - - await nova_model.send(text_event) - - # Should send contentStart, textInput, and contentEnd - assert mock_stream.input_stream.send.call_count >= 3 - - -@pytest.mark.asyncio -async def test_send_audio_content(nova_model, mock_client, mock_stream): - """Test sending audio content through unified send() method.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client - - await nova_model.connect() - - audio_event = { - "audioData": b"audio data", - "format": "pcm", - "sampleRate": 16000, - "channels": 1 - } - - await nova_model.send(audio_event) - - # Should start audio connection and send audio - assert nova_model.audio_connection_active - assert mock_stream.input_stream.send.called - - -@pytest.mark.asyncio -async def test_send_tool_result(nova_model, mock_client, mock_stream): - """Test sending tool result through unified send() method.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client - - await nova_model.connect() - - tool_result = { - "toolUseId": "tool-123", - "status": "success", - "content": [{"text": "Weather is sunny"}] - } - - await nova_model.send(tool_result) - - # Should send contentStart, toolResult, and contentEnd - assert mock_stream.input_stream.send.call_count >= 3 - - -@pytest.mark.asyncio -async def test_send_image_content_not_supported(nova_model, mock_client, mock_stream, caplog): - """Test that image content logs warning (not supported by Nova Sonic).""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client - - await nova_model.connect() - - image_event = { - "imageData": b"image data", - "mimeType": "image/jpeg" - } - - await nova_model.send(image_event) - - # Should log warning about unsupported image input - assert any("not supported" in record.message.lower() for record in caplog.records) +# Audio Streaming Tests -# Audio streaming tests @pytest.mark.asyncio async def test_audio_connection_lifecycle(nova_model, mock_client, mock_stream): """Test audio connection start and end lifecycle.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model._client = mock_client - + await nova_model.connect() - + # Start audio connection await nova_model._start_audio_connection() assert nova_model.audio_connection_active - + # End audio connection await nova_model._end_audio_input() assert not nova_model.audio_connection_active + await nova_model.close() + @pytest.mark.asyncio -async def test_silence_detection_ends_audio(nova_model, mock_client, mock_stream): +async def test_silence_detection(nova_model, mock_client, mock_stream): """Test that silence detection automatically ends audio input.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model._client = mock_client nova_model.silence_threshold = 0.1 # Short threshold for testing - + await nova_model.connect() - + # Send audio to start connection audio_event = { "audioData": b"audio data", @@ -396,20 +332,24 @@ async def test_silence_detection_ends_audio(nova_model, mock_client, mock_stream "sampleRate": 16000, "channels": 1 } - + await nova_model.send(audio_event) assert nova_model.audio_connection_active - + # Wait for silence detection await asyncio.sleep(0.2) - + # Audio connection should be ended assert not nova_model.audio_connection_active + await nova_model.close() + + +# Helper Method Tests + -# Tool configuration tests @pytest.mark.asyncio -async def test_build_tool_configuration(nova_model): +async def test_tool_configuration(nova_model): """Test building tool configuration from tool specs.""" tools = [ { @@ -425,141 +365,69 @@ async def test_build_tool_configuration(nova_model): } } ] - + tool_config = nova_model._build_tool_configuration(tools) - + assert len(tool_config) == 1 assert tool_config[0]["toolSpec"]["name"] == "get_weather" assert tool_config[0]["toolSpec"]["description"] == "Get weather information" assert "inputSchema" in tool_config[0]["toolSpec"] -# Event template tests @pytest.mark.asyncio -async def test_get_connection_start_event(nova_model): - """Test connection start event generation.""" +async def test_event_templates(nova_model): + """Test event template generation.""" + # Test connection start event event_json = nova_model._get_connection_start_event() event = json.loads(event_json) - assert "event" in event assert "sessionStart" in event["event"] assert "inferenceConfiguration" in event["event"]["sessionStart"] - -@pytest.mark.asyncio -async def test_get_prompt_start_event(nova_model): - """Test prompt start event generation.""" + # Test prompt start event nova_model.prompt_name = "test-prompt" - event_json = nova_model._get_prompt_start_event([]) event = json.loads(event_json) - assert "event" in event assert "promptStart" in event["event"] assert event["event"]["promptStart"]["promptName"] == "test-prompt" - -@pytest.mark.asyncio -async def test_get_text_input_event(nova_model): - """Test text input event generation.""" - nova_model.prompt_name = "test-prompt" + # Test text input event content_name = "test-content" - event_json = nova_model._get_text_input_event(content_name, "Hello") event = json.loads(event_json) - assert "event" in event assert "textInput" in event["event"] assert event["event"]["textInput"]["content"] == "Hello" - -@pytest.mark.asyncio -async def test_get_tool_result_event(nova_model): - """Test tool result event generation.""" - nova_model.prompt_name = "test-prompt" - content_name = "test-content" + # Test tool result event result = {"result": "Success"} - event_json = nova_model._get_tool_result_event(content_name, result) event = json.loads(event_json) - assert "event" in event assert "toolResult" in event["event"] assert json.loads(event["event"]["toolResult"]["content"]) == result -# Error handling tests -@pytest.mark.asyncio -async def test_send_when_inactive(nova_model): - """Test that send() handles inactive connection gracefully.""" - text_event = { - "text": "Hello", - "role": "user" - } - - # Should not raise error when inactive - await nova_model.send(text_event) - - -@pytest.mark.asyncio -async def test_close_when_already_closed(nova_model): - """Test that close() handles already closed connection.""" - # Should not raise error when already inactive - await nova_model.close() - await nova_model.close() # Second call should be safe +# Error Handling Tests @pytest.mark.asyncio -async def test_response_processor_handles_errors(nova_model, mock_client, mock_stream): - """Test that response processor handles errors gracefully.""" +async def test_error_handling(nova_model, mock_client, mock_stream): + """Test error handling in various scenarios.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): nova_model._client = mock_client - - # Setup mock to raise error + + # Test response processor handles errors gracefully async def mock_error(*args, **kwargs): raise Exception("Test error") - + mock_stream.await_output.side_effect = mock_error - + await nova_model.connect() - + # Wait a bit for response processor to handle error await asyncio.sleep(0.1) - - # Should still be able to close cleanly - await nova_model.close() - -# Integration-style tests -@pytest.mark.asyncio -async def test_full_conversation_flow(nova_model, mock_client, mock_stream): - """Test a complete conversation flow with text and audio.""" - with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client - - # Connect - await nova_model.connect(system_prompt="You are helpful") - - # Send text - await nova_model.send({"text": "Hello", "role": "user"}) - - # Send audio - await nova_model.send({ - "audioData": b"audio", - "format": "pcm", - "sampleRate": 16000, - "channels": 1 - }) - - # Send tool result - await nova_model.send({ - "toolUseId": "tool-1", - "status": "success", - "content": [{"text": "Result"}] - }) - - # Close + # Should still be able to close cleanly await nova_model.close() - - # Verify all operations completed - assert not nova_model._active diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index ad0d3993a2..1209150ba9 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -6,7 +6,6 @@ - Unified send() method with different content types - Event receiving and conversion - Connection lifecycle management -- Background task management """ import asyncio @@ -39,7 +38,7 @@ def mock_websockets_connect(mock_websocket): """Mock websockets.connect function.""" async def async_connect(*args, **kwargs): return mock_websocket - + with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.websockets.connect") as mock_connect: mock_connect.side_effect = async_connect yield mock_connect, mock_websocket @@ -83,35 +82,34 @@ def messages(): # Initialization Tests -def test_init_default_config(): - """Test model initialization with default configuration.""" - model = OpenAIRealtimeBidirectionalModel(api_key="test-key") - - assert model.model == "gpt-realtime" - assert model.api_key == "test-key" - assert model._active is False - assert model.websocket is None - +def test_model_initialization(api_key, model_name): + """Test model initialization with various configurations.""" + # Test default config + model_default = OpenAIRealtimeBidirectionalModel(api_key="test-key") + assert model_default.model == "gpt-realtime" + assert model_default.api_key == "test-key" + assert model_default._active is False + assert model_default.websocket is None -def test_init_with_api_key(api_key, model_name): - """Test model initialization with API key.""" - model = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) - - assert model.model == model_name - assert model.api_key == api_key + # Test with custom model + model_custom = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + assert model_custom.model == model_name + assert model_custom.api_key == api_key - -def test_init_with_custom_config(model_name, api_key): - """Test model initialization with custom configuration.""" - model = OpenAIRealtimeBidirectionalModel( + # Test with organization and project + model_org = OpenAIRealtimeBidirectionalModel( model=model_name, api_key=api_key, organization="org-123", project="proj-456" ) - - assert model.organization == "org-123" - assert model.project == "proj-456" + assert model_org.organization == "org-123" + assert model_org.project == "proj-456" + + # Test with env API key + with unittest.mock.patch.dict("os.environ", {"OPENAI_API_KEY": "env-key"}): + model_env = OpenAIRealtimeBidirectionalModel() + assert model_env.api_key == "env-key" def test_init_without_api_key_raises(): @@ -121,158 +119,123 @@ def test_init_without_api_key_raises(): OpenAIRealtimeBidirectionalModel() -def test_init_with_env_api_key(): - """Test initialization with API key from environment.""" - with unittest.mock.patch.dict("os.environ", {"OPENAI_API_KEY": "env-key"}): - model = OpenAIRealtimeBidirectionalModel() - assert model.api_key == "env-key" - - # Connection Tests @pytest.mark.asyncio -async def test_connect_basic(mock_websockets_connect, model): - """Test basic connection establishment.""" +async def test_connection_lifecycle(mock_websockets_connect, model, system_prompt, tool_spec, messages): + """Test complete connection lifecycle with various configurations.""" mock_connect, mock_ws = mock_websockets_connect - + + # Test basic connection await model.connect() - assert model._active is True assert model.session_id is not None assert model.websocket == mock_ws assert model._event_queue is not None + assert model._response_task is not None mock_connect.assert_called_once() + # Test close + await model.close() + assert model._active is False + mock_ws.close.assert_called_once() -@pytest.mark.asyncio -async def test_connect_with_system_prompt(mock_websockets_connect, model, system_prompt): - """Test connection with system prompt.""" - _, mock_ws = mock_websockets_connect - + # Test connection with system prompt await model.connect(system_prompt=system_prompt) - - # Verify session.update was sent with system prompt calls = mock_ws.send.call_args_list - session_update_call = None - for call in calls: - message = json.loads(call[0][0]) - if message.get("type") == "session.update": - session_update_call = message - break - - assert session_update_call is not None - assert session_update_call["session"]["instructions"] == system_prompt - + session_update = next( + (json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "session.update"), + None + ) + assert session_update is not None + assert system_prompt in session_update["session"]["instructions"] + await model.close() -@pytest.mark.asyncio -async def test_connect_with_tools(mock_websockets_connect, model, tool_spec): - """Test connection with tools.""" - _, mock_ws = mock_websockets_connect - + # Test connection with tools await model.connect(tools=[tool_spec]) - - # Verify tools were included in session config calls = mock_ws.send.call_args_list - session_update_call = None - for call in calls: - message = json.loads(call[0][0]) - if message.get("type") == "session.update": - session_update_call = message - break - - assert session_update_call is not None - assert "tools" in session_update_call["session"] - + # Tools are sent in a separate session.update after initial connection + session_updates = [json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "session.update"] + assert len(session_updates) > 0 + # Check if any session update has tools + has_tools = any("tools" in update.get("session", {}) for update in session_updates) + assert has_tools + await model.close() -@pytest.mark.asyncio -async def test_connect_with_messages(mock_websockets_connect, model, messages): - """Test connection with message history.""" - _, mock_ws = mock_websockets_connect - + # Test connection with messages await model.connect(messages=messages) - - # Verify conversation items were created calls = mock_ws.send.call_args_list - item_create_calls = [ - json.loads(call[0][0]) for call in calls - if json.loads(call[0][0]).get("type") == "conversation.item.create" - ] - - assert len(item_create_calls) > 0 + item_creates = [json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "conversation.item.create"] + assert len(item_creates) > 0 + await model.close() + + # Test connection with organization header + model_org = OpenAIRealtimeBidirectionalModel(api_key="test-key", organization="org-123") + await model_org.connect() + call_kwargs = mock_connect.call_args.kwargs + headers = call_kwargs.get("additional_headers", []) + org_header = [h for h in headers if h[0] == "OpenAI-Organization"] + assert len(org_header) == 1 + assert org_header[0][1] == "org-123" + await model_org.close() @pytest.mark.asyncio -async def test_connect_error_handling(mock_websockets_connect, model): - """Test connection error handling.""" - mock_connect, _ = mock_websockets_connect +async def test_connection_edge_cases(mock_websockets_connect, api_key, model_name): + """Test connection error handling and edge cases.""" + mock_connect, mock_ws = mock_websockets_connect + + # Test connection error + model1 = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) mock_connect.side_effect = Exception("Connection failed") - with pytest.raises(Exception, match="Connection failed"): - await model.connect() + await model1.connect() + # Reset mock + async def async_connect(*args, **kwargs): + return mock_ws + mock_connect.side_effect = async_connect -@pytest.mark.asyncio -async def test_connect_when_already_active(mock_websockets_connect, model): - """Test that connect() raises exception when already active.""" - mock_connect, _ = mock_websockets_connect - - # First connection - await model.connect() - - # Second connection attempt should raise + # Test double connection + model2 = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + await model2.connect() with pytest.raises(RuntimeError, match="Connection already active"): - await model.connect() + await model2.connect() + await model2.close() + # Test close when not connected + model3 = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + await model3.close() # Should not raise -@pytest.mark.asyncio -async def test_connect_with_organization_header(mock_websockets_connect, api_key): - """Test connection includes organization header.""" - mock_connect, _ = mock_websockets_connect - - model = OpenAIRealtimeBidirectionalModel( - api_key=api_key, - organization="org-123" - ) - await model.connect() - - # Verify headers were passed - call_kwargs = mock_connect.call_args.kwargs - headers = call_kwargs.get("additional_headers", []) - org_header = [h for h in headers if h[0] == "OpenAI-Organization"] - assert len(org_header) == 1 - assert org_header[0][1] == "org-123" + # Test close error handling (should not raise, just log) + model4 = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + await model4.connect() + mock_ws.close.side_effect = Exception("Close failed") + await model4.close() # Should not raise + assert model4._active is False # Send Method Tests @pytest.mark.asyncio -async def test_send_text_input(mock_websockets_connect, model): - """Test sending text input through unified send() method.""" +async def test_send_all_content_types(mock_websockets_connect, model): + """Test sending all content types through unified send() method.""" _, mock_ws = mock_websockets_connect await model.connect() - + + # Test text input text_input: TextInputEvent = {"text": "Hello", "role": "user"} await model.send(text_input) - - # Verify conversation.item.create and response.create were sent calls = mock_ws.send.call_args_list messages = [json.loads(call[0][0]) for call in calls] - item_create = [m for m in messages if m.get("type") == "conversation.item.create"] response_create = [m for m in messages if m.get("type") == "response.create"] - assert len(item_create) > 0 assert len(response_create) > 0 - -@pytest.mark.asyncio -async def test_send_audio_input(mock_websockets_connect, model): - """Test sending audio input through unified send() method.""" - _, mock_ws = mock_websockets_connect - await model.connect() - + # Test audio input audio_input: AudioInputEvent = { "audioData": b"audio_bytes", "format": "pcm", @@ -280,179 +243,122 @@ async def test_send_audio_input(mock_websockets_connect, model): "channels": 1, } await model.send(audio_input) - - # Verify input_audio_buffer.append was sent calls = mock_ws.send.call_args_list messages = [json.loads(call[0][0]) for call in calls] - audio_append = [m for m in messages if m.get("type") == "input_audio_buffer.append"] assert len(audio_append) > 0 - - # Verify audio was base64 encoded assert "audio" in audio_append[0] decoded = base64.b64decode(audio_append[0]["audio"]) assert decoded == b"audio_bytes" - -@pytest.mark.asyncio -async def test_send_image_input(mock_websockets_connect, model): - """Test sending image input logs warning (not supported).""" - _, mock_ws = mock_websockets_connect - await model.connect() - - image_input: ImageInputEvent = { - "imageData": b"image_bytes", - "mimeType": "image/jpeg", - "encoding": "raw", - } - - with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.logger") as mock_logger: - await model.send(image_input) - mock_logger.warning.assert_called_with("Image input not supported by OpenAI Realtime API") - - -@pytest.mark.asyncio -async def test_send_tool_result(mock_websockets_connect, model): - """Test sending tool result through unified send() method.""" - _, mock_ws = mock_websockets_connect - await model.connect() - + # Test tool result tool_result: ToolResult = { "toolUseId": "tool-123", "status": "success", "content": [{"text": "Result: 42"}], } await model.send(tool_result) - - # Verify function_call_output was created calls = mock_ws.send.call_args_list messages = [json.loads(call[0][0]) for call in calls] - item_create = [m for m in messages if m.get("type") == "conversation.item.create"] assert len(item_create) > 0 - - # Verify it's a function_call_output item = item_create[-1].get("item", {}) assert item.get("type") == "function_call_output" assert item.get("call_id") == "tool-123" + await model.close() + @pytest.mark.asyncio -async def test_send_when_inactive(mock_websockets_connect, model): - """Test that send() does nothing when connection is inactive.""" +async def test_send_edge_cases(mock_websockets_connect, model): + """Test send() edge cases and error handling.""" _, mock_ws = mock_websockets_connect - - # Don't connect, so _active is False + + # Test send when inactive text_input: TextInputEvent = {"text": "Hello", "role": "user"} await model.send(text_input) - - # Verify nothing was sent mock_ws.send.assert_not_called() - -@pytest.mark.asyncio -async def test_send_unknown_content_type(mock_websockets_connect, model): - """Test sending unknown content type logs warning.""" - _, _ = mock_websockets_connect + # Test image input (not supported) await model.connect() - + image_input: ImageInputEvent = { + "imageData": b"image_bytes", + "mimeType": "image/jpeg", + "encoding": "raw", + } + with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.logger") as mock_logger: + await model.send(image_input) + mock_logger.warning.assert_called_with("Image input not supported by OpenAI Realtime API") + + # Test unknown content type unknown_content = {"unknown_field": "value"} - with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.logger") as mock_logger: await model.send(unknown_content) - # Should log warning about unknown content assert mock_logger.warning.called + await model.close() + # Receive Method Tests @pytest.mark.asyncio -async def test_receive_connection_start_event(mock_websockets_connect, model): - """Test that receive() emits connection start event.""" +async def test_receive_lifecycle_events(mock_websockets_connect, model): + """Test that receive() emits connection start and end events.""" _, _ = mock_websockets_connect - + await model.connect() - + # Get first event receive_gen = model.receive() first_event = await anext(receive_gen) - + # First event should be connection start assert "BidirectionalConnectionStart" in first_event assert first_event["BidirectionalConnectionStart"]["connectionId"] == model.session_id - - # Close to stop the loop + + # Close to trigger connection end await model.close() + # Collect remaining events + events = [first_event] + try: + async for event in receive_gen: + events.append(event) + except StopAsyncIteration: + pass -@pytest.mark.asyncio -async def test_receive_connection_end_event(mock_websockets_connect, model): - """Test that receive() emits connection end event.""" - _, _ = mock_websockets_connect - - await model.connect() - - # Collect events until connection ends - events = [] - async for event in model.receive(): - events.append(event) - # Close after first event to trigger connection end - if len(events) == 1: - await model.close() - # Last event should be connection end assert "BidirectionalConnectionEnd" in events[-1] @pytest.mark.asyncio -async def test_receive_audio_output(mock_websockets_connect, model): - """Test receiving audio output from model.""" +async def test_event_conversion(mock_websockets_connect, model): + """Test conversion of all OpenAI event types to standard format.""" _, _ = mock_websockets_connect await model.connect() - - # Create mock OpenAI event - openai_event = { + + # Test audio output + audio_event = { "type": "response.output_audio.delta", "delta": base64.b64encode(b"audio_data").decode() } - - # Test conversion directly - converted_event = model._convert_openai_event(openai_event) - - assert "audioOutput" in converted_event - assert converted_event["audioOutput"]["audioData"] == b"audio_data" - assert converted_event["audioOutput"]["format"] == "pcm" - + converted = model._convert_openai_event(audio_event) + assert "audioOutput" in converted + assert converted["audioOutput"]["audioData"] == b"audio_data" + assert converted["audioOutput"]["format"] == "pcm" -@pytest.mark.asyncio -async def test_receive_text_output(mock_websockets_connect, model): - """Test receiving text output from model.""" - _, _ = mock_websockets_connect - await model.connect() - - # Create mock OpenAI event - openai_event = { + # Test text output + text_event = { "type": "response.output_text.delta", "delta": "Hello from OpenAI" } - - # Test conversion directly - converted_event = model._convert_openai_event(openai_event) - - assert "textOutput" in converted_event - assert converted_event["textOutput"]["text"] == "Hello from OpenAI" - assert converted_event["textOutput"]["role"] == "assistant" + converted = model._convert_openai_event(text_event) + assert "textOutput" in converted + assert converted["textOutput"]["text"] == "Hello from OpenAI" + assert converted["textOutput"]["role"] == "assistant" - -@pytest.mark.asyncio -async def test_receive_function_call(mock_websockets_connect, model): - """Test receiving function call from model.""" - _, _ = mock_websockets_connect - await model.connect() - - # Simulate function call sequence - # First: output_item.added with function name + # Test function call sequence item_added = { "type": "response.output_item.added", "item": { @@ -462,182 +368,102 @@ async def test_receive_function_call(mock_websockets_connect, model): } } model._convert_openai_event(item_added) - - # Second: function_call_arguments.delta + args_delta = { "type": "response.function_call_arguments.delta", "call_id": "call-123", "delta": '{"expression": "2+2"}' } model._convert_openai_event(args_delta) - - # Third: function_call_arguments.done + args_done = { "type": "response.function_call_arguments.done", "call_id": "call-123" } - converted_event = model._convert_openai_event(args_done) - - assert "toolUse" in converted_event - assert converted_event["toolUse"]["toolUseId"] == "call-123" - assert converted_event["toolUse"]["name"] == "calculator" - assert converted_event["toolUse"]["input"]["expression"] == "2+2" - + converted = model._convert_openai_event(args_done) + assert "toolUse" in converted + assert converted["toolUse"]["toolUseId"] == "call-123" + assert converted["toolUse"]["name"] == "calculator" + assert converted["toolUse"]["input"]["expression"] == "2+2" -@pytest.mark.asyncio -async def test_receive_voice_activity(mock_websockets_connect, model): - """Test receiving voice activity events.""" - _, _ = mock_websockets_connect - await model.connect() - - # Test speech started + # Test voice activity speech_started = { "type": "input_audio_buffer.speech_started" } - converted_event = model._convert_openai_event(speech_started) - - assert "voiceActivity" in converted_event - assert converted_event["voiceActivity"]["activityType"] == "speech_started" + converted = model._convert_openai_event(speech_started) + assert "voiceActivity" in converted + assert converted["voiceActivity"]["activityType"] == "speech_started" - -# Close Method Tests - - -@pytest.mark.asyncio -async def test_close_connection(mock_websockets_connect, model): - """Test closing connection.""" - _, mock_ws = mock_websockets_connect - - await model.connect() await model.close() - - assert model._active is False - mock_ws.close.assert_called_once() - - -@pytest.mark.asyncio -async def test_close_when_not_connected(mock_websockets_connect, model): - """Test closing when not connected does nothing.""" - _, mock_ws = mock_websockets_connect - - # Don't connect - await model.close() - - # Should not raise, and close should not be called - mock_ws.close.assert_not_called() - - -@pytest.mark.asyncio -async def test_close_error_handling(mock_websockets_connect, model): - """Test close error handling.""" - _, mock_ws = mock_websockets_connect - mock_ws.close.side_effect = Exception("Close failed") - - await model.connect() - - # Should not raise, just log warning - await model.close() - assert model._active is False - - -@pytest.mark.asyncio -async def test_close_cancels_response_task(mock_websockets_connect, model): - """Test that close cancels the background response task.""" - _, _ = mock_websockets_connect - - await model.connect() - - # Verify response task is running - assert model._response_task is not None - assert not model._response_task.done() - - await model.close() - - # Task should be cancelled - assert model._response_task.cancelled() or model._response_task.done() # Helper Method Tests -def test_build_session_config_basic(model): - """Test building basic session config.""" - config = model._build_session_config(None, None) - - assert isinstance(config, dict) - assert "instructions" in config - assert "audio" in config +def test_config_building(model, system_prompt, tool_spec): + """Test building session config with various options.""" + # Test basic config + config_basic = model._build_session_config(None, None) + assert isinstance(config_basic, dict) + assert "instructions" in config_basic + assert "audio" in config_basic + # Test with system prompt + config_prompt = model._build_session_config(system_prompt, None) + assert config_prompt["instructions"] == system_prompt -def test_build_session_config_with_system_prompt(model, system_prompt): - """Test building config with system prompt.""" - config = model._build_session_config(system_prompt, None) - - assert config["instructions"] == system_prompt + # Test with tools + config_tools = model._build_session_config(None, [tool_spec]) + assert "tools" in config_tools + assert len(config_tools["tools"]) > 0 -def test_build_session_config_with_tools(model, tool_spec): - """Test building config with tools.""" - config = model._build_session_config(None, [tool_spec]) - - assert "tools" in config - assert len(config["tools"]) > 0 - - -def test_convert_tools_to_openai_format(model, tool_spec): +def test_tool_conversion(model, tool_spec): """Test tool conversion to OpenAI format.""" + # Test with tools openai_tools = model._convert_tools_to_openai_format([tool_spec]) - assert len(openai_tools) == 1 assert openai_tools[0]["type"] == "function" assert openai_tools[0]["name"] == "calculator" assert openai_tools[0]["description"] == "Calculate mathematical expressions" + # Test empty list + openai_empty = model._convert_tools_to_openai_format([]) + assert openai_empty == [] + + +def test_helper_methods(model): + """Test various helper methods.""" + # Test _require_active + assert model._require_active() is False + model._active = True + assert model._require_active() is True + model._active = False + + # Test _create_text_event + text_event = model._create_text_event("Hello", "user") + assert "textOutput" in text_event + assert text_event["textOutput"]["text"] == "Hello" + assert text_event["textOutput"]["role"] == "user" -def test_convert_tools_empty_list(model): - """Test converting empty tool list.""" - openai_tools = model._convert_tools_to_openai_format([]) - - assert openai_tools == [] + # Test _create_voice_activity_event + voice_event = model._create_voice_activity_event("speech_started") + assert "voiceActivity" in voice_event + assert voice_event["voiceActivity"]["activityType"] == "speech_started" @pytest.mark.asyncio -async def test_send_event(mock_websockets_connect, model): - """Test sending event to WebSocket.""" +async def test_send_event_helper(mock_websockets_connect, model): + """Test _send_event helper method.""" _, mock_ws = mock_websockets_connect await model.connect() - + test_event = {"type": "test.event", "data": "test"} await model._send_event(test_event) - - # Verify event was sent as JSON + calls = mock_ws.send.call_args_list last_call = calls[-1] sent_message = json.loads(last_call[0][0]) - assert sent_message == test_event - -def test_require_active(model): - """Test _require_active method.""" - assert model._require_active() is False - - model._active = True - assert model._require_active() is True - - -def test_create_text_event(model): - """Test creating text event.""" - event = model._create_text_event("Hello", "user") - - assert "textOutput" in event - assert event["textOutput"]["text"] == "Hello" - assert event["textOutput"]["role"] == "user" - - -def test_create_voice_activity_event(model): - """Test creating voice activity event.""" - event = model._create_voice_activity_event("speech_started") - - assert "voiceActivity" in event - assert event["voiceActivity"]["activityType"] == "speech_started" + await model.close() From 28fe471b099dfd149d08d4e8dd0e567e8916017b Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 30 Oct 2025 15:18:12 +0100 Subject: [PATCH 07/11] fix: update comments --- .../bidirectional_streaming/models/novasonic.py | 8 +++----- .../experimental/bidirectional_streaming/models/openai.py | 6 ++---- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 5436b5ae71..b9c5060ba7 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -1,11 +1,9 @@ """Nova Sonic bidirectional model provider for real-time streaming conversations. -Implements the unified BidirectionalModel interface for Amazon's Nova Sonic, handling the +Implements the BidirectionalModel interface for Amazon's Nova Sonic, handling the complex event sequencing and audio processing required by Nova Sonic's InvokeModelWithBidirectionalStream protocol. -Unified model interface - combines configuration and connection state in single class. - Nova Sonic specifics: - Hierarchical event sequences: connectionStart → promptStart → content streaming - Base64-encoded audio format with hex encoding @@ -81,7 +79,7 @@ class NovaSonicBidirectionalModel(BidirectionalModel): - """Unified Nova Sonic implementation for bidirectional streaming. + """Nova Sonic implementation for bidirectional streaming. Combines model configuration and connection state in a single class. Manages Nova Sonic's complex event sequencing, audio format conversion, and @@ -305,7 +303,7 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: yield {"BidirectionalConnectionEnd": connection_end} async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: - """Unified send method for all content types. + """Unified send method for all content types. Sends the given content to Nova Sonic. Dispatches to appropriate internal handler based on content type. diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 8322eef4bc..16f3ac4a30 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -2,8 +2,6 @@ Provides real-time audio and text communication through OpenAI's Realtime API with WebSocket connections, voice activity detection, and function calling. - -Unified model interface - combines configuration and connection state in single class. """ import asyncio @@ -60,7 +58,7 @@ class OpenAIRealtimeBidirectionalModel(BidirectionalModel): - """Unified OpenAI Realtime API implementation for bidirectional streaming. + """OpenAI Realtime API implementation for bidirectional streaming. Combines model configuration and connection state in a single class. Manages WebSocket connection to OpenAI's Realtime API with automatic VAD, @@ -434,7 +432,7 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] return None async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: - """Unified send method for all content types. + """Unified send method for all content types. Sends the given content to OpenAI. Dispatches to appropriate internal handler based on content type. From 60eb493319256f80a1a50d62ba1411219c719d1d Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 30 Oct 2025 15:38:10 +0100 Subject: [PATCH 08/11] fix: move import to top --- .../experimental/bidirectional_streaming/models/openai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 16f3ac4a30..a542ec894e 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -8,6 +8,7 @@ import base64 import json import logging +import os import uuid from typing import AsyncIterable, Union @@ -91,7 +92,6 @@ def __init__( self.project = project self.session_config = session_config or {} - import os if not self.api_key: self.api_key = os.getenv("OPENAI_API_KEY") if not self.api_key: From 6529187994558056c91f41cecef053abac68c955 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 31 Oct 2025 13:51:47 +0100 Subject: [PATCH 09/11] fix: use protocol and improve _active handling --- .../models/bidirectional_model.py | 19 +++++++------------ .../models/gemini_live.py | 1 + .../models/novasonic.py | 1 + .../bidirectional_streaming/models/openai.py | 1 + 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 5b7091dcdc..05fb19e0f9 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -12,9 +12,8 @@ - Support for audio, text, image, and tool result streaming """ -import abc import logging -from typing import AsyncIterable, Union +from typing import AsyncIterable, Protocol, Union from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec @@ -28,15 +27,14 @@ logger = logging.getLogger(__name__) -class BidirectionalModel(abc.ABC): - """Abstract base class for bidirectional streaming models. +class BidirectionalModel(Protocol): + """Protocol for bidirectional streaming models. This interface defines the contract for models that support persistent streaming connections with real-time audio and text communication. Implementations handle provider-specific protocols while exposing a standardized event-based API. """ - @abc.abstractmethod async def connect( self, system_prompt: str | None = None, @@ -56,9 +54,8 @@ async def connect( messages: Initial conversation history to provide context. **kwargs: Provider-specific configuration options. """ - raise NotImplementedError + ... - @abc.abstractmethod async def close(self) -> None: """Close the streaming connection and release resources. @@ -66,9 +63,8 @@ async def close(self) -> None: resources such as network connections, buffers, or background tasks. After calling close(), the model instance cannot be used until connect() is called again. """ - raise NotImplementedError + ... - @abc.abstractmethod async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: """Receive streaming events from the model. @@ -82,9 +78,8 @@ async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: BidirectionalStreamEvent: Standardized event dictionaries containing audio output, text responses, tool calls, or control signals. """ - raise NotImplementedError + ... - @abc.abstractmethod async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: """Send content to the model over the active connection. @@ -104,4 +99,4 @@ async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputE await model.send(AudioInputEvent(audioData=bytes, format="pcm", ...)) await model.send(ToolResult(toolUseId="123", status="success", ...)) """ - raise NotImplementedError + ... diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 639328c64b..cef8135eb0 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -128,6 +128,7 @@ async def connect( await self._send_message_history(messages) except Exception as e: + self._active = False logger.error("Error connecting to Gemini Live: %s", e) raise diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index b9c5060ba7..ddb0540f6e 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -182,6 +182,7 @@ async def connect( logger.info("Nova Sonic connection established successfully") except Exception as e: + self._active = False logger.error("Nova connection create error: %s", str(e)) raise diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index a542ec894e..e64508db78 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -160,6 +160,7 @@ async def connect( logger.info("OpenAI Realtime connection established") except Exception as e: + self._active = False logger.error("OpenAI connection error: %s", e) raise From 990d905a27d64734ed24ca36a09b78c951e0ce39 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 31 Oct 2025 14:03:17 +0100 Subject: [PATCH 10/11] refactor: simplify model names --- .../bidirectional_streaming/__init__.py | 12 ++++---- .../models/__init__.py | 12 ++++---- .../models/gemini_live.py | 2 +- .../models/novasonic.py | 4 +-- .../bidirectional_streaming/models/openai.py | 2 +- .../tests/test_bidi_novasonic.py | 6 ++-- .../tests/test_bidi_openai.py | 4 +-- .../tests/test_gemini_live.py | 4 +-- .../models/test_gemini_live.py | 22 +++++++-------- .../models/test_novasonic.py | 8 +++--- .../models/test_openai_realtime.py | 28 +++++++++---------- 11 files changed, 51 insertions(+), 53 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index d855ba0388..caee4715a2 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -7,9 +7,9 @@ from .models.bidirectional_model import BidirectionalModel # Model providers - What users need to create models -from .models.gemini_live import GeminiLiveBidirectionalModel -from .models.novasonic import NovaSonicBidirectionalModel -from .models.openai import OpenAIRealtimeBidirectionalModel +from .models.gemini_live import GeminiLiveModel +from .models.novasonic import NovaSonicModel +from .models.openai import OpenAIRealtimeModel # Event types - For type hints and event handling from .types.bidirectional_streaming import ( @@ -29,9 +29,9 @@ "BidirectionalAgent", # Model providers - "GeminiLiveBidirectionalModel", - "NovaSonicBidirectionalModel", - "OpenAIRealtimeBidirectionalModel", + "GeminiLiveModel", + "NovaSonicModel", + "OpenAIRealtimeModel", # Event types "AudioInputEvent", diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index 12fe6c2715..5b0d506877 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -1,13 +1,13 @@ """Bidirectional model interfaces and implementations.""" from .bidirectional_model import BidirectionalModel -from .gemini_live import GeminiLiveBidirectionalModel -from .novasonic import NovaSonicBidirectionalModel -from .openai import OpenAIRealtimeBidirectionalModel +from .gemini_live import GeminiLiveModel +from .novasonic import NovaSonicModel +from .openai import OpenAIRealtimeModel __all__ = [ "BidirectionalModel", - "GeminiLiveBidirectionalModel", - "NovaSonicBidirectionalModel", - "OpenAIRealtimeBidirectionalModel", + "GeminiLiveModel", + "NovaSonicModel", + "OpenAIRealtimeModel", ] diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index cef8135eb0..9f0cfe6c08 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -44,7 +44,7 @@ GEMINI_CHANNELS = 1 -class GeminiLiveBidirectionalModel(BidirectionalModel): +class GeminiLiveModel(BidirectionalModel): """Gemini Live API implementation using official Google GenAI SDK. Combines model configuration and connection state in a single class. diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index ddb0540f6e..c9e5805db5 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -78,7 +78,7 @@ RESPONSE_TIMEOUT = 1.0 -class NovaSonicBidirectionalModel(BidirectionalModel): +class NovaSonicModel(BidirectionalModel): """Nova Sonic implementation for bidirectional streaming. Combines model configuration and connection state in a single class. @@ -111,7 +111,6 @@ def __init__( # Nova Sonic requires unique content names self.audio_content_name = None - self.text_content_name = None # Audio connection state self.audio_connection_active = False @@ -154,7 +153,6 @@ async def connect( self.prompt_name = str(uuid.uuid4()) self._active = True self.audio_content_name = str(uuid.uuid4()) - self.text_content_name = str(uuid.uuid4()) self._event_queue = asyncio.Queue() # Start Nova Sonic bidirectional stream diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index e64508db78..0810b7b217 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -58,7 +58,7 @@ } -class OpenAIRealtimeBidirectionalModel(BidirectionalModel): +class OpenAIRealtimeModel(BidirectionalModel): """OpenAI Realtime API implementation for bidirectional streaming. Combines model configuration and connection state in a single class. diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py index 8c3ae3b4c4..b0a41f20d4 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py @@ -17,7 +17,7 @@ from strands_tools import calculator from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel +from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicModel def test_direct_tools(): @@ -30,7 +30,7 @@ def test_direct_tools(): return try: - model = NovaSonicBidirectionalModel() + model = NovaSonicModel() agent = BidirectionalAgent(model=model, tools=[calculator]) # Test calculator @@ -185,7 +185,7 @@ async def main(duration=180): print("Audio optimizations: 1024-byte buffers, balanced smooth playback + responsive interruption") # Initialize model and agent - model = NovaSonicBidirectionalModel(region="us-east-1") + model = NovaSonicModel(region="us-east-1") agent = BidirectionalAgent(model=model, tools=[calculator], system_prompt="You are a helpful assistant.") await agent.start() diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py index 660040f3ed..90e82c2bc9 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py @@ -14,7 +14,7 @@ from strands_tools import calculator from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeBidirectionalModel +from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeModel async def play(context): @@ -205,7 +205,7 @@ async def main(): return False # Create OpenAI model - model = OpenAIRealtimeBidirectionalModel( + model = OpenAIRealtimeModel( model="gpt-4o-realtime-preview", api_key=api_key, session={ diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py index 4469e819a6..23e97bd5d2 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py @@ -38,7 +38,7 @@ from strands_tools import calculator from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveBidirectionalModel +from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveModel # Configure logging - debug only for Gemini Live, info for everything else logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') @@ -301,7 +301,7 @@ async def main(duration=180): # Initialize Gemini Live model with proper configuration logger.info("Initializing Gemini Live model with API key") - model = GeminiLiveBidirectionalModel( + model = GeminiLiveModel( model_id="gemini-2.5-flash-native-audio-preview-09-2025", api_key=api_key, params={ diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py index b894509c91..8c0a61b4b5 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -1,6 +1,6 @@ """Unit tests for Gemini Live bidirectional streaming model. -Tests the unified GeminiLiveBidirectionalModel interface including: +Tests the unified GeminiLiveModel interface including: - Model initialization and configuration - Connection establishment and lifecycle - Unified send() method with different content types @@ -13,7 +13,7 @@ from google import genai from google.genai import types as genai_types -from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveBidirectionalModel +from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveModel from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( AudioInputEvent, ImageInputEvent, @@ -55,9 +55,9 @@ def api_key(): @pytest.fixture def model(mock_genai_client, model_id, api_key): - """Create a GeminiLiveBidirectionalModel instance.""" + """Create a GeminiLiveModel instance.""" _ = mock_genai_client - return GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + return GeminiLiveModel(model_id=model_id, api_key=api_key) @pytest.fixture @@ -87,20 +87,20 @@ def test_model_initialization(mock_genai_client, model_id, api_key): _ = mock_genai_client # Test default config - model_default = GeminiLiveBidirectionalModel() + model_default = GeminiLiveModel() assert model_default.model_id == "models/gemini-2.0-flash-live-preview-04-09" assert model_default.api_key is None assert model_default._active is False assert model_default.live_session is None # Test with API key - model_with_key = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + model_with_key = GeminiLiveModel(model_id=model_id, api_key=api_key) assert model_with_key.model_id == model_id assert model_with_key.api_key == api_key # Test with custom config live_config = {"temperature": 0.7, "top_p": 0.9} - model_custom = GeminiLiveBidirectionalModel(model_id=model_id, live_config=live_config) + model_custom = GeminiLiveModel(model_id=model_id, live_config=live_config) assert model_custom.live_config == live_config @@ -151,7 +151,7 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id): mock_client, _, mock_live_session_cm = mock_genai_client # Test connection error - model1 = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + model1 = GeminiLiveModel(model_id=model_id, api_key=api_key) mock_client.aio.live.connect.side_effect = Exception("Connection failed") with pytest.raises(Exception, match="Connection failed"): await model1.connect() @@ -160,18 +160,18 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id): mock_client.aio.live.connect.side_effect = None # Test double connection - model2 = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + model2 = GeminiLiveModel(model_id=model_id, api_key=api_key) await model2.connect() with pytest.raises(RuntimeError, match="Connection already active"): await model2.connect() await model2.close() # Test close when not connected - model3 = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + model3 = GeminiLiveModel(model_id=model_id, api_key=api_key) await model3.close() # Should not raise # Test close error handling - model4 = GeminiLiveBidirectionalModel(model_id=model_id, api_key=api_key) + model4 = GeminiLiveModel(model_id=model_id, api_key=api_key) await model4.connect() mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") with pytest.raises(Exception, match="Close failed"): diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 10066a6938..5601e23b8b 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -13,7 +13,7 @@ import pytest_asyncio from strands.experimental.bidirectional_streaming.models.novasonic import ( - NovaSonicBidirectionalModel, + NovaSonicModel, ) from strands.types.tools import ToolResult @@ -53,7 +53,7 @@ def mock_client(mock_stream): @pytest_asyncio.fixture async def nova_model(model_id, region): """Create Nova Sonic model instance.""" - model = NovaSonicBidirectionalModel(model_id=model_id, region=region) + model = NovaSonicModel(model_id=model_id, region=region) yield model # Cleanup if model._active: @@ -66,7 +66,7 @@ async def nova_model(model_id, region): @pytest.mark.asyncio async def test_model_initialization(model_id, region): """Test model initialization with configuration.""" - model = NovaSonicBidirectionalModel(model_id=model_id, region=region) + model = NovaSonicModel(model_id=model_id, region=region) assert model.model_id == model_id assert model.region == region @@ -120,7 +120,7 @@ async def test_connection_edge_cases(nova_model, mock_client, mock_stream, model await nova_model.close() # Test close when already closed - model2 = NovaSonicBidirectionalModel(model_id=model_id, region=region) + model2 = NovaSonicModel(model_id=model_id, region=region) await model2.close() # Should not raise await model2.close() # Second call should also be safe diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 1209150ba9..388fc95cc6 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -1,6 +1,6 @@ """Unit tests for OpenAI Realtime bidirectional streaming model. -Tests the unified OpenAIRealtimeBidirectionalModel interface including: +Tests the unified OpenAIRealtimeModel interface including: - Model initialization and configuration - Connection establishment with WebSocket - Unified send() method with different content types @@ -15,7 +15,7 @@ import pytest -from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeBidirectionalModel +from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeModel from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( AudioInputEvent, ImageInputEvent, @@ -56,8 +56,8 @@ def api_key(): @pytest.fixture def model(api_key, model_name): - """Create an OpenAIRealtimeBidirectionalModel instance.""" - return OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + """Create an OpenAIRealtimeModel instance.""" + return OpenAIRealtimeModel(model=model_name, api_key=api_key) @pytest.fixture @@ -85,19 +85,19 @@ def messages(): def test_model_initialization(api_key, model_name): """Test model initialization with various configurations.""" # Test default config - model_default = OpenAIRealtimeBidirectionalModel(api_key="test-key") + model_default = OpenAIRealtimeModel(api_key="test-key") assert model_default.model == "gpt-realtime" assert model_default.api_key == "test-key" assert model_default._active is False assert model_default.websocket is None # Test with custom model - model_custom = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + model_custom = OpenAIRealtimeModel(model=model_name, api_key=api_key) assert model_custom.model == model_name assert model_custom.api_key == api_key # Test with organization and project - model_org = OpenAIRealtimeBidirectionalModel( + model_org = OpenAIRealtimeModel( model=model_name, api_key=api_key, organization="org-123", @@ -108,7 +108,7 @@ def test_model_initialization(api_key, model_name): # Test with env API key with unittest.mock.patch.dict("os.environ", {"OPENAI_API_KEY": "env-key"}): - model_env = OpenAIRealtimeBidirectionalModel() + model_env = OpenAIRealtimeModel() assert model_env.api_key == "env-key" @@ -116,7 +116,7 @@ def test_init_without_api_key_raises(): """Test that initialization without API key raises error.""" with unittest.mock.patch.dict("os.environ", {}, clear=True): with pytest.raises(ValueError, match="OpenAI API key is required"): - OpenAIRealtimeBidirectionalModel() + OpenAIRealtimeModel() # Connection Tests @@ -171,7 +171,7 @@ async def test_connection_lifecycle(mock_websockets_connect, model, system_promp await model.close() # Test connection with organization header - model_org = OpenAIRealtimeBidirectionalModel(api_key="test-key", organization="org-123") + model_org = OpenAIRealtimeModel(api_key="test-key", organization="org-123") await model_org.connect() call_kwargs = mock_connect.call_args.kwargs headers = call_kwargs.get("additional_headers", []) @@ -187,7 +187,7 @@ async def test_connection_edge_cases(mock_websockets_connect, api_key, model_nam mock_connect, mock_ws = mock_websockets_connect # Test connection error - model1 = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + model1 = OpenAIRealtimeModel(model=model_name, api_key=api_key) mock_connect.side_effect = Exception("Connection failed") with pytest.raises(Exception, match="Connection failed"): await model1.connect() @@ -198,18 +198,18 @@ async def async_connect(*args, **kwargs): mock_connect.side_effect = async_connect # Test double connection - model2 = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + model2 = OpenAIRealtimeModel(model=model_name, api_key=api_key) await model2.connect() with pytest.raises(RuntimeError, match="Connection already active"): await model2.connect() await model2.close() # Test close when not connected - model3 = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + model3 = OpenAIRealtimeModel(model=model_name, api_key=api_key) await model3.close() # Should not raise # Test close error handling (should not raise, just log) - model4 = OpenAIRealtimeBidirectionalModel(model=model_name, api_key=api_key) + model4 = OpenAIRealtimeModel(model=model_name, api_key=api_key) await model4.connect() mock_ws.close.side_effect = Exception("Close failed") await model4.close() # Should not raise From 231f1718d8533fe78d4f98bca11a7cff73235084 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 4 Nov 2025 15:27:07 +0300 Subject: [PATCH 11/11] fix: address comments in pr --- .../event_loop/bidirectional_event_loop.py | 6 +-- .../models/gemini_live.py | 11 ++--- .../models/novasonic.py | 41 ++++++++++--------- .../bidirectional_streaming/models/openai.py | 3 +- .../models/test_novasonic.py | 26 ++++++------ .../models/test_openai_realtime.py | 2 +- 6 files changed, 46 insertions(+), 43 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index d1d6e90b32..38d92aea8f 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -473,8 +473,8 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: try: await session.model.send(error_result) logger.debug("Error result sent: %s", tool_id) - except Exception: - logger.error("Failed to send error result: %s", tool_id) - pass # Connection might be closed + except Exception as send_error: + logger.error("Failed to send error result: %s - %s", tool_id, str(send_error)) + raise # Propagate exception since this is experimental code diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 9f0cfe6c08..ffff98cf10 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -84,7 +84,7 @@ def __init__( # Connection state (initialized in connect()) self.live_session = None - self.live_session_cm = None + self.live_session_context_manager = None self.session_id = None self._active = False @@ -115,13 +115,13 @@ async def connect( live_config = self._build_live_config(system_prompt, tools, **kwargs) # Create the context manager - self.live_session_cm = self.client.aio.live.connect( + self.live_session_context_manager = self.client.aio.live.connect( model=self.model_id, config=live_config ) # Enter the context manager - self.live_session = await self.live_session_cm.__aenter__() + self.live_session = await self.live_session_context_manager.__aenter__() # Send initial message history if provided if messages: @@ -312,6 +312,7 @@ async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputE logger.warning(f"Unknown content type with keys: {content.keys()}") except Exception as e: logger.error(f"Error sending content: {e}") + raise # Propagate exception for debugging in experimental code async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: """Internal: Send audio content using Gemini Live API. @@ -412,8 +413,8 @@ async def close(self) -> None: try: # Exit the context manager properly - if self.live_session_cm: - await self.live_session_cm.__aexit__(None, None, None) + if self.live_session_context_manager: + await self.live_session_context_manager.__aexit__(None, None, None) except Exception as e: logger.error("Error closing Gemini Live connection: %s", e) raise diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index c9e5805db5..e4c0d15650 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -102,11 +102,11 @@ def __init__( # Model configuration self.model_id = model_id self.region = region - self._client = None + self.client = None # Connection state (initialized in connect()) self.stream = None - self.prompt_name = None + self.session_id = None self._active = False # Nova Sonic requires unique content names @@ -146,17 +146,17 @@ async def connect( try: # Initialize client if needed - if not self._client: + if not self.client: await self._initialize_client() # Initialize connection state - self.prompt_name = str(uuid.uuid4()) + self.session_id = str(uuid.uuid4()) self._active = True self.audio_content_name = str(uuid.uuid4()) self._event_queue = asyncio.Queue() # Start Nova Sonic bidirectional stream - self.stream = await self._client.invoke_model_with_bidirectional_stream( + self.stream = await self.client.invoke_model_with_bidirectional_stream( InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) ) @@ -165,7 +165,7 @@ async def connect( logger.error("Stream is None") raise ValueError("Stream cannot be None") - logger.debug("Nova Sonic connection initialized with prompt: %s", self.prompt_name) + logger.debug("Nova Sonic connection initialized with session: %s", self.session_id) # Send initialization events system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." @@ -269,7 +269,7 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: # Emit connection start event to Strands event system connection_start: BidirectionalConnectionStartEvent = { - "connectionId": self.prompt_name, + "connectionId": self.session_id, "metadata": {"provider": "nova_sonic", "model_id": self.model_id}, } yield {"BidirectionalConnectionStart": connection_start} @@ -295,7 +295,7 @@ async def receive(self) -> AsyncIterable[dict[str, any]]: finally: # Emit connection end event when exiting connection_end: BidirectionalConnectionEndEvent = { - "connectionId": self.prompt_name, + "connectionId": self.session_id, "reason": "connection_complete", "metadata": {"provider": "nova_sonic"}, } @@ -331,6 +331,7 @@ async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputE logger.warning(f"Unknown content type with keys: {content.keys()}") except Exception as e: logger.error(f"Error sending content: {e}") + raise # Propagate exception for debugging in experimental code async def _start_audio_connection(self) -> None: """Internal: Start audio input connection (call once before sending audio chunks).""" @@ -343,7 +344,7 @@ async def _start_audio_connection(self) -> None: { "event": { "contentStart": { - "promptName": self.prompt_name, + "promptName": self.session_id, "contentName": self.audio_content_name, "type": "AUDIO", "interactive": True, @@ -376,7 +377,7 @@ async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: { "event": { "audioInput": { - "promptName": self.prompt_name, + "promptName": self.session_id, "contentName": self.audio_content_name, "content": nova_audio_data, } @@ -409,7 +410,7 @@ async def _end_audio_input(self) -> None: logger.debug("Nova audio connection end") audio_content_end = json.dumps( - {"event": {"contentEnd": {"promptName": self.prompt_name, "contentName": self.audio_content_name}}} + {"event": {"contentEnd": {"promptName": self.session_id, "contentName": self.audio_content_name}}} ) await self._send_nova_event(audio_content_end) @@ -434,7 +435,7 @@ async def _send_interrupt(self) -> None: { "event": { "audioInput": { - "promptName": self.prompt_name, + "promptName": self.session_id, "contentName": self.audio_content_name, "stopReason": "INTERRUPTED", } @@ -600,7 +601,7 @@ def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: prompt_start_event = { "event": { "promptStart": { - "promptName": self.prompt_name, + "promptName": self.session_id, "textOutputConfiguration": NOVA_TEXT_CONFIG, "audioOutputConfiguration": NOVA_AUDIO_OUTPUT_CONFIG, } @@ -644,7 +645,7 @@ def _get_text_content_start_event(self, content_name: str, role: str = "USER") - { "event": { "contentStart": { - "promptName": self.prompt_name, + "promptName": self.session_id, "contentName": content_name, "type": "TEXT", "role": role, @@ -661,7 +662,7 @@ def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> { "event": { "contentStart": { - "promptName": self.prompt_name, + "promptName": self.session_id, "contentName": content_name, "interactive": False, "type": "TOOL", @@ -679,7 +680,7 @@ def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> def _get_text_input_event(self, content_name: str, text: str) -> str: """Generate text input event.""" return json.dumps( - {"event": {"textInput": {"promptName": self.prompt_name, "contentName": content_name, "content": text}}} + {"event": {"textInput": {"promptName": self.session_id, "contentName": content_name, "content": text}}} ) def _get_tool_result_event(self, content_name: str, result: dict[str, any]) -> str: @@ -688,7 +689,7 @@ def _get_tool_result_event(self, content_name: str, result: dict[str, any]) -> s { "event": { "toolResult": { - "promptName": self.prompt_name, + "promptName": self.session_id, "contentName": content_name, "content": json.dumps(result), } @@ -698,11 +699,11 @@ def _get_tool_result_event(self, content_name: str, result: dict[str, any]) -> s def _get_content_end_event(self, content_name: str) -> str: """Generate content end event.""" - return json.dumps({"event": {"contentEnd": {"promptName": self.prompt_name, "contentName": content_name}}}) + return json.dumps({"event": {"contentEnd": {"promptName": self.session_id, "contentName": content_name}}}) def _get_prompt_end_event(self) -> str: """Generate prompt end event.""" - return json.dumps({"event": {"promptEnd": {"promptName": self.prompt_name}}}) + return json.dumps({"event": {"promptEnd": {"promptName": self.session_id}}}) def _get_connection_end_event(self) -> str: """Generate connection end event.""" @@ -733,7 +734,7 @@ async def _initialize_client(self) -> None: auth_schemes={"aws.auth#sigv4": SigV4AuthScheme(service="bedrock")}, ) - self._client = BedrockRuntimeClient(config=config) + self.client = BedrockRuntimeClient(config=config) logger.debug("Nova Sonic client initialized") except ImportError as e: diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 0810b7b217..4bf43b563d 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -144,7 +144,7 @@ async def connect( if self.project: headers.append(("OpenAI-Project", self.project)) - self.websocket = await websockets.connect(url, additional_headers=headers) + self.websocket = await websockets.connect(url, extra_headers=headers) logger.info("WebSocket connected successfully") # Configure session @@ -462,6 +462,7 @@ async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputE logger.warning(f"Unknown content type with keys: {content.keys()}") except Exception as e: logger.error(f"Error sending content: {e}") + raise # Propagate exception for debugging in experimental code async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: """Internal: Send audio content to OpenAI for processing.""" diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py index 5601e23b8b..7265bfacde 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -72,20 +72,20 @@ async def test_model_initialization(model_id, region): assert model.region == region assert model.stream is None assert not model._active - assert model.prompt_name is None + assert model.session_id is None @pytest.mark.asyncio async def test_connection_lifecycle(nova_model, mock_client, mock_stream): """Test complete connection lifecycle with various configurations.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client + nova_model.client = mock_client # Test basic connection await nova_model.connect(system_prompt="Test system prompt") assert nova_model._active assert nova_model.stream == mock_stream - assert nova_model.prompt_name is not None + assert nova_model.session_id is not None assert mock_client.invoke_model_with_bidirectional_stream.called # Test close @@ -111,7 +111,7 @@ async def test_connection_lifecycle(nova_model, mock_client, mock_stream): async def test_connection_edge_cases(nova_model, mock_client, mock_stream, model_id, region): """Test connection error handling and edge cases.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client + nova_model.client = mock_client # Test double connection await nova_model.connect() @@ -132,7 +132,7 @@ async def test_connection_edge_cases(nova_model, mock_client, mock_stream, model async def test_send_all_content_types(nova_model, mock_client, mock_stream): """Test sending all content types through unified send() method.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client + nova_model.client = mock_client await nova_model.connect() @@ -171,7 +171,7 @@ async def test_send_all_content_types(nova_model, mock_client, mock_stream): async def test_send_edge_cases(nova_model, mock_client, mock_stream, caplog): """Test send() edge cases and error handling.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client + nova_model.client = mock_client # Test send when inactive text_event = {"text": "Hello", "role": "user"} @@ -197,7 +197,7 @@ async def test_send_edge_cases(nova_model, mock_client, mock_stream, caplog): async def test_receive_lifecycle_events(nova_model, mock_client, mock_stream): """Test that receive() emits connection start and end events.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client + nova_model.client = mock_client # Setup mock to return no events and then stop async def mock_wait_for(*args, **kwargs): @@ -215,7 +215,7 @@ async def mock_wait_for(*args, **kwargs): # Should have connection start and end assert len(events) >= 2 assert "BidirectionalConnectionStart" in events[0] - assert events[0]["BidirectionalConnectionStart"]["connectionId"] == nova_model.prompt_name + assert events[0]["BidirectionalConnectionStart"]["connectionId"] == nova_model.session_id assert "BidirectionalConnectionEnd" in events[-1] @@ -301,7 +301,7 @@ async def test_event_conversion(nova_model): async def test_audio_connection_lifecycle(nova_model, mock_client, mock_stream): """Test audio connection start and end lifecycle.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client + nova_model.client = mock_client await nova_model.connect() @@ -320,7 +320,7 @@ async def test_audio_connection_lifecycle(nova_model, mock_client, mock_stream): async def test_silence_detection(nova_model, mock_client, mock_stream): """Test that silence detection automatically ends audio input.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client + nova_model.client = mock_client nova_model.silence_threshold = 0.1 # Short threshold for testing await nova_model.connect() @@ -385,12 +385,12 @@ async def test_event_templates(nova_model): assert "inferenceConfiguration" in event["event"]["sessionStart"] # Test prompt start event - nova_model.prompt_name = "test-prompt" + nova_model.session_id = "test-session" event_json = nova_model._get_prompt_start_event([]) event = json.loads(event_json) assert "event" in event assert "promptStart" in event["event"] - assert event["event"]["promptStart"]["promptName"] == "test-prompt" + assert event["event"]["promptStart"]["promptName"] == "test-session" # Test text input event content_name = "test-content" @@ -416,7 +416,7 @@ async def test_event_templates(nova_model): async def test_error_handling(nova_model, mock_client, mock_stream): """Test error handling in various scenarios.""" with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): - nova_model._client = mock_client + nova_model.client = mock_client # Test response processor handles errors gracefully async def mock_error(*args, **kwargs): diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py index 388fc95cc6..1c0b949b05 100644 --- a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -174,7 +174,7 @@ async def test_connection_lifecycle(mock_websockets_connect, model, system_promp model_org = OpenAIRealtimeModel(api_key="test-key", organization="org-123") await model_org.connect() call_kwargs = mock_connect.call_args.kwargs - headers = call_kwargs.get("additional_headers", []) + headers = call_kwargs.get("extra_headers", []) org_header = [h for h in headers if h[0] == "OpenAI-Organization"] assert len(org_header) == 1 assert org_header[0][1] == "org-123"