|
2 | 2 | import base64 |
3 | 3 | import json |
4 | 4 | import threading |
5 | | -from typing import Callable, Optional, Awaitable, Union, Any |
| 5 | +from typing import Callable, Optional, Awaitable, Union, Any, Literal |
6 | 6 | import asyncio |
7 | 7 | from concurrent.futures import ThreadPoolExecutor |
| 8 | +from enum import StrEnum |
8 | 9 |
|
9 | | -from websockets.sync.client import connect |
| 10 | +from websockets.sync.client import connect, ClientConnection |
10 | 11 | from websockets.exceptions import ConnectionClosedOK |
11 | 12 |
|
12 | 13 | from ..base_client import BaseElevenLabs |
13 | 14 |
|
14 | 15 |
|
| 16 | +class ClientToOrchestratorEvent(StrEnum): |
| 17 | + """Event types that can be sent from client to orchestrator.""" |
| 18 | + # Response to a ping request. |
| 19 | + PONG = "pong" |
| 20 | + CLIENT_TOOL_RESULT = "client_tool_result" |
| 21 | + CONVERSATION_INITIATION_CLIENT_DATA = "conversation_initiation_client_data" |
| 22 | + FEEDBACK = "feedback" |
| 23 | + # Non-interrupting content that is sent to the server to update the conversation state. |
| 24 | + CONTEXTUAL_UPDATE = "contextual_update" |
| 25 | + # User text message. |
| 26 | + USER_MESSAGE = "user_message" |
| 27 | + USER_ACTIVITY = "user_activity" |
| 28 | + |
| 29 | + |
| 30 | +class UserMessageClientToOrchestratorEvent: |
| 31 | + """Event for sending user text messages.""" |
| 32 | + |
| 33 | + def __init__(self, text: Optional[str] = None): |
| 34 | + self.type: Literal[ClientToOrchestratorEvent.USER_MESSAGE] = ClientToOrchestratorEvent.USER_MESSAGE |
| 35 | + self.text = text |
| 36 | + |
| 37 | + def to_dict(self) -> dict: |
| 38 | + return { |
| 39 | + "type": self.type, |
| 40 | + "text": self.text |
| 41 | + } |
| 42 | + |
| 43 | + |
| 44 | +class UserActivityClientToOrchestratorEvent: |
| 45 | + """Event for registering user activity (ping to prevent timeout).""" |
| 46 | + |
| 47 | + def __init__(self): |
| 48 | + self.type: Literal[ClientToOrchestratorEvent.USER_ACTIVITY] = ClientToOrchestratorEvent.USER_ACTIVITY |
| 49 | + |
| 50 | + def to_dict(self) -> dict: |
| 51 | + return { |
| 52 | + "type": self.type |
| 53 | + } |
| 54 | + |
| 55 | + |
| 56 | +class ContextualUpdateClientToOrchestratorEvent: |
| 57 | + """Event for sending non-interrupting contextual updates to the conversation state.""" |
| 58 | + |
| 59 | + def __init__(self, content: str): |
| 60 | + self.type: Literal[ClientToOrchestratorEvent.CONTEXTUAL_UPDATE] = ClientToOrchestratorEvent.CONTEXTUAL_UPDATE |
| 61 | + self.content = content |
| 62 | + |
| 63 | + def to_dict(self) -> dict: |
| 64 | + return { |
| 65 | + "type": self.type, |
| 66 | + "content": self.content |
| 67 | + } |
| 68 | + |
| 69 | + |
15 | 70 | class AudioInterface(ABC): |
16 | 71 | """AudioInterface provides an abstraction for handling audio input and output.""" |
17 | 72 |
|
@@ -193,6 +248,7 @@ class Conversation: |
193 | 248 | _should_stop: threading.Event |
194 | 249 | _conversation_id: Optional[str] |
195 | 250 | _last_interrupt_id: int |
| 251 | + _ws: Optional[ClientConnection] |
196 | 252 |
|
197 | 253 | def __init__( |
198 | 254 | self, |
@@ -243,6 +299,7 @@ def __init__( |
243 | 299 | self._should_stop = threading.Event() |
244 | 300 | self._conversation_id = None |
245 | 301 | self._last_interrupt_id = 0 |
| 302 | + self._ws = None |
246 | 303 |
|
247 | 304 | def start_session(self): |
248 | 305 | """Starts the conversation session. |
@@ -271,8 +328,68 @@ def wait_for_session_end(self) -> Optional[str]: |
271 | 328 | self._thread.join() |
272 | 329 | return self._conversation_id |
273 | 330 |
|
| 331 | + def send_user_message(self, text: str): |
| 332 | + """Send a text message from the user to the agent. |
| 333 | + |
| 334 | + Args: |
| 335 | + text: The text message to send to the agent. |
| 336 | + |
| 337 | + Raises: |
| 338 | + RuntimeError: If the session is not active or websocket is not connected. |
| 339 | + """ |
| 340 | + if not self._ws: |
| 341 | + raise RuntimeError("Session not started or websocket not connected.") |
| 342 | + |
| 343 | + event = UserMessageClientToOrchestratorEvent(text=text) |
| 344 | + try: |
| 345 | + self._ws.send(json.dumps(event.to_dict())) |
| 346 | + except Exception as e: |
| 347 | + print(f"Error sending user message: {e}") |
| 348 | + raise |
| 349 | + |
| 350 | + def register_user_activity(self): |
| 351 | + """Register user activity to prevent session timeout. |
| 352 | + |
| 353 | + This sends a ping to the orchestrator to reset the timeout timer. |
| 354 | + |
| 355 | + Raises: |
| 356 | + RuntimeError: If the session is not active or websocket is not connected. |
| 357 | + """ |
| 358 | + if not self._ws: |
| 359 | + raise RuntimeError("Session not started or websocket not connected.") |
| 360 | + |
| 361 | + event = UserActivityClientToOrchestratorEvent() |
| 362 | + try: |
| 363 | + self._ws.send(json.dumps(event.to_dict())) |
| 364 | + except Exception as e: |
| 365 | + print(f"Error registering user activity: {e}") |
| 366 | + raise |
| 367 | + |
| 368 | + def send_contextual_update(self, content: str): |
| 369 | + """Send a contextual update to the conversation. |
| 370 | + |
| 371 | + Contextual updates are non-interrupting content that is sent to the server |
| 372 | + to update the conversation state without directly prompting the agent. |
| 373 | + |
| 374 | + Args: |
| 375 | + content: The contextual information to send to the conversation. |
| 376 | + |
| 377 | + Raises: |
| 378 | + RuntimeError: If the session is not active or websocket is not connected. |
| 379 | + """ |
| 380 | + if not self._ws: |
| 381 | + raise RuntimeError("Session not started or websocket not connected.") |
| 382 | + |
| 383 | + event = ContextualUpdateClientToOrchestratorEvent(content=content) |
| 384 | + try: |
| 385 | + self._ws.send(json.dumps(event.to_dict())) |
| 386 | + except Exception as e: |
| 387 | + print(f"Error sending contextual update: {e}") |
| 388 | + raise |
| 389 | + |
274 | 390 | def _run(self, ws_url: str): |
275 | 391 | with connect(ws_url, max_size=16 * 1024 * 1024) as ws: |
| 392 | + self._ws = ws |
276 | 393 | ws.send( |
277 | 394 | json.dumps( |
278 | 395 | { |
@@ -313,6 +430,8 @@ def input_callback(audio): |
313 | 430 | except Exception as e: |
314 | 431 | print(f"Error receiving message: {e}") |
315 | 432 | self.end_session() |
| 433 | + |
| 434 | + self._ws = None |
316 | 435 |
|
317 | 436 | def _handle_message(self, message, ws): |
318 | 437 | if message["type"] == "conversation_initiation_metadata": |
|
0 commit comments