diff --git a/src/elevenlabs/conversational_ai/conversation.py b/src/elevenlabs/conversational_ai/conversation.py index 16cb7e7e..497dc8ea 100644 --- a/src/elevenlabs/conversational_ai/conversation.py +++ b/src/elevenlabs/conversational_ai/conversation.py @@ -2,16 +2,71 @@ import base64 import json import threading -from typing import Callable, Optional, Awaitable, Union, Any +from typing import Callable, Optional, Awaitable, Union, Any, Literal, Dict, Tuple import asyncio from concurrent.futures import ThreadPoolExecutor +from enum import Enum -from websockets.sync.client import connect, ClientConnection +from websockets.sync.client import connect, Connection from websockets.exceptions import ConnectionClosedOK from ..base_client import BaseElevenLabs +class ClientToOrchestratorEvent(str, Enum): + """Event types that can be sent from client to orchestrator.""" + # Response to a ping request. + PONG = "pong" + CLIENT_TOOL_RESULT = "client_tool_result" + CONVERSATION_INITIATION_CLIENT_DATA = "conversation_initiation_client_data" + FEEDBACK = "feedback" + # Non-interrupting content that is sent to the server to update the conversation state. + CONTEXTUAL_UPDATE = "contextual_update" + # User text message. + USER_MESSAGE = "user_message" + USER_ACTIVITY = "user_activity" + + +class UserMessageClientToOrchestratorEvent: + """Event for sending user text messages.""" + + def __init__(self, text: Optional[str] = None): + self.type: Literal[ClientToOrchestratorEvent.USER_MESSAGE] = ClientToOrchestratorEvent.USER_MESSAGE + self.text = text + + def to_dict(self) -> dict: + return { + "type": self.type, + "text": self.text + } + + +class UserActivityClientToOrchestratorEvent: + """Event for registering user activity (ping to prevent timeout).""" + + def __init__(self) -> None: + self.type: Literal[ClientToOrchestratorEvent.USER_ACTIVITY] = ClientToOrchestratorEvent.USER_ACTIVITY + + def to_dict(self) -> dict: + return { + "type": self.type + } + + +class ContextualUpdateClientToOrchestratorEvent: + """Event for sending non-interrupting contextual updates to the conversation state.""" + + def __init__(self, content: str): + self.type: Literal[ClientToOrchestratorEvent.CONTEXTUAL_UPDATE] = ClientToOrchestratorEvent.CONTEXTUAL_UPDATE + self.content = content + + def to_dict(self) -> dict: + return { + "type": self.type, + "content": self.content + } + + class AudioInterface(ABC): """AudioInterface provides an abstraction for handling audio input and output.""" @@ -63,8 +118,8 @@ class ClientTools: ensuring non-blocking operation of the main conversation thread. """ - def __init__(self): - self.tools: dict[str, tuple[Union[Callable[[dict], Any], Callable[[dict], Awaitable[Any]]], bool]] = {} + def __init__(self) -> None: + self.tools: Dict[str, Tuple[Union[Callable[[dict], Any], Callable[[dict], Awaitable[Any]]], bool]] = {} self.lock = threading.Lock() self._loop = None self._thread = None @@ -141,6 +196,9 @@ def execute_tool(self, tool_name: str, parameters: dict, callback: Callable[[dic """ if not self._running.is_set(): raise RuntimeError("ClientTools event loop is not running") + + if self._loop is None: + raise RuntimeError("Event loop is not available") async def _execute_and_callback(): try: @@ -193,6 +251,7 @@ class Conversation: _should_stop: threading.Event _conversation_id: Optional[str] _last_interrupt_id: int + _ws: Optional[Connection] def __init__( self, @@ -240,7 +299,7 @@ def __init__( self.client_tools.start() self._thread = None - self._ws: Optional[ClientConnection] = None + self._ws: Optional[Connection] = None self._should_stop = threading.Event() self._conversation_id = None self._last_interrupt_id = 0 @@ -273,8 +332,68 @@ def wait_for_session_end(self) -> Optional[str]: self._thread.join() return self._conversation_id + def send_user_message(self, text: str): + """Send a text message from the user to the agent. + + Args: + text: The text message to send to the agent. + + Raises: + RuntimeError: If the session is not active or websocket is not connected. + """ + if not self._ws: + raise RuntimeError("Session not started or websocket not connected.") + + event = UserMessageClientToOrchestratorEvent(text=text) + try: + self._ws.send(json.dumps(event.to_dict())) + except Exception as e: + print(f"Error sending user message: {e}") + raise + + def register_user_activity(self): + """Register user activity to prevent session timeout. + + This sends a ping to the orchestrator to reset the timeout timer. + + Raises: + RuntimeError: If the session is not active or websocket is not connected. + """ + if not self._ws: + raise RuntimeError("Session not started or websocket not connected.") + + event = UserActivityClientToOrchestratorEvent() + try: + self._ws.send(json.dumps(event.to_dict())) + except Exception as e: + print(f"Error registering user activity: {e}") + raise + + def send_contextual_update(self, content: str): + """Send a contextual update to the conversation. + + Contextual updates are non-interrupting content that is sent to the server + to update the conversation state without directly prompting the agent. + + Args: + content: The contextual information to send to the conversation. + + Raises: + RuntimeError: If the session is not active or websocket is not connected. + """ + if not self._ws: + raise RuntimeError("Session not started or websocket not connected.") + + event = ContextualUpdateClientToOrchestratorEvent(content=content) + try: + self._ws.send(json.dumps(event.to_dict())) + except Exception as e: + print(f"Error sending contextual update: {e}") + raise + def _run(self, ws_url: str): with connect(ws_url, max_size=16 * 1024 * 1024) as ws: + self._ws = ws ws.send( json.dumps( { @@ -316,6 +435,8 @@ def input_callback(audio): except Exception as e: print(f"Error receiving message: {e}") self.end_session() + + self._ws = None def _handle_message(self, message, ws): if message["type"] == "conversation_initiation_metadata": @@ -372,16 +493,6 @@ def send_response(response): else: pass # Ignore all other message types. - def send_contextual_update(self, text: str): - if not self._ws: - raise RuntimeError("WebSocket is not connected") - - payload = { - "type": "contextual_update", - "text": text, - } - self._ws.send(json.dumps(payload)) - def _get_wss_url(self): base_ws_url = self.client._client_wrapper.get_environment().wss return f"{base_ws_url}/v1/convai/conversation?agent_id={self.agent_id}" diff --git a/tests/test_convai.py b/tests/test_convai.py index c57135f1..65a12cae 100644 --- a/tests/test_convai.py +++ b/tests/test_convai.py @@ -191,5 +191,5 @@ def test_conversation_with_contextual_update(): conversation.wait_for_session_end() # Assertions - expected = json.dumps({"type": "contextual_update", "text": "User appears to be looking at pricing page"}) + expected = json.dumps({"type": "contextual_update", "content": "User appears to be looking at pricing page"}) mock_ws.send.assert_any_call(expected)