diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 3c47dd957..caee4715a 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -3,19 +3,22 @@ # Main components - Primary user interface from .agent.agent import BidirectionalAgent -# Advanced interfaces (for custom implementations) -from .models.bidirectional_model import BidirectionalModel, BidirectionalModelSession +# Model interface (for custom implementations) +from .models.bidirectional_model import BidirectionalModel # Model providers - What users need to create models -from .models.novasonic import NovaSonicBidirectionalModel -from .models.openai import OpenAIRealtimeBidirectionalModel +from .models.gemini_live import GeminiLiveModel +from .models.novasonic import NovaSonicModel +from .models.openai import OpenAIRealtimeModel # Event types - For type hints and event handling from .types.bidirectional_streaming import ( AudioInputEvent, AudioOutputEvent, BidirectionalStreamEvent, + ImageInputEvent, InterruptionDetectedEvent, + TextInputEvent, TextOutputEvent, UsageMetricsEvent, VoiceActivityEvent, @@ -26,12 +29,15 @@ "BidirectionalAgent", # Model providers - "NovaSonicBidirectionalModel", - "OpenAIRealtimeBidirectionalModel", + "GeminiLiveModel", + "NovaSonicModel", + "OpenAIRealtimeModel", # Event types "AudioInputEvent", - "AudioOutputEvent", + "AudioOutputEvent", + "ImageInputEvent", + "TextInputEvent", "TextOutputEvent", "InterruptionDetectedEvent", "BidirectionalStreamEvent", @@ -40,5 +46,4 @@ # Model interface "BidirectionalModel", - "BidirectionalModelSession", ] diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 4e8adfe7c..e61d938d5 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -379,13 +379,15 @@ async def send(self, input_data: str | AudioInputEvent | ImageInputEvent) -> Non self.messages.append({"role": "user", "content": input_data}) logger.debug("Text sent: %d characters", len(input_data)) - await self._session.model_session.send_text_content(input_data) + # Create TextInputEvent for send() + text_event = {"text": input_data, "role": "user"} + await self._session.model.send(text_event) elif isinstance(input_data, dict) and "audioData" in input_data: - # Handle audio input - await self._session.model_session.send_audio_content(input_data) + # Handle audio input - already in AudioInputEvent format + await self._session.model.send(input_data) elif isinstance(input_data, dict) and "imageData" in input_data: - # Handle image input (ImageInputEvent) - await self._session.model_session.send_image_content(input_data) + # Handle image input - already in ImageInputEvent format + await self._session.model.send(input_data) else: raise ValueError( "Input must be either a string (text), AudioInputEvent " @@ -419,7 +421,9 @@ async def interrupt(self) -> None: ValueError: If no active session. """ self._validate_active_session() - await self._session.model_session.send_interrupt() + # Interruption is now handled internally by models through audio/event processing + # No explicit interrupt method needed in unified interface + logger.debug("Interrupt requested - handled by model's audio processing") async def end(self) -> None: """End the conversation session and cleanup all resources. diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index bbf5fb425..38d92aea8 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -21,7 +21,7 @@ from ....types._events import ToolResultEvent, ToolStreamEvent from ....types.content import Message from ....types.tools import ToolResult, ToolUse -from ..models.bidirectional_model import BidirectionalModelSession +from ..models.bidirectional_model import BidirectionalModel logger = logging.getLogger(__name__) @@ -37,14 +37,14 @@ class BidirectionalConnection: handling while providing a simple interface for agent interactions. """ - def __init__(self, model_session: BidirectionalModelSession, agent: "BidirectionalAgent") -> None: - """Initialize session with model session and agent reference. + def __init__(self, model: BidirectionalModel, agent: "BidirectionalAgent") -> None: + """Initialize connection with model and agent reference. Args: - model_session: Provider-specific bidirectional model session. + model: Bidirectional model instance. agent: BidirectionalAgent instance for tool registry access. """ - self.model_session = model_session + self.model = model self.agent = agent self.active = True @@ -76,15 +76,15 @@ async def start_bidirectional_connection(agent: "BidirectionalAgent") -> Bidirec Returns: BidirectionalConnection: Active session with background tasks running. """ - logger.debug("Starting bidirectional session - initializing model session") + logger.debug("Starting bidirectional session - initializing model connection") - # Create provider-specific session - model_session = await agent.model.create_bidirectional_connection( + # Connect to model + await agent.model.connect( system_prompt=agent.system_prompt, tools=agent.tool_registry.get_all_tool_specs(), messages=agent.messages ) - # Create session wrapper for background processing - session = BidirectionalConnection(model_session=model_session, agent=agent) + # Create connection wrapper for background processing + session = BidirectionalConnection(model=agent.model, agent=agent) # Start concurrent background processors IMMEDIATELY after session creation # This is critical - Nova Sonic needs response processing during initialization @@ -135,9 +135,9 @@ async def stop_bidirectional_connection(session: BidirectionalConnection) -> Non if all_tasks: await asyncio.gather(*all_tasks, return_exceptions=True) - # Close model session - await session.model_session.close() - logger.debug("Session closed") + # Close model connection + await session.model.close() + logger.debug("Connection closed") async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> None: @@ -253,11 +253,11 @@ async def _process_model_events(session: BidirectionalConnection) -> None: events to standardized formats, and manages interruption detection. Args: - session: BidirectionalConnection containing model session. + session: BidirectionalConnection containing model. """ logger.debug("Model events processor started") try: - async for provider_event in session.model_session.receive_events(): + async for provider_event in session.model.receive(): if not session.active: break @@ -434,8 +434,8 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: tool_result = tool_event.tool_result tool_use_id = tool_result.get("toolUseId") - # Send result through provider-specific session - await session.model_session.send_tool_result(tool_use_id, tool_result) + # Send result through send() method + await session.model.send(tool_result) logger.debug("Tool result sent: %s", tool_use_id) # Handle streaming events if needed later @@ -471,10 +471,10 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: "content": [{"text": f"Error: {str(e)}"}] } try: - await session.model_session.send_tool_result(tool_id, error_result) + await session.model.send(error_result) logger.debug("Error result sent: %s", tool_id) - except Exception: - logger.error("Failed to send error result: %s", tool_id) - pass # Session might be closed + except Exception as send_error: + logger.error("Failed to send error result: %s - %s", tool_id, str(send_error)) + raise # Propagate exception since this is experimental code diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index c5287d15d..5b0d50687 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -1,17 +1,13 @@ """Bidirectional model interfaces and implementations.""" -from .bidirectional_model import BidirectionalModel, BidirectionalModelSession -from .gemini_live import GeminiLiveBidirectionalModel, GeminiLiveSession -from .novasonic import NovaSonicBidirectionalModel, NovaSonicSession -from .openai import OpenAIRealtimeBidirectionalModel, OpenAIRealtimeSession +from .bidirectional_model import BidirectionalModel +from .gemini_live import GeminiLiveModel +from .novasonic import NovaSonicModel +from .openai import OpenAIRealtimeModel __all__ = [ "BidirectionalModel", - "BidirectionalModelSession", - "GeminiLiveBidirectionalModel", - "GeminiLiveSession", - "NovaSonicBidirectionalModel", - "NovaSonicSession", - "OpenAIRealtimeBidirectionalModel", - "OpenAIRealtimeSession", + "GeminiLiveModel", + "NovaSonicModel", + "OpenAIRealtimeModel", ] diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 42485561b..05fb19e0f 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -1,113 +1,102 @@ -"""Bidirectional model interface for real-time streaming conversations. +"""Bidirectional streaming model interface. -Defines the interface for models that support bidirectional streaming capabilities. -Provides abstractions for different model providers with connection-based communication -patterns that support real-time audio and text interaction. +Defines the abstract interface for models that support real-time bidirectional +communication with persistent connections. Unlike traditional request-response +models, bidirectional models maintain an open connection for streaming audio, +text, and tool interactions. Features: -- connection-based persistent connections -- Real-time bidirectional communication +- Persistent connection management with connect/close lifecycle +- Real-time bidirectional communication (send and receive simultaneously) - Provider-agnostic event normalization -- Tool execution integration +- Support for audio, text, image, and tool result streaming """ -import abc import logging -from typing import AsyncIterable +from typing import AsyncIterable, Protocol, Union from ....types.content import Messages -from ....types.tools import ToolSpec -from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent, ImageInputEvent +from ....types.tools import ToolResult, ToolSpec +from ..types.bidirectional_streaming import ( + AudioInputEvent, + BidirectionalStreamEvent, + ImageInputEvent, + TextInputEvent, +) logger = logging.getLogger(__name__) -class BidirectionalModelSession(abc.ABC): - """Abstract interface for model-specific bidirectional communication connections. +class BidirectionalModel(Protocol): + """Protocol for bidirectional streaming models. - Defines the contract for managing persistent streaming connections with individual - model providers, handling audio/text input, receiving events, and managing - tool execution results. + This interface defines the contract for models that support persistent streaming + connections with real-time audio and text communication. Implementations handle + provider-specific protocols while exposing a standardized event-based API. """ - @abc.abstractmethod - async def receive_events(self) -> AsyncIterable[BidirectionalStreamEvent]: - """Receive events from the model in standardized format. - - Converts provider-specific events to a common format that can be - processed uniformly by the event loop. + async def connect( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs, + ) -> None: + """Establish a persistent streaming connection with the model. + + Opens a bidirectional connection that remains active for real-time communication. + The connection supports concurrent sending and receiving of events until explicitly + closed. Must be called before any send() or receive() operations. + + Args: + system_prompt: System instructions to configure model behavior. + tools: Tool specifications that the model can invoke during the conversation. + messages: Initial conversation history to provide context. + **kwargs: Provider-specific configuration options. """ - raise NotImplementedError + ... - @abc.abstractmethod - async def send_audio_content(self, audio_input: AudioInputEvent) -> None: - """Send audio content to the model during an active connection. + async def close(self) -> None: + """Close the streaming connection and release resources. - Handles audio encoding and provider-specific formatting while presenting - a simple AudioInputEvent interface. - """ - raise NotImplementedError - - # TODO: remove with interface unification - async def send_image_content(self, image_input: ImageInputEvent) -> None: - """Send image content to the model during an active connection. - - Handles image encoding and provider-specific formatting while presenting - a simple ImageInputEvent interface. - """ - raise NotImplementedError - - @abc.abstractmethod - async def send_text_content(self, text: str, **kwargs) -> None: - """Send text content to the model during ongoing generation. - - Allows natural interruption and follow-up questions without requiring - connection restart. + Terminates the active bidirectional connection and cleans up any associated + resources such as network connections, buffers, or background tasks. After + calling close(), the model instance cannot be used until connect() is called again. """ - raise NotImplementedError + ... - @abc.abstractmethod - async def send_interrupt(self) -> None: - """Send interruption signal to stop generation immediately. + async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: + """Receive streaming events from the model. - Enables responsive conversational experiences where users can - naturally interrupt during model responses. - """ - raise NotImplementedError + Continuously yields events from the model as they arrive over the connection. + Events are normalized to a provider-agnostic format for uniform processing. + This method should be called in a loop or async task to process model responses. - @abc.abstractmethod - async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: - """Send tool execution result to the model. + The stream continues until the connection is closed or an error occurs. - Formats and sends tool results according to the provider's specific protocol. - Handles both successful results and error cases through the result dictionary. + Yields: + BidirectionalStreamEvent: Standardized event dictionaries containing + audio output, text responses, tool calls, or control signals. """ - raise NotImplementedError - - @abc.abstractmethod - async def close(self) -> None: - """Close the connection and cleanup resources.""" - raise NotImplementedError - - -class BidirectionalModel(abc.ABC): - """Interface for models that support bidirectional streaming. - - Defines the contract for creating persistent streaming connections that support - real-time audio and text communication with AI models. - """ - - @abc.abstractmethod - async def create_bidirectional_connection( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs, - ) -> BidirectionalModelSession: - """Create a bidirectional connection with the model. - - Establishes a persistent connection for real-time communication while - abstracting provider-specific initialization requirements. + ... + + async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: + """Send content to the model over the active connection. + + Transmits user input or tool results to the model during an active streaming + session. Supports multiple content types including text, audio, images, and + tool execution results. Can be called multiple times during a conversation. + + Args: + content: The content to send. Must be one of: + - TextInputEvent: Text message from the user + - ImageInputEvent: Image data for visual understanding + - AudioInputEvent: Audio data for speech input + - ToolResult: Result from a tool execution + + Example: + await model.send(TextInputEvent(text="Hello", role="user")) + await model.send(AudioInputEvent(audioData=bytes, format="pcm", ...)) + await model.send(ToolResult(toolUseId="123", status="success", ...)) """ - raise NotImplementedError + ... diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 64c4d7348..ffff98cf1 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -15,14 +15,14 @@ import base64 import logging import uuid -from typing import Any, AsyncIterable, Dict, List, Optional +from typing import Any, AsyncIterable, Dict, List, Optional, Union from google import genai from google.genai import types as genai_types from google.genai.types import LiveServerMessage, LiveServerContent from ....types.content import Messages -from ....types.tools import ToolSpec, ToolUse +from ....types.tools import ToolResult, ToolSpec, ToolUse from ..types.bidirectional_streaming import ( AudioInputEvent, AudioOutputEvent, @@ -30,10 +30,11 @@ BidirectionalConnectionStartEvent, ImageInputEvent, InterruptionDetectedEvent, + TextInputEvent, TextOutputEvent, TranscriptEvent, ) -from .bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .bidirectional_model import BidirectionalModel logger = logging.getLogger(__name__) @@ -43,62 +44,92 @@ GEMINI_CHANNELS = 1 -class GeminiLiveSession(BidirectionalModelSession): - """Gemini Live API session using official Google GenAI SDK. +class GeminiLiveModel(BidirectionalModel): + """Gemini Live API implementation using official Google GenAI SDK. + Combines model configuration and connection state in a single class. Provides a clean interface to Gemini Live API using the official SDK, eliminating custom WebSocket handling and providing robust error handling. """ - def __init__(self, client: genai.Client, model_id: str, config: Dict[str, Any]): - """Initialize Gemini Live API session. + def __init__( + self, + model_id: str = "models/gemini-2.0-flash-live-preview-04-09", + api_key: Optional[str] = None, + live_config: Optional[Dict[str, Any]] = None, + **kwargs + ): + """Initialize Gemini Live API bidirectional model. Args: - client: Gemini client instance - model_id: Model identifier - config: Model configuration including live config + model_id: Gemini Live model identifier. + api_key: Google AI API key for authentication. + live_config: Gemini Live API configuration parameters (e.g., response_modalities, speech_config). + **kwargs: Reserved for future parameters. """ - self.client = client + # Model configuration self.model_id = model_id - self.config = config - self.session_id = str(uuid.uuid4()) - self._active = True - self.live_session = None - self.live_session_cm = None + self.api_key = api_key + self.live_config = live_config or {} - + # Create Gemini client with proper API version + client_kwargs = {} + if api_key: + client_kwargs["api_key"] = api_key + + # Use v1alpha for Live API as it has better model support + client_kwargs["http_options"] = {"api_version": "v1alpha"} + + self.client = genai.Client(**client_kwargs) + + # Connection state (initialized in connect()) + self.live_session = None + self.live_session_context_manager = None + self.session_id = None + self._active = False - async def initialize( + async def connect( self, system_prompt: Optional[str] = None, tools: Optional[List[ToolSpec]] = None, - messages: Optional[Messages] = None + messages: Optional[Messages] = None, + **kwargs ) -> None: - """Initialize Gemini Live API session by creating the connection.""" + """Establish bidirectional connection with Gemini Live API. + + Args: + system_prompt: System instructions for the model. + tools: List of tools available to the model. + messages: Conversation history to initialize with. + **kwargs: Additional configuration options. + """ + if self._active: + raise RuntimeError("Connection already active. Close the existing connection before creating a new one.") try: - # Build live config - live_config = self.config.get("live_config") + # Initialize connection state + self.session_id = str(uuid.uuid4()) + self._active = True - if live_config is None: - raise ValueError("live_config is required but not found in session config") + # Build live config + live_config = self._build_live_config(system_prompt, tools, **kwargs) # Create the context manager - self.live_session_cm = self.client.aio.live.connect( + self.live_session_context_manager = self.client.aio.live.connect( model=self.model_id, config=live_config ) # Enter the context manager - self.live_session = await self.live_session_cm.__aenter__() + self.live_session = await self.live_session_context_manager.__aenter__() # Send initial message history if provided if messages: await self._send_message_history(messages) - except Exception as e: - logger.error("Error initializing Gemini Live session: %s", e) + self._active = False + logger.error("Error connecting to Gemini Live: %s", e) raise async def _send_message_history(self, messages: Messages) -> None: @@ -125,13 +156,13 @@ async def _send_message_history(self, messages: Messages) -> None: content = genai_types.Content(role=role, parts=content_parts) await self.live_session.send_client_content(turns=content) - async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: + async def receive(self) -> AsyncIterable[Dict[str, Any]]: """Receive Gemini Live API events and convert to provider-agnostic format.""" # Emit connection start event connection_start: BidirectionalConnectionStartEvent = { "connectionId": self.session_id, - "metadata": {"provider": "gemini_live", "model_id": self.config.get("model_id")} + "metadata": {"provider": "gemini_live", "model_id": self.model_id} } yield {"BidirectionalConnectionStart": connection_start} @@ -251,15 +282,44 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic logger.error("Message attributes: %s", [attr for attr in dir(message) if not attr.startswith('_')]) return None - async def send_audio_content(self, audio_input: AudioInputEvent) -> None: - """Send audio content using Gemini Live API. + async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: + """Unified send method for all content types. Sends the given inputs to Google Live API - Gemini Live expects continuous audio streaming via send_realtime_input. - This automatically triggers VAD and can interrupt ongoing responses. + Dispatches to appropriate internal handler based on content type. + + Args: + content: Typed event (TextInputEvent, ImageInputEvent, AudioInputEvent, or ToolResult). """ if not self._active: return + try: + if isinstance(content, dict): + # Dispatch based on content structure + if "text" in content and "role" in content: + # TextInputEvent + await self._send_text_content(content["text"]) + elif "audioData" in content: + # AudioInputEvent + await self._send_audio_content(content) + elif "imageData" in content or "image_url" in content: + # ImageInputEvent + await self._send_image_content(content) + elif "toolUseId" in content and "status" in content: + # ToolResult + await self._send_tool_result(content) + else: + logger.warning(f"Unknown content type with keys: {content.keys()}") + except Exception as e: + logger.error(f"Error sending content: {e}") + raise # Propagate exception for debugging in experimental code + + async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Internal: Send audio content using Gemini Live API. + + Gemini Live expects continuous audio streaming via send_realtime_input. + This automatically triggers VAD and can interrupt ongoing responses. + """ try: # Create audio blob for the SDK audio_blob = genai_types.Blob( @@ -273,18 +333,15 @@ async def send_audio_content(self, audio_input: AudioInputEvent) -> None: except Exception as e: logger.error("Error sending audio content: %s", e) - async def send_image_content(self, image_input: ImageInputEvent) -> None: - """Send image content using Gemini Live API. + async def _send_image_content(self, image_input: ImageInputEvent) -> None: + """Internal: Send image content using Gemini Live API. Sends image frames following the same pattern as the GitHub example. Images are sent as base64-encoded data with MIME type. """ - if not self._active: - return - try: # Prepare the message based on encoding - if image_input["encoding"] == "base64": + if image_input.get("encoding") == "base64": # Data is already base64 encoded if isinstance(image_input["imageData"], bytes): data_str = image_input["imageData"].decode() @@ -306,11 +363,8 @@ async def send_image_content(self, image_input: ImageInputEvent) -> None: except Exception as e: logger.error("Error sending image content: %s", e) - async def send_text_content(self, text: str, **kwargs) -> None: - """Send text content using Gemini Live API.""" - if not self._active: - return - + async def _send_text_content(self, text: str) -> None: + """Internal: Send text content using Gemini Live API.""" try: # Create content with text content = genai_types.Content( @@ -324,36 +378,25 @@ async def send_text_content(self, text: str, **kwargs) -> None: except Exception as e: logger.error("Error sending text content: %s", e) - async def send_interrupt(self) -> None: - """Send interruption signal to Gemini Live API. - - Gemini Live uses automatic VAD-based interruption. When new audio input - is detected, it automatically interrupts the ongoing generation. - We don't need to send explicit interrupt signals like Nova Sonic. - """ - if not self._active: - return - + async def _send_tool_result(self, tool_result: ToolResult) -> None: + """Internal: Send tool result using Gemini Live API.""" try: - # Gemini Live handles interruption automatically through VAD - # When new audio input is sent via send_realtime_input, it automatically - # interrupts any ongoing generation. No explicit interrupt signal needed. - logger.debug("Interrupt requested - Gemini Live handles this automatically via VAD") + tool_use_id = tool_result.get("toolUseId") + + # Extract result content + result_data = {} + if "content" in tool_result: + # Extract text from content blocks + for block in tool_result["content"]: + if "text" in block: + result_data = {"result": block["text"]} + break - except Exception as e: - logger.error("Error in interrupt handling: %s", e) - - async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: - """Send tool result using Gemini Live API.""" - if not self._active: - return - - try: # Create function response func_response = genai_types.FunctionResponse( id=tool_use_id, name=tool_use_id, # Gemini uses name as identifier - response=result + response=result_data ) # Send tool response @@ -361,11 +404,6 @@ async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> No except Exception as e: logger.error("Error sending tool result: %s", e) - async def send_tool_error(self, tool_use_id: str, error: str) -> None: - """Send tool error using Gemini Live API.""" - error_result = {"error": error} - await self.send_tool_result(tool_use_id, error_result) - async def close(self) -> None: """Close Gemini Live API connection.""" if not self._active: @@ -375,72 +413,10 @@ async def close(self) -> None: try: # Exit the context manager properly - if self.live_session_cm: - await self.live_session_cm.__aexit__(None, None, None) - except Exception as e: - logger.error("Error closing Gemini Live session: %s", e) - raise - - -class GeminiLiveBidirectionalModel(BidirectionalModel): - """Gemini Live API model implementation using official Google GenAI SDK. - - Provides access to Google's Gemini Live API through the bidirectional - streaming interface, using the official SDK for robust and simple integration. - """ - - def __init__( - self, - model_id: str = "models/gemini-2.0-flash-live-preview-04-09", - api_key: Optional[str] = None, - **config - ): - """Initialize Gemini Live API bidirectional model. - - Args: - model_id: Gemini Live model identifier. - api_key: Google AI API key for authentication. - **config: Additional configuration. - """ - self.model_id = model_id - self.api_key = api_key - self.config = config - - # Create Gemini client with proper API version - client_kwargs = {} - if api_key: - client_kwargs["api_key"] = api_key - - # Use v1alpha for Live API as it has better model support - client_kwargs["http_options"] = {"api_version": "v1alpha"} - - self.client = genai.Client(**client_kwargs) - - async def create_bidirectional_connection( - self, - system_prompt: Optional[str] = None, - tools: Optional[List[ToolSpec]] = None, - messages: Optional[Messages] = None, - **kwargs - ) -> BidirectionalModelSession: - """Create Gemini Live API bidirectional connection using official SDK.""" - - try: - # Build configuration - live_config = self._build_live_config(system_prompt, tools, **kwargs) - - # Create session config - session_config = self._get_session_config() - session_config["live_config"] = live_config - - # Create and initialize session wrapper - session = GeminiLiveSession(self.client, self.model_id, session_config) - await session.initialize(system_prompt, tools, messages) - - return session - + if self.live_session_context_manager: + await self.live_session_context_manager.__aexit__(None, None, None) except Exception as e: - logger.error("Failed to create Gemini Live connection: %s", e) + logger.error("Error closing Gemini Live connection: %s", e) raise def _build_live_config( @@ -451,15 +427,15 @@ def _build_live_config( ) -> Dict[str, Any]: """Build LiveConnectConfig for the official SDK. - Simply passes through all config parameters from params, allowing users + Simply passes through all config parameters from live_config, allowing users to configure any Gemini Live API parameter directly. """ - # Start with user config from params + # Start with user-provided live_config config_dict = {} - if "params" in self.config: - config_dict.update(self.config["params"]) + if self.live_config: + config_dict.update(self.live_config) - # Override with any kwargs + # Override with any kwargs from connect() config_dict.update(kwargs) # Add system instruction if provided @@ -488,12 +464,4 @@ def _format_tools_for_live_api(self, tool_specs: List[ToolSpec]) -> List[genai_t for tool_spec in tool_specs ], ), - ] - - def _get_session_config(self) -> Dict[str, Any]: - """Get session configuration for Gemini Live API.""" - return { - "model_id": self.model_id, - "params": self.config.get("params"), - **self.config - } \ No newline at end of file + ] \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 134ff73fd..e4c0d1565 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -19,25 +19,31 @@ import time import traceback import uuid -from typing import AsyncIterable +from typing import AsyncIterable, Union from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme -from aws_sdk_bedrock_runtime.models import BidirectionalInputPayloadPart, InvokeModelWithBidirectionalStreamInputChunk, InvokeModelWithBidirectionalStreamOperationOutput +from aws_sdk_bedrock_runtime.models import ( + BidirectionalInputPayloadPart, + InvokeModelWithBidirectionalStreamInputChunk, + InvokeModelWithBidirectionalStreamOperationOutput, +) from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver from ....types.content import Messages -from ....types.tools import ToolSpec, ToolUse +from ....types.tools import ToolResult, ToolSpec, ToolUse from ..types.bidirectional_streaming import ( AudioInputEvent, AudioOutputEvent, BidirectionalConnectionEndEvent, BidirectionalConnectionStartEvent, + ImageInputEvent, InterruptionDetectedEvent, + TextInputEvent, TextOutputEvent, UsageMetricsEvent, ) -from .bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .bidirectional_model import BidirectionalModel logger = logging.getLogger(__name__) @@ -72,29 +78,39 @@ RESPONSE_TIMEOUT = 1.0 -class NovaSonicSession(BidirectionalModelSession): - """Nova Sonic connection implementation handling the provider's specific protocol. +class NovaSonicModel(BidirectionalModel): + """Nova Sonic implementation for bidirectional streaming. + Combines model configuration and connection state in a single class. Manages Nova Sonic's complex event sequencing, audio format conversion, and - tool execution patterns while providing the standard BidirectionalModelSession - interface. + tool execution patterns while providing the standard BidirectionalModel interface. """ - def __init__(self, stream: InvokeModelWithBidirectionalStreamOperationOutput, config: dict[str, any]) -> None: - """Initialize Nova Sonic connection. + def __init__( + self, + model_id: str = "amazon.nova-sonic-v1:0", + region: str = "us-east-1", + **kwargs + ) -> None: + """Initialize Nova Sonic bidirectional model. Args: - stream: Nova Sonic bidirectional stream operation output from AWS SDK. - config: Model configuration. + model_id: Nova Sonic model identifier. + region: AWS region. + **kwargs: Reserved for future parameters. """ - self.stream = stream - self.config = config - self.prompt_name = str(uuid.uuid4()) - self._active = True + # Model configuration + self.model_id = model_id + self.region = region + self.client = None + + # Connection state (initialized in connect()) + self.stream = None + self.session_id = None + self._active = False # Nova Sonic requires unique content names - self.audio_content_name = str(uuid.uuid4()) - self.text_content_name = str(uuid.uuid4()) + self.audio_content_name = None # Audio connection state self.audio_connection_active = False @@ -102,33 +118,70 @@ def __init__(self, stream: InvokeModelWithBidirectionalStreamOperationOutput, co self.silence_threshold = SILENCE_THRESHOLD self.silence_task = None - # Validate stream - if not stream: - logger.error("Stream is None") - raise ValueError("Stream cannot be None") + # Background task and event queue + self._response_task = None + self._event_queue = None - logger.debug("Nova Sonic connection initialized with prompt: %s", self.prompt_name) + logger.debug("Nova Sonic bidirectional model initialized: %s", model_id) - async def initialize( + async def connect( self, system_prompt: str | None = None, tools: list[ToolSpec] | None = None, messages: Messages | None = None, + **kwargs, ) -> None: - """Initialize Nova Sonic connection with required protocol sequence.""" + """Establish bidirectional connection to Nova Sonic. + + Args: + system_prompt: System instructions for the model. + tools: List of tools available to the model. + messages: Conversation history to initialize with. + **kwargs: Additional configuration options. + """ + if self._active: + raise RuntimeError("Connection already active. Close the existing connection before creating a new one.") + + logger.debug("Nova connection create - starting") + try: - system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." + # Initialize client if needed + if not self.client: + await self._initialize_client() + + # Initialize connection state + self.session_id = str(uuid.uuid4()) + self._active = True + self.audio_content_name = str(uuid.uuid4()) + self._event_queue = asyncio.Queue() + + # Start Nova Sonic bidirectional stream + self.stream = await self.client.invoke_model_with_bidirectional_stream( + InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) + ) + + # Validate stream + if not self.stream: + logger.error("Stream is None") + raise ValueError("Stream cannot be None") + logger.debug("Nova Sonic connection initialized with session: %s", self.session_id) + + # Send initialization events + system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." init_events = self._build_initialization_events(system_prompt, tools or [], messages) logger.debug("Nova Sonic initialization - sending %d events", len(init_events)) await self._send_initialization_events(init_events) - logger.info("Nova Sonic connection initialized successfully") + # Start background response processor self._response_task = asyncio.create_task(self._process_responses()) + logger.info("Nova Sonic connection established successfully") + except Exception as e: - logger.error("Error during Nova Sonic initialization: %s", e) + self._active = False + logger.error("Nova connection create error: %s", str(e)) raise def _build_initialization_events( @@ -206,7 +259,7 @@ def _log_event_type(self, nova_event: dict[str, any]) -> None: audio_bytes = base64.b64decode(audio_content) logger.debug("Nova audio output: %d bytes", len(audio_bytes)) - async def receive_events(self) -> AsyncIterable[dict[str, any]]: + async def receive(self) -> AsyncIterable[dict[str, any]]: """Receive Nova Sonic events and convert to provider-agnostic format.""" if not self.stream: logger.error("Stream is None") @@ -216,15 +269,11 @@ async def receive_events(self) -> AsyncIterable[dict[str, any]]: # Emit connection start event to Strands event system connection_start: BidirectionalConnectionStartEvent = { - "connectionId": self.prompt_name, - "metadata": {"provider": "nova_sonic", "model_id": self.config.get("model_id")}, + "connectionId": self.session_id, + "metadata": {"provider": "nova_sonic", "model_id": self.model_id}, } yield {"BidirectionalConnectionStart": connection_start} - # Initialize event queue if not already done - if not hasattr(self, "_event_queue"): - self._event_queue = asyncio.Queue() - try: while self._active: try: @@ -246,14 +295,46 @@ async def receive_events(self) -> AsyncIterable[dict[str, any]]: finally: # Emit connection end event when exiting connection_end: BidirectionalConnectionEndEvent = { - "connectionId": self.prompt_name, + "connectionId": self.session_id, "reason": "connection_complete", "metadata": {"provider": "nova_sonic"}, } yield {"BidirectionalConnectionEnd": connection_end} - async def start_audio_connection(self) -> None: - """Start audio input connection (call once before sending audio chunks).""" + async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: + """Unified send method for all content types. Sends the given content to Nova Sonic. + + Dispatches to appropriate internal handler based on content type. + + Args: + content: Typed event (TextInputEvent, ImageInputEvent, AudioInputEvent, or ToolResult). + """ + if not self._active: + return + + try: + if isinstance(content, dict): + # Dispatch based on content structure + if "text" in content and "role" in content: + # TextInputEvent + await self._send_text_content(content["text"]) + elif "audioData" in content: + # AudioInputEvent + await self._send_audio_content(content) + elif "imageData" in content or "image_url" in content: + # ImageInputEvent - not supported by Nova Sonic + logger.warning("Image input not supported by Nova Sonic") + elif "toolUseId" in content and "status" in content: + # ToolResult + await self._send_tool_result(content) + else: + logger.warning(f"Unknown content type with keys: {content.keys()}") + except Exception as e: + logger.error(f"Error sending content: {e}") + raise # Propagate exception for debugging in experimental code + + async def _start_audio_connection(self) -> None: + """Internal: Start audio input connection (call once before sending audio chunks).""" if self.audio_connection_active: return @@ -263,7 +344,7 @@ async def start_audio_connection(self) -> None: { "event": { "contentStart": { - "promptName": self.prompt_name, + "promptName": self.session_id, "contentName": self.audio_content_name, "type": "AUDIO", "interactive": True, @@ -277,14 +358,11 @@ async def start_audio_connection(self) -> None: await self._send_nova_event(audio_content_start) self.audio_connection_active = True - async def send_audio_content(self, audio_input: AudioInputEvent) -> None: - """Send audio using Nova Sonic protocol-specific format.""" - if not self._active: - return - + async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Internal: Send audio using Nova Sonic protocol-specific format.""" # Start audio connection if not already active if not self.audio_connection_active: - await self.start_audio_connection() + await self._start_audio_connection() # Update last audio time and cancel any pending silence task self.last_audio_time = time.time() @@ -299,7 +377,7 @@ async def send_audio_content(self, audio_input: AudioInputEvent) -> None: { "event": { "audioInput": { - "promptName": self.prompt_name, + "promptName": self.session_id, "contentName": self.audio_content_name, "content": nova_audio_data, } @@ -313,36 +391,33 @@ async def send_audio_content(self, audio_input: AudioInputEvent) -> None: self.silence_task = asyncio.create_task(self._check_silence()) async def _check_silence(self) -> None: - """Check for silence and automatically end audio connection.""" + """Internal: Check for silence and automatically end audio connection.""" try: await asyncio.sleep(self.silence_threshold) if self.audio_connection_active and self.last_audio_time: elapsed = time.time() - self.last_audio_time if elapsed >= self.silence_threshold: logger.debug("Nova silence detected: %.2f seconds", elapsed) - await self.end_audio_input() + await self._end_audio_input() except asyncio.CancelledError: pass - async def end_audio_input(self) -> None: - """End current audio input connection to trigger Nova Sonic processing.""" + async def _end_audio_input(self) -> None: + """Internal: End current audio input connection to trigger Nova Sonic processing.""" if not self.audio_connection_active: return logger.debug("Nova audio connection end") audio_content_end = json.dumps( - {"event": {"contentEnd": {"promptName": self.prompt_name, "contentName": self.audio_content_name}}} + {"event": {"contentEnd": {"promptName": self.session_id, "contentName": self.audio_content_name}}} ) await self._send_nova_event(audio_content_end) self.audio_connection_active = False - async def send_text_content(self, text: str, **kwargs) -> None: - """Send text content using Nova Sonic format.""" - if not self._active: - return - + async def _send_text_content(self, text: str) -> None: + """Internal: Send text content using Nova Sonic format.""" content_name = str(uuid.uuid4()) events = [ self._get_text_content_start_event(content_name), @@ -353,37 +428,45 @@ async def send_text_content(self, text: str, **kwargs) -> None: for event in events: await self._send_nova_event(event) - async def send_interrupt(self) -> None: - """Send interruption signal to Nova Sonic.""" - if not self._active: - return - + async def _send_interrupt(self) -> None: + """Internal: Send interruption signal to Nova Sonic.""" # Nova Sonic handles interruption through special input events - interrupt_event = { - "event": { - "audioInput": { - "promptName": self.prompt_name, - "contentName": self.audio_content_name, - "stopReason": "INTERRUPTED", + interrupt_event = json.dumps( + { + "event": { + "audioInput": { + "promptName": self.session_id, + "contentName": self.audio_content_name, + "stopReason": "INTERRUPTED", + } } } - } + ) await self._send_nova_event(interrupt_event) - async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: - """Send tool result using Nova Sonic toolResult format.""" - if not self._active: - return + async def _send_tool_result(self, tool_result: ToolResult) -> None: + """Internal: Send tool result using Nova Sonic toolResult format.""" + tool_use_id = tool_result.get("toolUseId") logger.debug("Nova tool result send: %s", tool_use_id) + + # Extract result content + result_data = {} + if "content" in tool_result: + # Extract text from content blocks + for block in tool_result["content"]: + if "text" in block: + result_data = {"result": block["text"]} + break + content_name = str(uuid.uuid4()) events = [ self._get_tool_content_start_event(content_name, tool_use_id), - self._get_tool_result_event(content_name, result), + self._get_tool_result_event(content_name, result_data), self._get_content_end_event(content_name), ] - for _i, event in enumerate(events): + for event in events: await self._send_nova_event(event) async def close(self) -> None: @@ -405,7 +488,7 @@ async def close(self) -> None: try: # End audio connection if active if self.audio_connection_active: - await self.end_audio_input() + await self._end_audio_input() # Send cleanup events cleanup_events = [self._get_prompt_end_event(), self._get_connection_end_event()] @@ -518,7 +601,7 @@ def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: prompt_start_event = { "event": { "promptStart": { - "promptName": self.prompt_name, + "promptName": self.session_id, "textOutputConfiguration": NOVA_TEXT_CONFIG, "audioOutputConfiguration": NOVA_AUDIO_OUTPUT_CONFIG, } @@ -562,7 +645,7 @@ def _get_text_content_start_event(self, content_name: str, role: str = "USER") - { "event": { "contentStart": { - "promptName": self.prompt_name, + "promptName": self.session_id, "contentName": content_name, "type": "TEXT", "role": role, @@ -579,7 +662,7 @@ def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> { "event": { "contentStart": { - "promptName": self.prompt_name, + "promptName": self.session_id, "contentName": content_name, "interactive": False, "type": "TOOL", @@ -597,7 +680,7 @@ def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> def _get_text_input_event(self, content_name: str, text: str) -> str: """Generate text input event.""" return json.dumps( - {"event": {"textInput": {"promptName": self.prompt_name, "contentName": content_name, "content": text}}} + {"event": {"textInput": {"promptName": self.session_id, "contentName": content_name, "content": text}}} ) def _get_tool_result_event(self, content_name: str, result: dict[str, any]) -> str: @@ -606,7 +689,7 @@ def _get_tool_result_event(self, content_name: str, result: dict[str, any]) -> s { "event": { "toolResult": { - "promptName": self.prompt_name, + "promptName": self.session_id, "contentName": content_name, "content": json.dumps(result), } @@ -616,11 +699,11 @@ def _get_tool_result_event(self, content_name: str, result: dict[str, any]) -> s def _get_content_end_event(self, content_name: str) -> str: """Generate content end event.""" - return json.dumps({"event": {"contentEnd": {"promptName": self.prompt_name, "contentName": content_name}}}) + return json.dumps({"event": {"contentEnd": {"promptName": self.session_id, "contentName": content_name}}}) def _get_prompt_end_event(self) -> str: """Generate prompt end event.""" - return json.dumps({"event": {"promptEnd": {"promptName": self.prompt_name}}}) + return json.dumps({"event": {"promptEnd": {"promptName": self.session_id}}}) def _get_connection_end_event(self) -> str: """Generate connection end event.""" @@ -640,60 +723,6 @@ async def _send_nova_event(self, event: str) -> None: logger.error("Event was: %s", event) raise - -class NovaSonicBidirectionalModel(BidirectionalModel): - """Nova Sonic model implementation for bidirectional streaming. - - Provides access to Amazon's Nova Sonic model through the bidirectional - streaming interface, handling AWS authentication and connection management. - """ - - def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config: any) -> None: - """Initialize Nova Sonic bidirectional model. - - Args: - model_id: Nova Sonic model identifier. - region: AWS region. - **config: Additional configuration. - """ - self.model_id = model_id - self.region = region - self.config = config - self._client = None - - logger.debug("Nova Sonic bidirectional model initialized: %s", model_id) - - async def create_bidirectional_connection( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs, - ) -> BidirectionalModelSession: - """Create Nova Sonic bidirectional connection.""" - logger.debug("Nova connection create - starting") - - # Initialize client if needed - if not self._client: - await self._initialize_client() - - # Start Nova Sonic bidirectional stream - try: - stream = await self._client.invoke_model_with_bidirectional_stream( - InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) - ) - - # Create and initialize connection - connection = NovaSonicSession(stream, self.config) - await connection.initialize(system_prompt, tools, messages) - - logger.debug("Nova connection created") - return connection - except Exception as e: - logger.error("Nova connection create error: %s", str(e)) - logger.error("Failed to create Nova Sonic connection: %s", e) - raise - async def _initialize_client(self) -> None: """Initialize Nova Sonic client.""" try: @@ -705,7 +734,7 @@ async def _initialize_client(self) -> None: auth_schemes={"aws.auth#sigv4": SigV4AuthScheme(service="bedrock")}, ) - self._client = BedrockRuntimeClient(config=config) + self.client = BedrockRuntimeClient(config=config) logger.debug("Nova Sonic client initialized") except ImportError as e: diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 7d009b1c7..4bf43b563 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -8,25 +8,28 @@ import base64 import json import logging +import os import uuid -from typing import AsyncIterable +from typing import AsyncIterable, Union import websockets from websockets.client import WebSocketClientProtocol from websockets.exceptions import ConnectionClosed from ....types.content import Messages -from ....types.tools import ToolSpec, ToolUse +from ....types.tools import ToolResult, ToolSpec, ToolUse from ..types.bidirectional_streaming import ( AudioInputEvent, AudioOutputEvent, BidirectionalConnectionEndEvent, BidirectionalConnectionStartEvent, BidirectionalStreamEvent, + ImageInputEvent, + TextInputEvent, TextOutputEvent, VoiceActivityEvent, ) -from .bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .bidirectional_model import BidirectionalModel logger = logging.getLogger(__name__) @@ -55,63 +58,126 @@ } -class OpenAIRealtimeSession(BidirectionalModelSession): - """OpenAI Realtime API session for real-time audio/text streaming. +class OpenAIRealtimeModel(BidirectionalModel): + """OpenAI Realtime API implementation for bidirectional streaming. + Combines model configuration and connection state in a single class. Manages WebSocket connection to OpenAI's Realtime API with automatic VAD, function calling, and event conversion to Strands format. """ - def __init__(self, websocket: WebSocketClientProtocol, config: dict[str, any]) -> None: - """Initialize OpenAI Realtime session.""" - self.websocket = websocket - self.config = config - self.session_id = str(uuid.uuid4()) - self._active = True - - self._event_queue = asyncio.Queue() + def __init__( + self, + model: str = DEFAULT_MODEL, + api_key: str | None = None, + organization: str | None = None, + project: str | None = None, + session_config: dict[str, any] | None = None, + **kwargs + ) -> None: + """Initialize OpenAI Realtime bidirectional model. + + Args: + model: OpenAI model identifier (default: gpt-realtime). + api_key: OpenAI API key for authentication. + organization: OpenAI organization ID for API requests. + project: OpenAI project ID for API requests. + session_config: Session configuration parameters (e.g., voice, turn_detection, modalities). + **kwargs: Reserved for future parameters. + """ + # Model configuration + self.model = model + self.api_key = api_key + self.organization = organization + self.project = project + self.session_config = session_config or {} + + if not self.api_key: + self.api_key = os.getenv("OPENAI_API_KEY") + if not self.api_key: + raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter.") + + # Connection state (initialized in connect()) + self.websocket = None + self.session_id = None + self._active = False + + self._event_queue = None self._response_task = None self._function_call_buffer = {} - logger.debug("OpenAI Realtime session initialized: %s", self.session_id) - - def _require_active(self) -> bool: - """Check if session is active.""" - return self._active - - def _create_text_event(self, text: str, role: str) -> dict[str, any]: - """Create standardized text output event.""" - text_output: TextOutputEvent = {"text": text, "role": role} - return {"textOutput": text_output} - - def _create_voice_activity_event(self, activity_type: str) -> dict[str, any]: - """Create standardized voice activity event.""" - voice_activity: VoiceActivityEvent = {"activityType": activity_type} - return {"voiceActivity": voice_activity} - - + logger.debug("OpenAI Realtime bidirectional model initialized: %s", model) - async def initialize( + async def connect( self, system_prompt: str | None = None, tools: list[ToolSpec] | None = None, messages: Messages | None = None, + **kwargs, ) -> None: - """Initialize session with configuration.""" + """Establish bidirectional connection to OpenAI Realtime API. + + Args: + system_prompt: System instructions for the model. + tools: List of tools available to the model. + messages: Conversation history to initialize with. + **kwargs: Additional configuration options. + """ + if self._active: + raise RuntimeError("Connection already active. Close the existing connection before creating a new one.") + + logger.info("Creating OpenAI Realtime connection...") + try: + # Initialize connection state + self.session_id = str(uuid.uuid4()) + self._active = True + self._event_queue = asyncio.Queue() + self._function_call_buffer = {} + + # Establish WebSocket connection + url = f"{OPENAI_REALTIME_URL}?model={self.model}" + + headers = [("Authorization", f"Bearer {self.api_key}")] + if self.organization: + headers.append(("OpenAI-Organization", self.organization)) + if self.project: + headers.append(("OpenAI-Project", self.project)) + + self.websocket = await websockets.connect(url, extra_headers=headers) + logger.info("WebSocket connected successfully") + + # Configure session session_config = self._build_session_config(system_prompt, tools) await self._send_event({"type": "session.update", "session": session_config}) + # Add conversation history if provided if messages: await self._add_conversation_history(messages) + # Start background response processor self._response_task = asyncio.create_task(self._process_responses()) - logger.info("OpenAI Realtime session initialized successfully") + logger.info("OpenAI Realtime connection established") except Exception as e: - logger.error("Error during OpenAI Realtime initialization: %s", e) + self._active = False + logger.error("OpenAI connection error: %s", e) raise + def _require_active(self) -> bool: + """Check if session is active.""" + return self._active + + def _create_text_event(self, text: str, role: str) -> dict[str, any]: + """Create standardized text output event.""" + text_output: TextOutputEvent = {"text": text, "role": role} + return {"textOutput": text_output} + + def _create_voice_activity_event(self, activity_type: str) -> dict[str, any]: + """Create standardized voice activity event.""" + voice_activity: VoiceActivityEvent = {"activityType": activity_type} + return {"voiceActivity": voice_activity} + def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict: """Build session configuration for OpenAI Realtime API.""" config = DEFAULT_SESSION_CONFIG.copy() @@ -122,14 +188,14 @@ def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] if tools: config["tools"] = self._convert_tools_to_openai_format(tools) - custom_config = self.config.get("session", {}) + # Apply user-provided session configuration supported_params = { "type", "output_modalities", "instructions", "voice", "audio", "tools", "tool_choice", "input_audio_format", "output_audio_format", "input_audio_transcription", "turn_detection" } - for key, value in custom_config.items(): + for key, value in self.session_config.items(): if key in supported_params: config[key] = value else: @@ -201,11 +267,11 @@ async def _process_responses(self) -> None: self._active = False logger.debug("OpenAI Realtime response processor stopped") - async def receive_events(self) -> AsyncIterable[BidirectionalStreamEvent]: + async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: """Receive OpenAI events and convert to Strands format.""" connection_start: BidirectionalConnectionStartEvent = { "connectionId": self.session_id, - "metadata": {"provider": "openai_realtime", "model": self.config.get("model", DEFAULT_MODEL)}, + "metadata": {"provider": "openai_realtime", "model": self.model}, } yield {"BidirectionalConnectionStart": connection_start} @@ -366,19 +432,45 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] logger.debug("Unhandled OpenAI event type: %s", event_type) return None - async def send_audio_content(self, audio_input: AudioInputEvent) -> None: - """Send audio content to OpenAI for processing.""" + async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: + """Unified send method for all content types. Sends the given content to OpenAI. + + Dispatches to appropriate internal handler based on content type. + + Args: + content: Typed event (TextInputEvent, ImageInputEvent, AudioInputEvent, or ToolResult). + """ if not self._require_active(): return + try: + if isinstance(content, dict): + # Dispatch based on content structure + if "text" in content and "role" in content: + # TextInputEvent + await self._send_text_content(content["text"]) + elif "audioData" in content: + # AudioInputEvent + await self._send_audio_content(content) + elif "imageData" in content or "image_url" in content: + # ImageInputEvent - not supported by OpenAI Realtime yet + logger.warning("Image input not supported by OpenAI Realtime API") + elif "toolUseId" in content and "status" in content: + # ToolResult + await self._send_tool_result(content) + else: + logger.warning(f"Unknown content type with keys: {content.keys()}") + except Exception as e: + logger.error(f"Error sending content: {e}") + raise # Propagate exception for debugging in experimental code + + async def _send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Internal: Send audio content to OpenAI for processing.""" audio_base64 = base64.b64encode(audio_input["audioData"]).decode("utf-8") await self._send_event({"type": "input_audio_buffer.append", "audio": audio_base64}) - async def send_text_content(self, text: str) -> None: - """Send text content to OpenAI for processing.""" - if not self._require_active(): - return - + async def _send_text_content(self, text: str) -> None: + """Internal: Send text content to OpenAI for processing.""" item_data = { "type": "message", "role": "user", @@ -387,20 +479,26 @@ async def send_text_content(self, text: str) -> None: await self._send_event({"type": "conversation.item.create", "item": item_data}) await self._send_event({"type": "response.create"}) - async def send_interrupt(self) -> None: - """Send interruption signal to OpenAI.""" - if not self._require_active(): - return - + async def _send_interrupt(self) -> None: + """Internal: Send interruption signal to OpenAI.""" await self._send_event({"type": "response.cancel"}) - async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: - """Send tool result back to OpenAI.""" - if not self._require_active(): - return + async def _send_tool_result(self, tool_result: ToolResult) -> None: + """Internal: Send tool result back to OpenAI.""" + tool_use_id = tool_result.get("toolUseId") logger.debug("OpenAI tool result send: %s", tool_use_id) - result_text = json.dumps(result) if not isinstance(result, str) else result + + # Extract result content + result_data = {} + if "content" in tool_result: + # Extract text from content blocks + for block in tool_result["content"]: + if "text" in block: + result_data = block["text"] + break + + result_text = json.dumps(result_data) if not isinstance(result_data, str) else result_data item_data = { "type": "function_call_output", @@ -443,60 +541,3 @@ async def _send_event(self, event: dict[str, any]) -> None: raise -class OpenAIRealtimeBidirectionalModel(BidirectionalModel): - """OpenAI Realtime API provider for Strands bidirectional streaming. - - Provides real-time audio/text communication through OpenAI's Realtime API - with WebSocket connections, voice activity detection, and function calling. - """ - - def __init__( - self, - model: str = DEFAULT_MODEL, - api_key: str | None = None, - **config: any - ) -> None: - """Initialize OpenAI Realtime bidirectional model.""" - self.model = model - self.api_key = api_key - self.config = config - - import os - if not self.api_key: - self.api_key = os.getenv("OPENAI_API_KEY") - if not self.api_key: - raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter.") - - logger.debug("OpenAI Realtime bidirectional model initialized: %s", model) - - async def create_bidirectional_connection( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs, - ) -> BidirectionalModelSession: - """Create bidirectional connection to OpenAI Realtime API.""" - logger.info("Creating OpenAI Realtime connection...") - - try: - url = f"{OPENAI_REALTIME_URL}?model={self.model}" - - headers = [("Authorization", f"Bearer {self.api_key}")] - if "organization" in self.config: - headers.append(("OpenAI-Organization", self.config["organization"])) - if "project" in self.config: - headers.append(("OpenAI-Project", self.config["project"])) - - websocket = await websockets.connect(url, additional_headers=headers) - logger.info("WebSocket connected successfully") - - session = OpenAIRealtimeSession(websocket, self.config) - await session.initialize(system_prompt, tools, messages) - - logger.info("OpenAI Realtime connection established") - return session - - except Exception as e: - logger.error("OpenAI connection error: %s", e) - raise \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py index 8c3ae3b4c..b0a41f20d 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py @@ -17,7 +17,7 @@ from strands_tools import calculator from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel +from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicModel def test_direct_tools(): @@ -30,7 +30,7 @@ def test_direct_tools(): return try: - model = NovaSonicBidirectionalModel() + model = NovaSonicModel() agent = BidirectionalAgent(model=model, tools=[calculator]) # Test calculator @@ -185,7 +185,7 @@ async def main(duration=180): print("Audio optimizations: 1024-byte buffers, balanced smooth playback + responsive interruption") # Initialize model and agent - model = NovaSonicBidirectionalModel(region="us-east-1") + model = NovaSonicModel(region="us-east-1") agent = BidirectionalAgent(model=model, tools=[calculator], system_prompt="You are a helpful assistant.") await agent.start() diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py index 660040f3e..90e82c2bc 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py @@ -14,7 +14,7 @@ from strands_tools import calculator from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeBidirectionalModel +from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeModel async def play(context): @@ -205,7 +205,7 @@ async def main(): return False # Create OpenAI model - model = OpenAIRealtimeBidirectionalModel( + model = OpenAIRealtimeModel( model="gpt-4o-realtime-preview", api_key=api_key, session={ diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py index 4469e819a..23e97bd5d 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py @@ -38,7 +38,7 @@ from strands_tools import calculator from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveBidirectionalModel +from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveModel # Configure logging - debug only for Gemini Live, info for everything else logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') @@ -301,7 +301,7 @@ async def main(duration=180): # Initialize Gemini Live model with proper configuration logger.info("Initializing Gemini Live model with API key") - model = GeminiLiveBidirectionalModel( + model = GeminiLiveModel( model_id="gemini-2.5-flash-native-audio-preview-09-2025", api_key=api_key, params={ diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 4b215d74e..145710c3c 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -88,6 +88,20 @@ class ImageInputEvent(TypedDict): encoding: Literal["base64", "raw"] +class TextInputEvent(TypedDict): + """Text input event for sending text to the model. + + Used for sending text content through the send() method. + + Attributes: + text: The text content to send to the model. + role: The role of the message sender (typically "user"). + """ + + text: str + role: Role + + class TextOutputEvent(TypedDict): """Text output event from the model during bidirectional streaming. diff --git a/tests/strands/experimental/__init__.py b/tests/strands/experimental/__init__.py index e69de29bb..ac8db1d74 100644 --- a/tests/strands/experimental/__init__.py +++ b/tests/strands/experimental/__init__.py @@ -0,0 +1 @@ +"""Experimental features tests.""" diff --git a/tests/strands/experimental/bidirectional_streaming/__init__.py b/tests/strands/experimental/bidirectional_streaming/__init__.py new file mode 100644 index 000000000..ea37091cc --- /dev/null +++ b/tests/strands/experimental/bidirectional_streaming/__init__.py @@ -0,0 +1 @@ +"""Bidirectional streaming tests.""" diff --git a/tests/strands/experimental/bidirectional_streaming/models/__init__.py b/tests/strands/experimental/bidirectional_streaming/models/__init__.py new file mode 100644 index 000000000..ea9fbb2d0 --- /dev/null +++ b/tests/strands/experimental/bidirectional_streaming/models/__init__.py @@ -0,0 +1 @@ +"""Bidirectional streaming model tests.""" diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py new file mode 100644 index 000000000..8c0a61b4b --- /dev/null +++ b/tests/strands/experimental/bidirectional_streaming/models/test_gemini_live.py @@ -0,0 +1,371 @@ +"""Unit tests for Gemini Live bidirectional streaming model. + +Tests the unified GeminiLiveModel interface including: +- Model initialization and configuration +- Connection establishment and lifecycle +- Unified send() method with different content types +- Event receiving and conversion +""" + +import unittest.mock + +import pytest +from google import genai +from google.genai import types as genai_types + +from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveModel +from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( + AudioInputEvent, + ImageInputEvent, + TextInputEvent, +) +from strands.types.tools import ToolResult + + +@pytest.fixture +def mock_genai_client(): + """Mock the Google GenAI client.""" + with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.gemini_live.genai.Client") as mock_client_cls: + mock_client = mock_client_cls.return_value + mock_client.aio = unittest.mock.MagicMock() + + # Mock the live session + mock_live_session = unittest.mock.AsyncMock() + + # Mock the context manager + mock_live_session_cm = unittest.mock.MagicMock() + mock_live_session_cm.__aenter__ = unittest.mock.AsyncMock(return_value=mock_live_session) + mock_live_session_cm.__aexit__ = unittest.mock.AsyncMock(return_value=None) + + # Make connect return the context manager + mock_client.aio.live.connect = unittest.mock.MagicMock(return_value=mock_live_session_cm) + + yield mock_client, mock_live_session, mock_live_session_cm + + +@pytest.fixture +def model_id(): + return "models/gemini-2.0-flash-live-preview-04-09" + + +@pytest.fixture +def api_key(): + return "test-api-key" + + +@pytest.fixture +def model(mock_genai_client, model_id, api_key): + """Create a GeminiLiveModel instance.""" + _ = mock_genai_client + return GeminiLiveModel(model_id=model_id, api_key=api_key) + + +@pytest.fixture +def tool_spec(): + return { + "description": "Calculate mathematical expressions", + "name": "calculator", + "inputSchema": {"json": {"type": "object", "properties": {"expression": {"type": "string"}}}}, + } + + +@pytest.fixture +def system_prompt(): + return "You are a helpful assistant" + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "Hello"}]}] + + +# Initialization Tests + + +def test_model_initialization(mock_genai_client, model_id, api_key): + """Test model initialization with various configurations.""" + _ = mock_genai_client + + # Test default config + model_default = GeminiLiveModel() + assert model_default.model_id == "models/gemini-2.0-flash-live-preview-04-09" + assert model_default.api_key is None + assert model_default._active is False + assert model_default.live_session is None + + # Test with API key + model_with_key = GeminiLiveModel(model_id=model_id, api_key=api_key) + assert model_with_key.model_id == model_id + assert model_with_key.api_key == api_key + + # Test with custom config + live_config = {"temperature": 0.7, "top_p": 0.9} + model_custom = GeminiLiveModel(model_id=model_id, live_config=live_config) + assert model_custom.live_config == live_config + + +# Connection Tests + + +@pytest.mark.asyncio +async def test_connection_lifecycle(mock_genai_client, model, system_prompt, tool_spec, messages): + """Test complete connection lifecycle with various configurations.""" + mock_client, mock_live_session, mock_live_session_cm = mock_genai_client + + # Test basic connection + await model.connect() + assert model._active is True + assert model.session_id is not None + assert model.live_session == mock_live_session + mock_client.aio.live.connect.assert_called_once() + + # Test close + await model.close() + assert model._active is False + mock_live_session_cm.__aexit__.assert_called_once() + + # Test connection with system prompt + await model.connect(system_prompt=system_prompt) + call_args = mock_client.aio.live.connect.call_args + config = call_args.kwargs.get("config", {}) + assert config.get("system_instruction") == system_prompt + await model.close() + + # Test connection with tools + await model.connect(tools=[tool_spec]) + call_args = mock_client.aio.live.connect.call_args + config = call_args.kwargs.get("config", {}) + assert "tools" in config + assert len(config["tools"]) > 0 + await model.close() + + # Test connection with messages + await model.connect(messages=messages) + mock_live_session.send_client_content.assert_called() + await model.close() + + +@pytest.mark.asyncio +async def test_connection_edge_cases(mock_genai_client, api_key, model_id): + """Test connection error handling and edge cases.""" + mock_client, _, mock_live_session_cm = mock_genai_client + + # Test connection error + model1 = GeminiLiveModel(model_id=model_id, api_key=api_key) + mock_client.aio.live.connect.side_effect = Exception("Connection failed") + with pytest.raises(Exception, match="Connection failed"): + await model1.connect() + + # Reset mock for next tests + mock_client.aio.live.connect.side_effect = None + + # Test double connection + model2 = GeminiLiveModel(model_id=model_id, api_key=api_key) + await model2.connect() + with pytest.raises(RuntimeError, match="Connection already active"): + await model2.connect() + await model2.close() + + # Test close when not connected + model3 = GeminiLiveModel(model_id=model_id, api_key=api_key) + await model3.close() # Should not raise + + # Test close error handling + model4 = GeminiLiveModel(model_id=model_id, api_key=api_key) + await model4.connect() + mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") + with pytest.raises(Exception, match="Close failed"): + await model4.close() + + +# Send Method Tests + + +@pytest.mark.asyncio +async def test_send_all_content_types(mock_genai_client, model): + """Test sending all content types through unified send() method.""" + _, mock_live_session, _ = mock_genai_client + await model.connect() + + # Test text input + text_input: TextInputEvent = {"text": "Hello", "role": "user"} + await model.send(text_input) + mock_live_session.send_client_content.assert_called_once() + call_args = mock_live_session.send_client_content.call_args + content = call_args.kwargs.get("turns") + assert content.role == "user" + assert content.parts[0].text == "Hello" + + # Test audio input + audio_input: AudioInputEvent = { + "audioData": b"audio_bytes", + "format": "pcm", + "sampleRate": 16000, + "channels": 1, + } + await model.send(audio_input) + mock_live_session.send_realtime_input.assert_called_once() + + # Test image input + image_input: ImageInputEvent = { + "imageData": b"image_bytes", + "mimeType": "image/jpeg", + "encoding": "raw", + } + await model.send(image_input) + mock_live_session.send.assert_called_once() + + # Test tool result + tool_result: ToolResult = { + "toolUseId": "tool-123", + "status": "success", + "content": [{"text": "Result: 42"}], + } + await model.send(tool_result) + mock_live_session.send_tool_response.assert_called_once() + + await model.close() + + +@pytest.mark.asyncio +async def test_send_edge_cases(mock_genai_client, model): + """Test send() edge cases and error handling.""" + _, mock_live_session, _ = mock_genai_client + + # Test send when inactive + text_input: TextInputEvent = {"text": "Hello", "role": "user"} + await model.send(text_input) + mock_live_session.send_client_content.assert_not_called() + + # Test unknown content type + await model.connect() + unknown_content = {"unknown_field": "value"} + await model.send(unknown_content) # Should not raise, just log warning + + await model.close() + + +# Receive Method Tests + + +@pytest.mark.asyncio +async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): + """Test that receive() emits connection start and end events.""" + _, mock_live_session, _ = mock_genai_client + mock_live_session.receive.return_value = agenerator([]) + + await model.connect() + + # Collect events + events = [] + async for event in model.receive(): + events.append(event) + # Close after first event to trigger connection end + if len(events) == 1: + await model.close() + + # Verify connection start and end + assert len(events) >= 2 + assert "BidirectionalConnectionStart" in events[0] + assert events[0]["BidirectionalConnectionStart"]["connectionId"] == model.session_id + assert "BidirectionalConnectionEnd" in events[-1] + + +@pytest.mark.asyncio +async def test_event_conversion(mock_genai_client, model): + """Test conversion of all Gemini Live event types to standard format.""" + _, _, _ = mock_genai_client + await model.connect() + + # Test text output + mock_text = unittest.mock.Mock() + mock_text.text = "Hello from Gemini" + mock_text.data = None + mock_text.tool_call = None + mock_text.server_content = None + + text_event = model._convert_gemini_live_event(mock_text) + assert "textOutput" in text_event + assert text_event["textOutput"]["text"] == "Hello from Gemini" + assert text_event["textOutput"]["role"] == "assistant" + + # Test audio output + mock_audio = unittest.mock.Mock() + mock_audio.text = None + mock_audio.data = b"audio_data" + mock_audio.tool_call = None + mock_audio.server_content = None + + audio_event = model._convert_gemini_live_event(mock_audio) + assert "audioOutput" in audio_event + assert audio_event["audioOutput"]["audioData"] == b"audio_data" + assert audio_event["audioOutput"]["format"] == "pcm" + + # Test tool call + mock_func_call = unittest.mock.Mock() + mock_func_call.id = "tool-123" + mock_func_call.name = "calculator" + mock_func_call.args = {"expression": "2+2"} + + mock_tool_call = unittest.mock.Mock() + mock_tool_call.function_calls = [mock_func_call] + + mock_tool = unittest.mock.Mock() + mock_tool.text = None + mock_tool.data = None + mock_tool.tool_call = mock_tool_call + mock_tool.server_content = None + + tool_event = model._convert_gemini_live_event(mock_tool) + assert "toolUse" in tool_event + assert tool_event["toolUse"]["toolUseId"] == "tool-123" + assert tool_event["toolUse"]["name"] == "calculator" + + # Test interruption + mock_server_content = unittest.mock.Mock() + mock_server_content.interrupted = True + mock_server_content.input_transcription = None + mock_server_content.output_transcription = None + + mock_interrupt = unittest.mock.Mock() + mock_interrupt.text = None + mock_interrupt.data = None + mock_interrupt.tool_call = None + mock_interrupt.server_content = mock_server_content + + interrupt_event = model._convert_gemini_live_event(mock_interrupt) + assert "interruptionDetected" in interrupt_event + assert interrupt_event["interruptionDetected"]["reason"] == "user_input" + + await model.close() + + +# Helper Method Tests + + +def test_config_building(model, system_prompt, tool_spec): + """Test building live config with various options.""" + # Test basic config + config_basic = model._build_live_config() + assert isinstance(config_basic, dict) + + # Test with system prompt + config_prompt = model._build_live_config(system_prompt=system_prompt) + assert config_prompt["system_instruction"] == system_prompt + + # Test with tools + config_tools = model._build_live_config(tools=[tool_spec]) + assert "tools" in config_tools + assert len(config_tools["tools"]) > 0 + + +def test_tool_formatting(model, tool_spec): + """Test tool formatting for Gemini Live API.""" + # Test with tools + formatted_tools = model._format_tools_for_live_api([tool_spec]) + assert len(formatted_tools) == 1 + assert isinstance(formatted_tools[0], genai_types.Tool) + + # Test empty list + formatted_empty = model._format_tools_for_live_api([]) + assert formatted_empty == [] diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py new file mode 100644 index 000000000..7265bfacd --- /dev/null +++ b/tests/strands/experimental/bidirectional_streaming/models/test_novasonic.py @@ -0,0 +1,433 @@ +"""Unit tests for Nova Sonic bidirectional model implementation. + +Tests the unified BidirectionalModel interface implementation for Amazon Nova Sonic, +covering connection lifecycle, event conversion, audio streaming, and tool execution. +""" + +import asyncio +import base64 +import json +from unittest.mock import AsyncMock, patch + +import pytest +import pytest_asyncio + +from strands.experimental.bidirectional_streaming.models.novasonic import ( + NovaSonicModel, +) +from strands.types.tools import ToolResult + + +# Test fixtures +@pytest.fixture +def model_id(): + """Nova Sonic model identifier.""" + return "amazon.nova-sonic-v1:0" + + +@pytest.fixture +def region(): + """AWS region.""" + return "us-east-1" + + +@pytest.fixture +def mock_stream(): + """Mock Nova Sonic bidirectional stream.""" + stream = AsyncMock() + stream.input_stream = AsyncMock() + stream.input_stream.send = AsyncMock() + stream.input_stream.close = AsyncMock() + stream.await_output = AsyncMock() + return stream + + +@pytest.fixture +def mock_client(mock_stream): + """Mock Bedrock Runtime client.""" + client = AsyncMock() + client.invoke_model_with_bidirectional_stream = AsyncMock(return_value=mock_stream) + return client + + +@pytest_asyncio.fixture +async def nova_model(model_id, region): + """Create Nova Sonic model instance.""" + model = NovaSonicModel(model_id=model_id, region=region) + yield model + # Cleanup + if model._active: + await model.close() + + +# Initialization and Connection Tests + + +@pytest.mark.asyncio +async def test_model_initialization(model_id, region): + """Test model initialization with configuration.""" + model = NovaSonicModel(model_id=model_id, region=region) + + assert model.model_id == model_id + assert model.region == region + assert model.stream is None + assert not model._active + assert model.session_id is None + + +@pytest.mark.asyncio +async def test_connection_lifecycle(nova_model, mock_client, mock_stream): + """Test complete connection lifecycle with various configurations.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model.client = mock_client + + # Test basic connection + await nova_model.connect(system_prompt="Test system prompt") + assert nova_model._active + assert nova_model.stream == mock_stream + assert nova_model.session_id is not None + assert mock_client.invoke_model_with_bidirectional_stream.called + + # Test close + await nova_model.close() + assert not nova_model._active + assert mock_stream.input_stream.close.called + + # Test connection with tools + tools = [ + { + "name": "get_weather", + "description": "Get weather information", + "inputSchema": {"json": json.dumps({"type": "object", "properties": {}})} + } + ] + await nova_model.connect(system_prompt="You are helpful", tools=tools) + # Verify initialization events were sent (connectionStart, promptStart, system prompt) + assert mock_stream.input_stream.send.call_count >= 3 + await nova_model.close() + + +@pytest.mark.asyncio +async def test_connection_edge_cases(nova_model, mock_client, mock_stream, model_id, region): + """Test connection error handling and edge cases.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model.client = mock_client + + # Test double connection + await nova_model.connect() + with pytest.raises(RuntimeError, match="Connection already active"): + await nova_model.connect() + await nova_model.close() + + # Test close when already closed + model2 = NovaSonicModel(model_id=model_id, region=region) + await model2.close() # Should not raise + await model2.close() # Second call should also be safe + + +# Send Method Tests + + +@pytest.mark.asyncio +async def test_send_all_content_types(nova_model, mock_client, mock_stream): + """Test sending all content types through unified send() method.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model.client = mock_client + + await nova_model.connect() + + # Test text content + text_event = {"text": "Hello, Nova!", "role": "user"} + await nova_model.send(text_event) + # Should send contentStart, textInput, and contentEnd + assert mock_stream.input_stream.send.call_count >= 3 + + # Test audio content + audio_event = { + "audioData": b"audio data", + "format": "pcm", + "sampleRate": 16000, + "channels": 1 + } + await nova_model.send(audio_event) + # Should start audio connection and send audio + assert nova_model.audio_connection_active + assert mock_stream.input_stream.send.called + + # Test tool result + tool_result = { + "toolUseId": "tool-123", + "status": "success", + "content": [{"text": "Weather is sunny"}] + } + await nova_model.send(tool_result) + # Should send contentStart, toolResult, and contentEnd + assert mock_stream.input_stream.send.called + + await nova_model.close() + + +@pytest.mark.asyncio +async def test_send_edge_cases(nova_model, mock_client, mock_stream, caplog): + """Test send() edge cases and error handling.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model.client = mock_client + + # Test send when inactive + text_event = {"text": "Hello", "role": "user"} + await nova_model.send(text_event) # Should not raise + + # Test image content (not supported) + await nova_model.connect() + image_event = { + "imageData": b"image data", + "mimeType": "image/jpeg" + } + await nova_model.send(image_event) + # Should log warning about unsupported image input + assert any("not supported" in record.message.lower() for record in caplog.records) + + await nova_model.close() + + +# Receive and Event Conversion Tests + + +@pytest.mark.asyncio +async def test_receive_lifecycle_events(nova_model, mock_client, mock_stream): + """Test that receive() emits connection start and end events.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model.client = mock_client + + # Setup mock to return no events and then stop + async def mock_wait_for(*args, **kwargs): + await asyncio.sleep(0.1) + nova_model._active = False + raise asyncio.TimeoutError() + + with patch("asyncio.wait_for", side_effect=mock_wait_for): + await nova_model.connect() + + events = [] + async for event in nova_model.receive(): + events.append(event) + + # Should have connection start and end + assert len(events) >= 2 + assert "BidirectionalConnectionStart" in events[0] + assert events[0]["BidirectionalConnectionStart"]["connectionId"] == nova_model.session_id + assert "BidirectionalConnectionEnd" in events[-1] + + +@pytest.mark.asyncio +async def test_event_conversion(nova_model): + """Test conversion of all Nova Sonic event types to standard format.""" + # Test audio output + audio_bytes = b"test audio data" + audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") + nova_event = {"audioOutput": {"content": audio_base64}} + result = nova_model._convert_nova_event(nova_event) + assert result is not None + assert "audioOutput" in result + assert result["audioOutput"]["audioData"] == audio_bytes + assert result["audioOutput"]["format"] == "pcm" + assert result["audioOutput"]["sampleRate"] == 24000 + + # Test text output + nova_event = {"textOutput": {"content": "Hello, world!", "role": "ASSISTANT"}} + result = nova_model._convert_nova_event(nova_event) + assert result is not None + assert "textOutput" in result + assert result["textOutput"]["text"] == "Hello, world!" + assert result["textOutput"]["role"] == "assistant" + + # Test tool use + tool_input = {"location": "Seattle"} + nova_event = { + "toolUse": { + "toolUseId": "tool-123", + "toolName": "get_weather", + "content": json.dumps(tool_input) + } + } + result = nova_model._convert_nova_event(nova_event) + assert result is not None + assert "toolUse" in result + assert result["toolUse"]["toolUseId"] == "tool-123" + assert result["toolUse"]["name"] == "get_weather" + assert result["toolUse"]["input"] == tool_input + + # Test interruption + nova_event = {"stopReason": "INTERRUPTED"} + result = nova_model._convert_nova_event(nova_event) + assert result is not None + assert "interruptionDetected" in result + assert result["interruptionDetected"]["reason"] == "user_input" + + # Test usage metrics + nova_event = { + "usageEvent": { + "totalTokens": 100, + "totalInputTokens": 40, + "totalOutputTokens": 60, + "details": { + "total": { + "output": { + "speechTokens": 30 + } + } + } + } + } + result = nova_model._convert_nova_event(nova_event) + assert result is not None + assert "usageMetrics" in result + assert result["usageMetrics"]["totalTokens"] == 100 + assert result["usageMetrics"]["inputTokens"] == 40 + assert result["usageMetrics"]["outputTokens"] == 60 + assert result["usageMetrics"]["audioTokens"] == 30 + + # Test content start tracks role + nova_event = {"contentStart": {"role": "USER"}} + result = nova_model._convert_nova_event(nova_event) + assert result is None # contentStart doesn't emit an event + assert nova_model._current_role == "USER" + + +# Audio Streaming Tests + + +@pytest.mark.asyncio +async def test_audio_connection_lifecycle(nova_model, mock_client, mock_stream): + """Test audio connection start and end lifecycle.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model.client = mock_client + + await nova_model.connect() + + # Start audio connection + await nova_model._start_audio_connection() + assert nova_model.audio_connection_active + + # End audio connection + await nova_model._end_audio_input() + assert not nova_model.audio_connection_active + + await nova_model.close() + + +@pytest.mark.asyncio +async def test_silence_detection(nova_model, mock_client, mock_stream): + """Test that silence detection automatically ends audio input.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model.client = mock_client + nova_model.silence_threshold = 0.1 # Short threshold for testing + + await nova_model.connect() + + # Send audio to start connection + audio_event = { + "audioData": b"audio data", + "format": "pcm", + "sampleRate": 16000, + "channels": 1 + } + + await nova_model.send(audio_event) + assert nova_model.audio_connection_active + + # Wait for silence detection + await asyncio.sleep(0.2) + + # Audio connection should be ended + assert not nova_model.audio_connection_active + + await nova_model.close() + + +# Helper Method Tests + + +@pytest.mark.asyncio +async def test_tool_configuration(nova_model): + """Test building tool configuration from tool specs.""" + tools = [ + { + "name": "get_weather", + "description": "Get weather information", + "inputSchema": { + "json": json.dumps({ + "type": "object", + "properties": { + "location": {"type": "string"} + } + }) + } + } + ] + + tool_config = nova_model._build_tool_configuration(tools) + + assert len(tool_config) == 1 + assert tool_config[0]["toolSpec"]["name"] == "get_weather" + assert tool_config[0]["toolSpec"]["description"] == "Get weather information" + assert "inputSchema" in tool_config[0]["toolSpec"] + + +@pytest.mark.asyncio +async def test_event_templates(nova_model): + """Test event template generation.""" + # Test connection start event + event_json = nova_model._get_connection_start_event() + event = json.loads(event_json) + assert "event" in event + assert "sessionStart" in event["event"] + assert "inferenceConfiguration" in event["event"]["sessionStart"] + + # Test prompt start event + nova_model.session_id = "test-session" + event_json = nova_model._get_prompt_start_event([]) + event = json.loads(event_json) + assert "event" in event + assert "promptStart" in event["event"] + assert event["event"]["promptStart"]["promptName"] == "test-session" + + # Test text input event + content_name = "test-content" + event_json = nova_model._get_text_input_event(content_name, "Hello") + event = json.loads(event_json) + assert "event" in event + assert "textInput" in event["event"] + assert event["event"]["textInput"]["content"] == "Hello" + + # Test tool result event + result = {"result": "Success"} + event_json = nova_model._get_tool_result_event(content_name, result) + event = json.loads(event_json) + assert "event" in event + assert "toolResult" in event["event"] + assert json.loads(event["event"]["toolResult"]["content"]) == result + + +# Error Handling Tests + + +@pytest.mark.asyncio +async def test_error_handling(nova_model, mock_client, mock_stream): + """Test error handling in various scenarios.""" + with patch.object(nova_model, "_initialize_client", new_callable=AsyncMock): + nova_model.client = mock_client + + # Test response processor handles errors gracefully + async def mock_error(*args, **kwargs): + raise Exception("Test error") + + mock_stream.await_output.side_effect = mock_error + + await nova_model.connect() + + # Wait a bit for response processor to handle error + await asyncio.sleep(0.1) + + # Should still be able to close cleanly + await nova_model.close() diff --git a/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py new file mode 100644 index 000000000..1c0b949b0 --- /dev/null +++ b/tests/strands/experimental/bidirectional_streaming/models/test_openai_realtime.py @@ -0,0 +1,469 @@ +"""Unit tests for OpenAI Realtime bidirectional streaming model. + +Tests the unified OpenAIRealtimeModel interface including: +- Model initialization and configuration +- Connection establishment with WebSocket +- Unified send() method with different content types +- Event receiving and conversion +- Connection lifecycle management +""" + +import asyncio +import base64 +import json +import unittest.mock + +import pytest + +from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeModel +from strands.experimental.bidirectional_streaming.types.bidirectional_streaming import ( + AudioInputEvent, + ImageInputEvent, + TextInputEvent, +) +from strands.types.tools import ToolResult + + +@pytest.fixture +def mock_websocket(): + """Mock WebSocket connection.""" + mock_ws = unittest.mock.AsyncMock() + mock_ws.send = unittest.mock.AsyncMock() + mock_ws.close = unittest.mock.AsyncMock() + return mock_ws + + +@pytest.fixture +def mock_websockets_connect(mock_websocket): + """Mock websockets.connect function.""" + async def async_connect(*args, **kwargs): + return mock_websocket + + with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.websockets.connect") as mock_connect: + mock_connect.side_effect = async_connect + yield mock_connect, mock_websocket + + +@pytest.fixture +def model_name(): + return "gpt-realtime" + + +@pytest.fixture +def api_key(): + return "test-api-key" + + +@pytest.fixture +def model(api_key, model_name): + """Create an OpenAIRealtimeModel instance.""" + return OpenAIRealtimeModel(model=model_name, api_key=api_key) + + +@pytest.fixture +def tool_spec(): + return { + "description": "Calculate mathematical expressions", + "name": "calculator", + "inputSchema": {"json": {"type": "object", "properties": {"expression": {"type": "string"}}}}, + } + + +@pytest.fixture +def system_prompt(): + return "You are a helpful assistant" + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "Hello"}]}] + + +# Initialization Tests + + +def test_model_initialization(api_key, model_name): + """Test model initialization with various configurations.""" + # Test default config + model_default = OpenAIRealtimeModel(api_key="test-key") + assert model_default.model == "gpt-realtime" + assert model_default.api_key == "test-key" + assert model_default._active is False + assert model_default.websocket is None + + # Test with custom model + model_custom = OpenAIRealtimeModel(model=model_name, api_key=api_key) + assert model_custom.model == model_name + assert model_custom.api_key == api_key + + # Test with organization and project + model_org = OpenAIRealtimeModel( + model=model_name, + api_key=api_key, + organization="org-123", + project="proj-456" + ) + assert model_org.organization == "org-123" + assert model_org.project == "proj-456" + + # Test with env API key + with unittest.mock.patch.dict("os.environ", {"OPENAI_API_KEY": "env-key"}): + model_env = OpenAIRealtimeModel() + assert model_env.api_key == "env-key" + + +def test_init_without_api_key_raises(): + """Test that initialization without API key raises error.""" + with unittest.mock.patch.dict("os.environ", {}, clear=True): + with pytest.raises(ValueError, match="OpenAI API key is required"): + OpenAIRealtimeModel() + + +# Connection Tests + + +@pytest.mark.asyncio +async def test_connection_lifecycle(mock_websockets_connect, model, system_prompt, tool_spec, messages): + """Test complete connection lifecycle with various configurations.""" + mock_connect, mock_ws = mock_websockets_connect + + # Test basic connection + await model.connect() + assert model._active is True + assert model.session_id is not None + assert model.websocket == mock_ws + assert model._event_queue is not None + assert model._response_task is not None + mock_connect.assert_called_once() + + # Test close + await model.close() + assert model._active is False + mock_ws.close.assert_called_once() + + # Test connection with system prompt + await model.connect(system_prompt=system_prompt) + calls = mock_ws.send.call_args_list + session_update = next( + (json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "session.update"), + None + ) + assert session_update is not None + assert system_prompt in session_update["session"]["instructions"] + await model.close() + + # Test connection with tools + await model.connect(tools=[tool_spec]) + calls = mock_ws.send.call_args_list + # Tools are sent in a separate session.update after initial connection + session_updates = [json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "session.update"] + assert len(session_updates) > 0 + # Check if any session update has tools + has_tools = any("tools" in update.get("session", {}) for update in session_updates) + assert has_tools + await model.close() + + # Test connection with messages + await model.connect(messages=messages) + calls = mock_ws.send.call_args_list + item_creates = [json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "conversation.item.create"] + assert len(item_creates) > 0 + await model.close() + + # Test connection with organization header + model_org = OpenAIRealtimeModel(api_key="test-key", organization="org-123") + await model_org.connect() + call_kwargs = mock_connect.call_args.kwargs + headers = call_kwargs.get("extra_headers", []) + org_header = [h for h in headers if h[0] == "OpenAI-Organization"] + assert len(org_header) == 1 + assert org_header[0][1] == "org-123" + await model_org.close() + + +@pytest.mark.asyncio +async def test_connection_edge_cases(mock_websockets_connect, api_key, model_name): + """Test connection error handling and edge cases.""" + mock_connect, mock_ws = mock_websockets_connect + + # Test connection error + model1 = OpenAIRealtimeModel(model=model_name, api_key=api_key) + mock_connect.side_effect = Exception("Connection failed") + with pytest.raises(Exception, match="Connection failed"): + await model1.connect() + + # Reset mock + async def async_connect(*args, **kwargs): + return mock_ws + mock_connect.side_effect = async_connect + + # Test double connection + model2 = OpenAIRealtimeModel(model=model_name, api_key=api_key) + await model2.connect() + with pytest.raises(RuntimeError, match="Connection already active"): + await model2.connect() + await model2.close() + + # Test close when not connected + model3 = OpenAIRealtimeModel(model=model_name, api_key=api_key) + await model3.close() # Should not raise + + # Test close error handling (should not raise, just log) + model4 = OpenAIRealtimeModel(model=model_name, api_key=api_key) + await model4.connect() + mock_ws.close.side_effect = Exception("Close failed") + await model4.close() # Should not raise + assert model4._active is False + + +# Send Method Tests + + +@pytest.mark.asyncio +async def test_send_all_content_types(mock_websockets_connect, model): + """Test sending all content types through unified send() method.""" + _, mock_ws = mock_websockets_connect + await model.connect() + + # Test text input + text_input: TextInputEvent = {"text": "Hello", "role": "user"} + await model.send(text_input) + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + item_create = [m for m in messages if m.get("type") == "conversation.item.create"] + response_create = [m for m in messages if m.get("type") == "response.create"] + assert len(item_create) > 0 + assert len(response_create) > 0 + + # Test audio input + audio_input: AudioInputEvent = { + "audioData": b"audio_bytes", + "format": "pcm", + "sampleRate": 24000, + "channels": 1, + } + await model.send(audio_input) + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + audio_append = [m for m in messages if m.get("type") == "input_audio_buffer.append"] + assert len(audio_append) > 0 + assert "audio" in audio_append[0] + decoded = base64.b64decode(audio_append[0]["audio"]) + assert decoded == b"audio_bytes" + + # Test tool result + tool_result: ToolResult = { + "toolUseId": "tool-123", + "status": "success", + "content": [{"text": "Result: 42"}], + } + await model.send(tool_result) + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + item_create = [m for m in messages if m.get("type") == "conversation.item.create"] + assert len(item_create) > 0 + item = item_create[-1].get("item", {}) + assert item.get("type") == "function_call_output" + assert item.get("call_id") == "tool-123" + + await model.close() + + +@pytest.mark.asyncio +async def test_send_edge_cases(mock_websockets_connect, model): + """Test send() edge cases and error handling.""" + _, mock_ws = mock_websockets_connect + + # Test send when inactive + text_input: TextInputEvent = {"text": "Hello", "role": "user"} + await model.send(text_input) + mock_ws.send.assert_not_called() + + # Test image input (not supported) + await model.connect() + image_input: ImageInputEvent = { + "imageData": b"image_bytes", + "mimeType": "image/jpeg", + "encoding": "raw", + } + with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.logger") as mock_logger: + await model.send(image_input) + mock_logger.warning.assert_called_with("Image input not supported by OpenAI Realtime API") + + # Test unknown content type + unknown_content = {"unknown_field": "value"} + with unittest.mock.patch("strands.experimental.bidirectional_streaming.models.openai.logger") as mock_logger: + await model.send(unknown_content) + assert mock_logger.warning.called + + await model.close() + + +# Receive Method Tests + + +@pytest.mark.asyncio +async def test_receive_lifecycle_events(mock_websockets_connect, model): + """Test that receive() emits connection start and end events.""" + _, _ = mock_websockets_connect + + await model.connect() + + # Get first event + receive_gen = model.receive() + first_event = await anext(receive_gen) + + # First event should be connection start + assert "BidirectionalConnectionStart" in first_event + assert first_event["BidirectionalConnectionStart"]["connectionId"] == model.session_id + + # Close to trigger connection end + await model.close() + + # Collect remaining events + events = [first_event] + try: + async for event in receive_gen: + events.append(event) + except StopAsyncIteration: + pass + + # Last event should be connection end + assert "BidirectionalConnectionEnd" in events[-1] + + +@pytest.mark.asyncio +async def test_event_conversion(mock_websockets_connect, model): + """Test conversion of all OpenAI event types to standard format.""" + _, _ = mock_websockets_connect + await model.connect() + + # Test audio output + audio_event = { + "type": "response.output_audio.delta", + "delta": base64.b64encode(b"audio_data").decode() + } + converted = model._convert_openai_event(audio_event) + assert "audioOutput" in converted + assert converted["audioOutput"]["audioData"] == b"audio_data" + assert converted["audioOutput"]["format"] == "pcm" + + # Test text output + text_event = { + "type": "response.output_text.delta", + "delta": "Hello from OpenAI" + } + converted = model._convert_openai_event(text_event) + assert "textOutput" in converted + assert converted["textOutput"]["text"] == "Hello from OpenAI" + assert converted["textOutput"]["role"] == "assistant" + + # Test function call sequence + item_added = { + "type": "response.output_item.added", + "item": { + "type": "function_call", + "call_id": "call-123", + "name": "calculator" + } + } + model._convert_openai_event(item_added) + + args_delta = { + "type": "response.function_call_arguments.delta", + "call_id": "call-123", + "delta": '{"expression": "2+2"}' + } + model._convert_openai_event(args_delta) + + args_done = { + "type": "response.function_call_arguments.done", + "call_id": "call-123" + } + converted = model._convert_openai_event(args_done) + assert "toolUse" in converted + assert converted["toolUse"]["toolUseId"] == "call-123" + assert converted["toolUse"]["name"] == "calculator" + assert converted["toolUse"]["input"]["expression"] == "2+2" + + # Test voice activity + speech_started = { + "type": "input_audio_buffer.speech_started" + } + converted = model._convert_openai_event(speech_started) + assert "voiceActivity" in converted + assert converted["voiceActivity"]["activityType"] == "speech_started" + + await model.close() + + +# Helper Method Tests + + +def test_config_building(model, system_prompt, tool_spec): + """Test building session config with various options.""" + # Test basic config + config_basic = model._build_session_config(None, None) + assert isinstance(config_basic, dict) + assert "instructions" in config_basic + assert "audio" in config_basic + + # Test with system prompt + config_prompt = model._build_session_config(system_prompt, None) + assert config_prompt["instructions"] == system_prompt + + # Test with tools + config_tools = model._build_session_config(None, [tool_spec]) + assert "tools" in config_tools + assert len(config_tools["tools"]) > 0 + + +def test_tool_conversion(model, tool_spec): + """Test tool conversion to OpenAI format.""" + # Test with tools + openai_tools = model._convert_tools_to_openai_format([tool_spec]) + assert len(openai_tools) == 1 + assert openai_tools[0]["type"] == "function" + assert openai_tools[0]["name"] == "calculator" + assert openai_tools[0]["description"] == "Calculate mathematical expressions" + + # Test empty list + openai_empty = model._convert_tools_to_openai_format([]) + assert openai_empty == [] + + +def test_helper_methods(model): + """Test various helper methods.""" + # Test _require_active + assert model._require_active() is False + model._active = True + assert model._require_active() is True + model._active = False + + # Test _create_text_event + text_event = model._create_text_event("Hello", "user") + assert "textOutput" in text_event + assert text_event["textOutput"]["text"] == "Hello" + assert text_event["textOutput"]["role"] == "user" + + # Test _create_voice_activity_event + voice_event = model._create_voice_activity_event("speech_started") + assert "voiceActivity" in voice_event + assert voice_event["voiceActivity"]["activityType"] == "speech_started" + + +@pytest.mark.asyncio +async def test_send_event_helper(mock_websockets_connect, model): + """Test _send_event helper method.""" + _, mock_ws = mock_websockets_connect + await model.connect() + + test_event = {"type": "test.event", "data": "test"} + await model._send_event(test_event) + + calls = mock_ws.send.call_args_list + last_call = calls[-1] + sent_message = json.loads(last_call[0][0]) + assert sent_message == test_event + + await model.close()