diff --git a/pyproject.toml b/pyproject.toml index 0d21a881..ddf90ee0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "elevenlabs" [tool.poetry] name = "elevenlabs" -version = "v2.10.0" +version = "v2.11.0" description = "" readme = "README.md" authors = [] diff --git a/src/elevenlabs/conversational_ai/conversation.py b/src/elevenlabs/conversational_ai/conversation.py index 1e0c0156..62a50ab7 100644 --- a/src/elevenlabs/conversational_ai/conversation.py +++ b/src/elevenlabs/conversational_ai/conversation.py @@ -8,6 +8,7 @@ from enum import Enum from websockets.sync.client import connect, Connection +import websockets from websockets.exceptions import ConnectionClosedOK from ..base_client import BaseElevenLabs @@ -105,6 +106,50 @@ def interrupt(self): pass +class AsyncAudioInterface(ABC): + """AsyncAudioInterface provides an async abstraction for handling audio input and output.""" + + @abstractmethod + async def start(self, input_callback: Callable[[bytes], Awaitable[None]]): + """Starts the audio interface. + + Called one time before the conversation starts. + The `input_callback` should be called regularly with input audio chunks from + the user. The audio should be in 16-bit PCM mono format at 16kHz. Recommended + chunk size is 4000 samples (250 milliseconds). + """ + pass + + @abstractmethod + async def stop(self): + """Stops the audio interface. + + Called one time after the conversation ends. Should clean up any resources + used by the audio interface and stop any audio streams. Do not call the + `input_callback` from `start` after this method is called. + """ + pass + + @abstractmethod + async def output(self, audio: bytes): + """Output audio to the user. + + The `audio` input is in 16-bit PCM mono format at 16kHz. Implementations can + choose to do additional buffering. This method should return quickly and not + block the calling thread. + """ + pass + + @abstractmethod + async def interrupt(self): + """Interruption signal to stop any audio output. + + User has interrupted the agent and all previosly buffered audio output should + be stopped. + """ + pass + + class ClientTools: """Handles registration and execution of client-side tools that can be called by the agent. @@ -231,13 +276,167 @@ def __init__( self.user_id = user_id -class Conversation: - client: BaseElevenLabs - agent_id: str - requires_auth: bool - config: ConversationInitiationData +class BaseConversation: + """Base class for conversation implementations with shared parameters and logic.""" + + def __init__( + self, + client: BaseElevenLabs, + agent_id: str, + user_id: Optional[str] = None, + *, + requires_auth: bool, + config: Optional[ConversationInitiationData] = None, + client_tools: Optional[ClientTools] = None, + ): + self.client = client + self.agent_id = agent_id + self.user_id = user_id + self.requires_auth = requires_auth + self.config = config or ConversationInitiationData() + self.client_tools = client_tools or ClientTools() + + self.client_tools.start() + + self._conversation_id = None + self._last_interrupt_id = 0 + + def _get_wss_url(self): + base_ws_url = self.client._client_wrapper.get_environment().wss + return f"{base_ws_url}/v1/convai/conversation?agent_id={self.agent_id}&source=python_sdk&version={__version__}" + + def _get_signed_url(self): + response = self.client.conversational_ai.conversations.get_signed_url(agent_id=self.agent_id) + signed_url = response.signed_url + # Append source and version query parameters to the signed URL + separator = "&" if "?" in signed_url else "?" + return f"{signed_url}{separator}source=python_sdk&version={__version__}" + + def _create_initiation_message(self): + return json.dumps( + { + "type": "conversation_initiation_client_data", + "custom_llm_extra_body": self.config.extra_body, + "conversation_config_override": self.config.conversation_config_override, + "dynamic_variables": self.config.dynamic_variables, + "source_info": { + "source": "python_sdk", + "version": __version__, + }, + **({"user_id": self.config.user_id} if self.config.user_id else {}), + } + ) + + def _handle_message_core(self, message, message_handler): + """Core message handling logic shared between sync and async implementations. + + Args: + message: The parsed message dictionary + message_handler: Handler object with methods for different operations + """ + if message["type"] == "conversation_initiation_metadata": + event = message["conversation_initiation_metadata_event"] + assert self._conversation_id is None + self._conversation_id = event["conversation_id"] + + elif message["type"] == "audio": + event = message["audio_event"] + if int(event["event_id"]) <= self._last_interrupt_id: + return + audio = base64.b64decode(event["audio_base_64"]) + message_handler.handle_audio_output(audio) + + elif message["type"] == "agent_response": + if message_handler.callback_agent_response: + event = message["agent_response_event"] + message_handler.handle_agent_response(event["agent_response"].strip()) + + elif message["type"] == "agent_response_correction": + if message_handler.callback_agent_response_correction: + event = message["agent_response_correction_event"] + message_handler.handle_agent_response_correction( + event["original_agent_response"].strip(), + event["corrected_agent_response"].strip() + ) + + elif message["type"] == "user_transcript": + if message_handler.callback_user_transcript: + event = message["user_transcription_event"] + message_handler.handle_user_transcript(event["user_transcript"].strip()) + + elif message["type"] == "interruption": + event = message["interruption_event"] + self._last_interrupt_id = int(event["event_id"]) + message_handler.handle_interruption() + + elif message["type"] == "ping": + event = message["ping_event"] + message_handler.handle_ping(event) + if message_handler.callback_latency_measurement and event["ping_ms"]: + message_handler.handle_latency_measurement(int(event["ping_ms"])) + + elif message["type"] == "client_tool_call": + tool_call = message.get("client_tool_call", {}) + tool_name = tool_call.get("tool_name") + parameters = {"tool_call_id": tool_call["tool_call_id"], **tool_call.get("parameters", {})} + message_handler.handle_client_tool_call(tool_name, parameters) + else: + pass # Ignore all other message types. + + async def _handle_message_core_async(self, message, message_handler): + """Async wrapper for core message handling logic.""" + if message["type"] == "conversation_initiation_metadata": + event = message["conversation_initiation_metadata_event"] + assert self._conversation_id is None + self._conversation_id = event["conversation_id"] + + elif message["type"] == "audio": + event = message["audio_event"] + if int(event["event_id"]) <= self._last_interrupt_id: + return + audio = base64.b64decode(event["audio_base_64"]) + await message_handler.handle_audio_output(audio) + + elif message["type"] == "agent_response": + if message_handler.callback_agent_response: + event = message["agent_response_event"] + await message_handler.handle_agent_response(event["agent_response"].strip()) + + elif message["type"] == "agent_response_correction": + if message_handler.callback_agent_response_correction: + event = message["agent_response_correction_event"] + await message_handler.handle_agent_response_correction( + event["original_agent_response"].strip(), + event["corrected_agent_response"].strip() + ) + + elif message["type"] == "user_transcript": + if message_handler.callback_user_transcript: + event = message["user_transcription_event"] + await message_handler.handle_user_transcript(event["user_transcript"].strip()) + + elif message["type"] == "interruption": + event = message["interruption_event"] + self._last_interrupt_id = int(event["event_id"]) + await message_handler.handle_interruption() + + elif message["type"] == "ping": + event = message["ping_event"] + await message_handler.handle_ping(event) + if message_handler.callback_latency_measurement and event["ping_ms"]: + await message_handler.handle_latency_measurement(int(event["ping_ms"])) + + elif message["type"] == "client_tool_call": + tool_call = message.get("client_tool_call", {}) + tool_name = tool_call.get("tool_name") + parameters = {"tool_call_id": tool_call["tool_call_id"], **tool_call.get("parameters", {})} + message_handler.handle_client_tool_call(tool_name, parameters) + else: + pass # Ignore all other message types. + + +class Conversation(BaseConversation): audio_interface: AudioInterface - client_tools: Optional[ClientTools] callback_agent_response: Optional[Callable[[str], None]] callback_agent_response_correction: Optional[Callable[[str, str], None]] callback_user_transcript: Optional[Callable[[str], None]] @@ -246,8 +445,6 @@ class Conversation: _thread: Optional[threading.Thread] _should_stop: threading.Event - _conversation_id: Optional[str] - _last_interrupt_id: int _ws: Optional[Connection] def __init__( @@ -285,26 +482,25 @@ def __init__( callback_latency_measurement: Callback for latency measurements (in milliseconds). """ - self.client = client - self.agent_id = agent_id - self.user_id = user_id - self.requires_auth = requires_auth + super().__init__( + client=client, + agent_id=agent_id, + user_id=user_id, + requires_auth=requires_auth, + config=config, + client_tools=client_tools, + ) + self.audio_interface = audio_interface self.callback_agent_response = callback_agent_response - self.config = config or ConversationInitiationData() - self.client_tools = client_tools or ClientTools() self.callback_agent_response_correction = callback_agent_response_correction self.callback_user_transcript = callback_user_transcript self.callback_latency_measurement = callback_latency_measurement self.callback_end_session = callback_end_session - self.client_tools.start() - self._thread = None self._ws: Optional[Connection] = None self._should_stop = threading.Event() - self._conversation_id = None - self._last_interrupt_id = 0 def start_session(self): """Starts the conversation session. @@ -399,21 +595,7 @@ def send_contextual_update(self, text: str): def _run(self, ws_url: str): with connect(ws_url, max_size=16 * 1024 * 1024) as ws: self._ws = ws - ws.send( - json.dumps( - { - "type": "conversation_initiation_client_data", - "custom_llm_extra_body": self.config.extra_body, - "conversation_config_override": self.config.conversation_config_override, - "dynamic_variables": self.config.dynamic_variables, - "source_info": { - "source": "python_sdk", - "version": __version__, - }, - **({"user_id": self.config.user_id} if self.config.user_id else {}), - } - ) - ) + ws.send(self._create_initiation_message()) self._ws = ws def input_callback(audio): @@ -449,67 +631,299 @@ def input_callback(audio): self._ws = None def _handle_message(self, message, ws): - if message["type"] == "conversation_initiation_metadata": - event = message["conversation_initiation_metadata_event"] - assert self._conversation_id is None - self._conversation_id = event["conversation_id"] - - elif message["type"] == "audio": - event = message["audio_event"] - if int(event["event_id"]) <= self._last_interrupt_id: - return - audio = base64.b64decode(event["audio_base_64"]) - self.audio_interface.output(audio) - elif message["type"] == "agent_response": - if self.callback_agent_response: - event = message["agent_response_event"] - self.callback_agent_response(event["agent_response"].strip()) - elif message["type"] == "agent_response_correction": - if self.callback_agent_response_correction: - event = message["agent_response_correction_event"] - self.callback_agent_response_correction( - event["original_agent_response"].strip(), event["corrected_agent_response"].strip() - ) - elif message["type"] == "user_transcript": - if self.callback_user_transcript: - event = message["user_transcription_event"] - self.callback_user_transcript(event["user_transcript"].strip()) - elif message["type"] == "interruption": - event = message["interruption_event"] - self._last_interrupt_id = int(event["event_id"]) - self.audio_interface.interrupt() - elif message["type"] == "ping": - event = message["ping_event"] - ws.send( - json.dumps( - { - "type": "pong", - "event_id": event["event_id"], - } + class SyncMessageHandler: + def __init__(self, conversation, ws): + self.conversation = conversation + self.ws = ws + self.callback_agent_response = conversation.callback_agent_response + self.callback_agent_response_correction = conversation.callback_agent_response_correction + self.callback_user_transcript = conversation.callback_user_transcript + self.callback_latency_measurement = conversation.callback_latency_measurement + + def handle_audio_output(self, audio): + self.conversation.audio_interface.output(audio) + + def handle_agent_response(self, response): + self.conversation.callback_agent_response(response) + + def handle_agent_response_correction(self, original, corrected): + self.conversation.callback_agent_response_correction(original, corrected) + + def handle_user_transcript(self, transcript): + self.conversation.callback_user_transcript(transcript) + + def handle_interruption(self): + self.conversation.audio_interface.interrupt() + + def handle_ping(self, event): + self.ws.send( + json.dumps( + { + "type": "pong", + "event_id": event["event_id"], + } + ) ) - ) - if self.callback_latency_measurement and event["ping_ms"]: - self.callback_latency_measurement(int(event["ping_ms"])) - elif message["type"] == "client_tool_call": - tool_call = message.get("client_tool_call", {}) - tool_name = tool_call.get("tool_name") - parameters = {"tool_call_id": tool_call["tool_call_id"], **tool_call.get("parameters", {})} + + def handle_latency_measurement(self, latency): + self.conversation.callback_latency_measurement(latency) + + def handle_client_tool_call(self, tool_name, parameters): + def send_response(response): + if not self.conversation._should_stop.is_set(): + self.ws.send(json.dumps(response)) + + self.conversation.client_tools.execute_tool(tool_name, parameters, send_response) + + handler = SyncMessageHandler(self, ws) + self._handle_message_core(message, handler) + + +class AsyncConversation(BaseConversation): + audio_interface: AsyncAudioInterface + callback_agent_response: Optional[Callable[[str], Awaitable[None]]] + callback_agent_response_correction: Optional[Callable[[str, str], Awaitable[None]]] + callback_user_transcript: Optional[Callable[[str], Awaitable[None]]] + callback_latency_measurement: Optional[Callable[[int], Awaitable[None]]] + callback_end_session: Optional[Callable[[], Awaitable[None]]] + + _task: Optional[asyncio.Task] + _should_stop: asyncio.Event + _ws: Optional[websockets.WebSocketClientProtocol] - def send_response(response): - if not self._should_stop.is_set(): - ws.send(json.dumps(response)) + def __init__( + self, + client: BaseElevenLabs, + agent_id: str, + user_id: Optional[str] = None, + *, + requires_auth: bool, + audio_interface: AsyncAudioInterface, + config: Optional[ConversationInitiationData] = None, + client_tools: Optional[ClientTools] = None, + callback_agent_response: Optional[Callable[[str], Awaitable[None]]] = None, + callback_agent_response_correction: Optional[Callable[[str, str], Awaitable[None]]] = None, + callback_user_transcript: Optional[Callable[[str], Awaitable[None]]] = None, + callback_latency_measurement: Optional[Callable[[int], Awaitable[None]]] = None, + callback_end_session: Optional[Callable[[], Awaitable[None]]] = None, + ): + """Async Conversational AI session. - self.client_tools.execute_tool(tool_name, parameters, send_response) - else: - pass # Ignore all other message types. + BETA: This API is subject to change without regard to backwards compatibility. - def _get_wss_url(self): - base_ws_url = self.client._client_wrapper.get_environment().wss - return f"{base_ws_url}/v1/convai/conversation?agent_id={self.agent_id}&source=python_sdk&version={__version__}" + Args: + client: The ElevenLabs client to use for the conversation. + agent_id: The ID of the agent to converse with. + user_id: The ID of the user conversing with the agent. + requires_auth: Whether the agent requires authentication. + audio_interface: The async audio interface to use for input and output. + client_tools: The client tools to use for the conversation. + callback_agent_response: Async callback for agent responses. + callback_agent_response_correction: Async callback for agent response corrections. + First argument is the original response (previously given to + callback_agent_response), second argument is the corrected response. + callback_user_transcript: Async callback for user transcripts. + callback_latency_measurement: Async callback for latency measurements (in milliseconds). + callback_end_session: Async callback for when session ends. + """ - def _get_signed_url(self): - response = self.client.conversational_ai.conversations.get_signed_url(agent_id=self.agent_id) - signed_url = response.signed_url - # Append source and version query parameters to the signed URL - separator = "&" if "?" in signed_url else "?" - return f"{signed_url}{separator}source=python_sdk&version={__version__}" + super().__init__( + client=client, + agent_id=agent_id, + user_id=user_id, + requires_auth=requires_auth, + config=config, + client_tools=client_tools, + ) + + self.audio_interface = audio_interface + self.callback_agent_response = callback_agent_response + self.callback_agent_response_correction = callback_agent_response_correction + self.callback_user_transcript = callback_user_transcript + self.callback_latency_measurement = callback_latency_measurement + self.callback_end_session = callback_end_session + + self._task = None + self._ws = None + self._should_stop = asyncio.Event() + + async def start_session(self): + """Starts the conversation session. + + Will run in background task until `end_session` is called. + """ + ws_url = self._get_signed_url() if self.requires_auth else self._get_wss_url() + self._task = asyncio.create_task(self._run(ws_url)) + + async def end_session(self): + """Ends the conversation session and cleans up resources.""" + await self.audio_interface.stop() + self.client_tools.stop() + self._ws = None + self._should_stop.set() + + if self.callback_end_session: + await self.callback_end_session() + + async def wait_for_session_end(self) -> Optional[str]: + """Waits for the conversation session to end. + + You must call `end_session` before calling this method, otherwise it will block. + + Returns the conversation ID, if available. + """ + if not self._task: + raise RuntimeError("Session not started.") + await self._task + return self._conversation_id + + async def send_user_message(self, text: str): + """Send a text message from the user to the agent. + + Args: + text: The text message to send to the agent. + + Raises: + RuntimeError: If the session is not active or websocket is not connected. + """ + if not self._ws: + raise RuntimeError("Session not started or websocket not connected.") + + event = UserMessageClientToOrchestratorEvent(text=text) + try: + await self._ws.send(json.dumps(event.to_dict())) + except Exception as e: + print(f"Error sending user message: {e}") + raise + + async def register_user_activity(self): + """Register user activity to prevent session timeout. + + This sends a ping to the orchestrator to reset the timeout timer. + + Raises: + RuntimeError: If the session is not active or websocket is not connected. + """ + if not self._ws: + raise RuntimeError("Session not started or websocket not connected.") + + event = UserActivityClientToOrchestratorEvent() + try: + await self._ws.send(json.dumps(event.to_dict())) + except Exception as e: + print(f"Error registering user activity: {e}") + raise + + async def send_contextual_update(self, text: str): + """Send a contextual update to the conversation. + + Contextual updates are non-interrupting content that is sent to the server + to update the conversation state without directly prompting the agent. + + Args: + text: The contextual information to send to the conversation. + + Raises: + RuntimeError: If the session is not active or websocket is not connected. + """ + if not self._ws: + raise RuntimeError("Session not started or websocket not connected.") + + event = ContextualUpdateClientToOrchestratorEvent(text=text) + try: + await self._ws.send(json.dumps(event.to_dict())) + except Exception as e: + print(f"Error sending contextual update: {e}") + raise + + async def _run(self, ws_url: str): + async with websockets.connect(ws_url, max_size=16 * 1024 * 1024) as ws: + self._ws = ws + await ws.send(self._create_initiation_message()) + + async def input_callback(audio): + try: + await ws.send( + json.dumps( + { + "user_audio_chunk": base64.b64encode(audio).decode(), + } + ) + ) + except ConnectionClosedOK: + await self.end_session() + except Exception as e: + print(f"Error sending user audio chunk: {e}") + await self.end_session() + + await self.audio_interface.start(input_callback) + + try: + while not self._should_stop.is_set(): + try: + message_str = await asyncio.wait_for(ws.recv(), timeout=0.5) + if self._should_stop.is_set(): + return + message = json.loads(message_str) + await self._handle_message(message, ws) + except asyncio.TimeoutError: + pass + except ConnectionClosedOK: + await self.end_session() + break + except Exception as e: + print(f"Error receiving message: {e}") + await self.end_session() + break + finally: + self._ws = None + + async def _handle_message(self, message, ws): + class AsyncMessageHandler: + def __init__(self, conversation, ws): + self.conversation = conversation + self.ws = ws + self.callback_agent_response = conversation.callback_agent_response + self.callback_agent_response_correction = conversation.callback_agent_response_correction + self.callback_user_transcript = conversation.callback_user_transcript + self.callback_latency_measurement = conversation.callback_latency_measurement + + async def handle_audio_output(self, audio): + await self.conversation.audio_interface.output(audio) + + async def handle_agent_response(self, response): + await self.conversation.callback_agent_response(response) + + async def handle_agent_response_correction(self, original, corrected): + await self.conversation.callback_agent_response_correction(original, corrected) + + async def handle_user_transcript(self, transcript): + await self.conversation.callback_user_transcript(transcript) + + async def handle_interruption(self): + await self.conversation.audio_interface.interrupt() + + async def handle_ping(self, event): + await self.ws.send( + json.dumps( + { + "type": "pong", + "event_id": event["event_id"], + } + ) + ) + + async def handle_latency_measurement(self, latency): + await self.conversation.callback_latency_measurement(latency) + + def handle_client_tool_call(self, tool_name, parameters): + def send_response(response): + if not self.conversation._should_stop.is_set(): + asyncio.create_task(self.ws.send(json.dumps(response))) + + self.conversation.client_tools.execute_tool(tool_name, parameters, send_response) + + handler = AsyncMessageHandler(self, ws) + + # Use the shared core message handling logic with async wrapper + await self._handle_message_core_async(message, handler) diff --git a/src/elevenlabs/conversational_ai/default_audio_interface.py b/src/elevenlabs/conversational_ai/default_audio_interface.py index b1660d85..e3bd9ad8 100644 --- a/src/elevenlabs/conversational_ai/default_audio_interface.py +++ b/src/elevenlabs/conversational_ai/default_audio_interface.py @@ -1,8 +1,9 @@ -from typing import Callable +from typing import Callable, Awaitable import queue import threading +import asyncio -from .conversation import AudioInterface +from .conversation import AudioInterface, AsyncAudioInterface class DefaultAudioInterface(AudioInterface): @@ -81,3 +82,92 @@ def _in_callback(self, in_data, frame_count, time_info, status): if self.input_callback: self.input_callback(in_data) return (None, self.pyaudio.paContinue) + + +class AsyncDefaultAudioInterface(AsyncAudioInterface): + INPUT_FRAMES_PER_BUFFER = 4000 # 250ms @ 16kHz + OUTPUT_FRAMES_PER_BUFFER = 1000 # 62.5ms @ 16kHz + + def __init__(self): + try: + import pyaudio + except ImportError: + raise ImportError("To use AsyncDefaultAudioInterface you must install pyaudio.") + self.pyaudio = pyaudio + + async def start(self, input_callback: Callable[[bytes], Awaitable[None]]): + # Audio input is using callbacks from pyaudio which we adapt to async + self.input_callback = input_callback + + # Audio output is buffered so we can handle interruptions. + # Start a separate task to handle writing to the output stream. + self.output_queue: asyncio.Queue[bytes] = asyncio.Queue() + self.should_stop = asyncio.Event() + + self.p = self.pyaudio.PyAudio() + self.in_stream = self.p.open( + format=self.pyaudio.paInt16, + channels=1, + rate=16000, + input=True, + stream_callback=self._in_callback, + frames_per_buffer=self.INPUT_FRAMES_PER_BUFFER, + start=True, + ) + self.out_stream = self.p.open( + format=self.pyaudio.paInt16, + channels=1, + rate=16000, + output=True, + frames_per_buffer=self.OUTPUT_FRAMES_PER_BUFFER, + start=True, + ) + + # Start the output task + self.output_task = asyncio.create_task(self._output_task()) + + async def stop(self): + self.should_stop.set() + await self.output_task + self.in_stream.stop_stream() + self.in_stream.close() + self.out_stream.close() + self.p.terminate() + + async def output(self, audio: bytes): + await self.output_queue.put(audio) + + async def interrupt(self): + # Clear the output queue to stop any audio that is currently playing. + try: + while True: + try: + _ = self.output_queue.get_nowait() + except asyncio.QueueEmpty: + break + except AttributeError: + # In Python 3.8, it's asyncio.QueueEmpty, in 3.10+ it's asyncio.QueueEmpty + while not self.output_queue.empty(): + try: + _ = self.output_queue.get_nowait() + except: + break + + async def _output_task(self): + while not self.should_stop.is_set(): + try: + audio = await asyncio.wait_for(self.output_queue.get(), timeout=0.25) + self.out_stream.write(audio) + except asyncio.TimeoutError: + pass + + def _in_callback(self, in_data, frame_count, time_info, status): + if self.input_callback: + # Schedule the async callback to run in the event loop + try: + loop = asyncio.get_event_loop() + asyncio.run_coroutine_threadsafe(self.input_callback(in_data), loop) + except RuntimeError: + # No event loop running, ignore + pass + return (None, self.pyaudio.paContinue) diff --git a/tests/test_async_convai.py b/tests/test_async_convai.py new file mode 100644 index 00000000..72de00d2 --- /dev/null +++ b/tests/test_async_convai.py @@ -0,0 +1,366 @@ +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch +import pytest + +from elevenlabs.conversational_ai.conversation import ( + AsyncConversation, + AsyncAudioInterface, + ConversationInitiationData, +) + + +class MockAsyncAudioInterface(AsyncAudioInterface): + async def start(self, input_callback): + print("Async audio interface started") + self.input_callback = input_callback + + async def stop(self): + print("Async audio interface stopped") + + async def output(self, audio): + print(f"Would play audio of length: {len(audio)} bytes") + + async def interrupt(self): + print("Async audio interrupted") + + +# Add test constants and helpers at module level +TEST_CONVERSATION_ID = "test123" +TEST_AGENT_ID = "test_agent" + + +def create_mock_async_websocket(messages=None): + """Helper to create a mock async websocket with predefined responses""" + mock_ws = AsyncMock() + + if messages is None: + messages = [ + { + "type": "conversation_initiation_metadata", + "conversation_initiation_metadata_event": {"conversation_id": TEST_CONVERSATION_ID}, + }, + {"type": "agent_response", "agent_response_event": {"agent_response": "Hello there!"}}, + ] + + # Convert messages to JSON strings + json_messages = [json.dumps(msg) for msg in messages] + json_messages.extend(['{"type": "keep_alive"}'] * 10) # Add some keep-alive messages + + # Create an iterator + message_iter = iter(json_messages) + + async def mock_recv(): + try: + return next(message_iter) + except StopIteration: + # Simulate connection close after messages + raise asyncio.TimeoutError() + + mock_ws.recv = mock_recv + return mock_ws + + +@pytest.mark.asyncio +async def test_async_conversation_basic_flow(): + # Mock setup + mock_ws = create_mock_async_websocket() + mock_client = MagicMock() + agent_response_callback = AsyncMock() + test_user_id = "test_user_123" + + # Setup the conversation + config = ConversationInitiationData(user_id=test_user_id) + conversation = AsyncConversation( + client=mock_client, + agent_id=TEST_AGENT_ID, + config=config, + requires_auth=False, + audio_interface=MockAsyncAudioInterface(), + callback_agent_response=agent_response_callback, + ) + + # Run the test + with patch("elevenlabs.conversational_ai.conversation.websockets.connect") as mock_connect: + mock_connect.return_value.__aenter__.return_value = mock_ws + + await conversation.start_session() + + # Wait a bit for the callback to be called + await asyncio.sleep(0.1) + + await conversation.end_session() + await conversation.wait_for_session_end() + + # Assertions - check the call was made with the right structure + send_calls = [call[0][0] for call in mock_ws.send.call_args_list] + init_messages = [json.loads(call) for call in send_calls if 'conversation_initiation_client_data' in call] + assert len(init_messages) == 1 + init_message = init_messages[0] + + assert init_message["type"] == "conversation_initiation_client_data" + assert init_message["custom_llm_extra_body"] == {} + assert init_message["conversation_config_override"] == {} + assert init_message["dynamic_variables"] == {} + assert init_message["source_info"]["source"] == "python_sdk" + assert "version" in init_message["source_info"] + assert init_message["user_id"] == test_user_id + agent_response_callback.assert_called_once_with("Hello there!") + assert conversation._conversation_id == TEST_CONVERSATION_ID + assert conversation.config.user_id == test_user_id + + +@pytest.mark.asyncio +async def test_async_conversation_with_auth(): + # Mock setup + mock_client = MagicMock() + mock_client.conversational_ai.conversations.get_signed_url.return_value.signed_url = "wss://signed.url" + mock_ws = create_mock_async_websocket( + [ + { + "type": "conversation_initiation_metadata", + "conversation_initiation_metadata_event": {"conversation_id": TEST_CONVERSATION_ID}, + } + ] + ) + + conversation = AsyncConversation( + client=mock_client, + agent_id=TEST_AGENT_ID, + requires_auth=True, + audio_interface=MockAsyncAudioInterface(), + ) + + # Run the test + with patch("elevenlabs.conversational_ai.conversation.websockets.connect") as mock_connect: + mock_connect.return_value.__aenter__.return_value = mock_ws + + await conversation.start_session() + await conversation.end_session() + await conversation.wait_for_session_end() + + # Assertions + mock_client.conversational_ai.conversations.get_signed_url.assert_called_once_with(agent_id=TEST_AGENT_ID) + + +@pytest.mark.asyncio +async def test_async_conversation_with_dynamic_variables(): + # Mock setup + mock_ws = create_mock_async_websocket() + mock_client = MagicMock() + agent_response_callback = AsyncMock() + + dynamic_variables = {"name": "angelo"} + config = ConversationInitiationData(dynamic_variables=dynamic_variables) + + # Setup the conversation + conversation = AsyncConversation( + client=mock_client, + config=config, + agent_id=TEST_AGENT_ID, + requires_auth=False, + audio_interface=MockAsyncAudioInterface(), + callback_agent_response=agent_response_callback, + ) + + # Run the test + with patch("elevenlabs.conversational_ai.conversation.websockets.connect") as mock_connect: + mock_connect.return_value.__aenter__.return_value = mock_ws + + await conversation.start_session() + + # Wait a bit for the callback to be called + await asyncio.sleep(0.1) + + await conversation.end_session() + await conversation.wait_for_session_end() + + # Assertions - check the call was made with the right structure + send_calls = [call[0][0] for call in mock_ws.send.call_args_list] + init_messages = [json.loads(call) for call in send_calls if 'conversation_initiation_client_data' in call] + assert len(init_messages) == 1 + init_message = init_messages[0] + + assert init_message["type"] == "conversation_initiation_client_data" + assert init_message["custom_llm_extra_body"] == {} + assert init_message["conversation_config_override"] == {} + assert init_message["dynamic_variables"] == {"name": "angelo"} + assert init_message["source_info"]["source"] == "python_sdk" + assert "version" in init_message["source_info"] + agent_response_callback.assert_called_once_with("Hello there!") + assert conversation._conversation_id == TEST_CONVERSATION_ID + + +@pytest.mark.asyncio +async def test_async_conversation_with_contextual_update(): + # Mock setup + mock_ws = create_mock_async_websocket([]) + mock_client = MagicMock() + + # Setup the conversation + conversation = AsyncConversation( + client=mock_client, + agent_id=TEST_AGENT_ID, + requires_auth=False, + audio_interface=MockAsyncAudioInterface(), + ) + + # Run the test + with patch("elevenlabs.conversational_ai.conversation.websockets.connect") as mock_connect: + mock_connect.return_value.__aenter__.return_value = mock_ws + + await conversation.start_session() + await asyncio.sleep(0.1) + + await conversation.send_contextual_update("User appears to be looking at pricing page") + + # Teardown + await conversation.end_session() + await conversation.wait_for_session_end() + + # Assertions + expected = json.dumps({"type": "contextual_update", "text": "User appears to be looking at pricing page"}) + mock_ws.send.assert_any_call(expected) + + +@pytest.mark.asyncio +async def test_async_conversation_send_user_message(): + # Mock setup + mock_ws = create_mock_async_websocket([]) + mock_client = MagicMock() + + # Setup the conversation + conversation = AsyncConversation( + client=mock_client, + agent_id=TEST_AGENT_ID, + requires_auth=False, + audio_interface=MockAsyncAudioInterface(), + ) + + # Run the test + with patch("elevenlabs.conversational_ai.conversation.websockets.connect") as mock_connect: + mock_connect.return_value.__aenter__.return_value = mock_ws + + await conversation.start_session() + await asyncio.sleep(0.1) + + await conversation.send_user_message("Hello, how are you?") + + # Teardown + await conversation.end_session() + await conversation.wait_for_session_end() + + # Assertions + expected = json.dumps({"type": "user_message", "text": "Hello, how are you?"}) + mock_ws.send.assert_any_call(expected) + + +@pytest.mark.asyncio +async def test_async_conversation_register_user_activity(): + # Mock setup + mock_ws = create_mock_async_websocket([]) + mock_client = MagicMock() + + # Setup the conversation + conversation = AsyncConversation( + client=mock_client, + agent_id=TEST_AGENT_ID, + requires_auth=False, + audio_interface=MockAsyncAudioInterface(), + ) + + # Run the test + with patch("elevenlabs.conversational_ai.conversation.websockets.connect") as mock_connect: + mock_connect.return_value.__aenter__.return_value = mock_ws + + await conversation.start_session() + await asyncio.sleep(0.1) + + await conversation.register_user_activity() + + # Teardown + await conversation.end_session() + await conversation.wait_for_session_end() + + # Assertions + expected = json.dumps({"type": "user_activity"}) + mock_ws.send.assert_any_call(expected) + + +@pytest.mark.asyncio +async def test_async_conversation_callback_flows(): + # Mock setup for testing all callback types + messages = [ + { + "type": "conversation_initiation_metadata", + "conversation_initiation_metadata_event": {"conversation_id": TEST_CONVERSATION_ID}, + }, + {"type": "agent_response", "agent_response_event": {"agent_response": "Hello there!"}}, + { + "type": "agent_response_correction", + "agent_response_correction_event": { + "original_agent_response": "Hello ther!", + "corrected_agent_response": "Hello there!" + } + }, + { + "type": "user_transcript", + "user_transcription_event": {"user_transcript": "Hi, how are you?"} + }, + { + "type": "ping", + "ping_event": {"event_id": "123", "ping_ms": 50} + }, + { + "type": "interruption", + "interruption_event": {"event_id": "456"} + }, + { + "type": "audio", + "audio_event": {"event_id": "789", "audio_base_64": "dGVzdA=="} # "test" in base64 + } + ] + + mock_ws = create_mock_async_websocket(messages) + mock_client = MagicMock() + + # Setup callbacks + agent_response_callback = AsyncMock() + agent_response_correction_callback = AsyncMock() + user_transcript_callback = AsyncMock() + latency_measurement_callback = AsyncMock() + end_session_callback = AsyncMock() + + # Setup the conversation + conversation = AsyncConversation( + client=mock_client, + agent_id=TEST_AGENT_ID, + requires_auth=False, + audio_interface=MockAsyncAudioInterface(), + callback_agent_response=agent_response_callback, + callback_agent_response_correction=agent_response_correction_callback, + callback_user_transcript=user_transcript_callback, + callback_latency_measurement=latency_measurement_callback, + callback_end_session=end_session_callback, + ) + + # Run the test + with patch("elevenlabs.conversational_ai.conversation.websockets.connect") as mock_connect: + mock_connect.return_value.__aenter__.return_value = mock_ws + + await conversation.start_session() + + # Wait for callbacks to be processed + await asyncio.sleep(0.2) + + await conversation.end_session() + await conversation.wait_for_session_end() + + # Assertions + agent_response_callback.assert_called_with("Hello there!") + agent_response_correction_callback.assert_called_with("Hello ther!", "Hello there!") + user_transcript_callback.assert_called_with("Hi, how are you?") + latency_measurement_callback.assert_called_with(50) + end_session_callback.assert_called_once() + assert conversation._conversation_id == TEST_CONVERSATION_ID + assert conversation._last_interrupt_id == 456