Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 126 additions & 15 deletions src/elevenlabs/conversational_ai/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,71 @@
import base64
import json
import threading
from typing import Callable, Optional, Awaitable, Union, Any
from typing import Callable, Optional, Awaitable, Union, Any, Literal, Dict, Tuple
import asyncio
from concurrent.futures import ThreadPoolExecutor
from enum import Enum

from websockets.sync.client import connect, ClientConnection
from websockets.sync.client import connect, Connection
from websockets.exceptions import ConnectionClosedOK

from ..base_client import BaseElevenLabs


class ClientToOrchestratorEvent(str, Enum):
"""Event types that can be sent from client to orchestrator."""
# Response to a ping request.
PONG = "pong"
CLIENT_TOOL_RESULT = "client_tool_result"
CONVERSATION_INITIATION_CLIENT_DATA = "conversation_initiation_client_data"
FEEDBACK = "feedback"
# Non-interrupting content that is sent to the server to update the conversation state.
CONTEXTUAL_UPDATE = "contextual_update"
# User text message.
USER_MESSAGE = "user_message"
USER_ACTIVITY = "user_activity"


class UserMessageClientToOrchestratorEvent:
"""Event for sending user text messages."""

def __init__(self, text: Optional[str] = None):
self.type: Literal[ClientToOrchestratorEvent.USER_MESSAGE] = ClientToOrchestratorEvent.USER_MESSAGE
self.text = text

def to_dict(self) -> dict:
return {
"type": self.type,
"text": self.text
}


class UserActivityClientToOrchestratorEvent:
"""Event for registering user activity (ping to prevent timeout)."""

def __init__(self) -> None:
self.type: Literal[ClientToOrchestratorEvent.USER_ACTIVITY] = ClientToOrchestratorEvent.USER_ACTIVITY

def to_dict(self) -> dict:
return {
"type": self.type
}


class ContextualUpdateClientToOrchestratorEvent:
"""Event for sending non-interrupting contextual updates to the conversation state."""

def __init__(self, content: str):
self.type: Literal[ClientToOrchestratorEvent.CONTEXTUAL_UPDATE] = ClientToOrchestratorEvent.CONTEXTUAL_UPDATE
self.content = content

def to_dict(self) -> dict:
return {
"type": self.type,
"content": self.content
}


class AudioInterface(ABC):
"""AudioInterface provides an abstraction for handling audio input and output."""

Expand Down Expand Up @@ -63,8 +118,8 @@ class ClientTools:
ensuring non-blocking operation of the main conversation thread.
"""

def __init__(self):
self.tools: dict[str, tuple[Union[Callable[[dict], Any], Callable[[dict], Awaitable[Any]]], bool]] = {}
def __init__(self) -> None:
self.tools: Dict[str, Tuple[Union[Callable[[dict], Any], Callable[[dict], Awaitable[Any]]], bool]] = {}
self.lock = threading.Lock()
self._loop = None
self._thread = None
Expand Down Expand Up @@ -141,6 +196,9 @@ def execute_tool(self, tool_name: str, parameters: dict, callback: Callable[[dic
"""
if not self._running.is_set():
raise RuntimeError("ClientTools event loop is not running")

if self._loop is None:
raise RuntimeError("Event loop is not available")

async def _execute_and_callback():
try:
Expand Down Expand Up @@ -193,6 +251,7 @@ class Conversation:
_should_stop: threading.Event
_conversation_id: Optional[str]
_last_interrupt_id: int
_ws: Optional[Connection]

def __init__(
self,
Expand Down Expand Up @@ -240,7 +299,7 @@ def __init__(
self.client_tools.start()

self._thread = None
self._ws: Optional[ClientConnection] = None
self._ws: Optional[Connection] = None
self._should_stop = threading.Event()
self._conversation_id = None
self._last_interrupt_id = 0
Expand Down Expand Up @@ -273,8 +332,68 @@ def wait_for_session_end(self) -> Optional[str]:
self._thread.join()
return self._conversation_id

def send_user_message(self, text: str):
"""Send a text message from the user to the agent.

Args:
text: The text message to send to the agent.

Raises:
RuntimeError: If the session is not active or websocket is not connected.
"""
if not self._ws:
raise RuntimeError("Session not started or websocket not connected.")

event = UserMessageClientToOrchestratorEvent(text=text)
try:
self._ws.send(json.dumps(event.to_dict()))
except Exception as e:
print(f"Error sending user message: {e}")
raise

def register_user_activity(self):
"""Register user activity to prevent session timeout.

This sends a ping to the orchestrator to reset the timeout timer.

Raises:
RuntimeError: If the session is not active or websocket is not connected.
"""
if not self._ws:
raise RuntimeError("Session not started or websocket not connected.")

event = UserActivityClientToOrchestratorEvent()
try:
self._ws.send(json.dumps(event.to_dict()))
except Exception as e:
print(f"Error registering user activity: {e}")
raise

def send_contextual_update(self, content: str):
"""Send a contextual update to the conversation.

Contextual updates are non-interrupting content that is sent to the server
to update the conversation state without directly prompting the agent.

Args:
content: The contextual information to send to the conversation.

Raises:
RuntimeError: If the session is not active or websocket is not connected.
"""
if not self._ws:
raise RuntimeError("Session not started or websocket not connected.")

event = ContextualUpdateClientToOrchestratorEvent(content=content)
try:
self._ws.send(json.dumps(event.to_dict()))
except Exception as e:
print(f"Error sending contextual update: {e}")
raise

def _run(self, ws_url: str):
with connect(ws_url, max_size=16 * 1024 * 1024) as ws:
self._ws = ws
ws.send(
json.dumps(
{
Expand Down Expand Up @@ -316,6 +435,8 @@ def input_callback(audio):
except Exception as e:
print(f"Error receiving message: {e}")
self.end_session()

self._ws = None

def _handle_message(self, message, ws):
if message["type"] == "conversation_initiation_metadata":
Expand Down Expand Up @@ -372,16 +493,6 @@ def send_response(response):
else:
pass # Ignore all other message types.

def send_contextual_update(self, text: str):
if not self._ws:
raise RuntimeError("WebSocket is not connected")

payload = {
"type": "contextual_update",
"text": text,
}
self._ws.send(json.dumps(payload))

def _get_wss_url(self):
base_ws_url = self.client._client_wrapper.get_environment().wss
return f"{base_ws_url}/v1/convai/conversation?agent_id={self.agent_id}"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_convai.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,5 +191,5 @@ def test_conversation_with_contextual_update():
conversation.wait_for_session_end()

# Assertions
expected = json.dumps({"type": "contextual_update", "text": "User appears to be looking at pricing page"})
expected = json.dumps({"type": "contextual_update", "content": "User appears to be looking at pricing page"})
mock_ws.send.assert_any_call(expected)