Skip to content

Commit d9de46c

Browse files
committed
fixes
1 parent 03d98f1 commit d9de46c

2 files changed

Lines changed: 149 additions & 87 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name = "elevenlabs"
33

44
[tool.poetry]
55
name = "elevenlabs"
6-
version = "v2.10.0"
6+
version = "v2.11.0"
77
description = ""
88
readme = "README.md"
99
authors = []

src/elevenlabs/conversational_ai/conversation.py

Lines changed: 148 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
1-
from abc import ABC, abstractmethod
1+
import asyncio
22
import base64
33
import json
44
import threading
5-
from typing import Callable, Optional, Awaitable, Union, Any, Literal, Dict, Tuple
6-
import asyncio
5+
from abc import ABC, abstractmethod
76
from concurrent.futures import ThreadPoolExecutor
87
from enum import Enum
8+
from typing import Any, Awaitable, Callable, Dict, Literal, Optional, Protocol, Tuple, Union
99

10-
from websockets.sync.client import connect, Connection
1110
import websockets
12-
from websockets.exceptions import ConnectionClosedOK
13-
1411
from ..base_client import BaseElevenLabs
1512
from ..version import __version__
13+
from websockets.exceptions import ConnectionClosedOK
14+
from websockets.sync.client import Connection, connect
1615

1716

1817
class ClientToOrchestratorEvent(str, Enum):
@@ -276,6 +275,71 @@ def __init__(
276275
self.user_id = user_id
277276

278277

278+
class MessageHandler(Protocol):
279+
"""Protocol defining the interface for message handlers."""
280+
281+
callback_agent_response: Optional[Callable]
282+
callback_agent_response_correction: Optional[Callable]
283+
callback_user_transcript: Optional[Callable]
284+
callback_latency_measurement: Optional[Callable]
285+
286+
def handle_audio_output(self, audio: bytes) -> Union[None, Awaitable[None]]:
287+
"""Handle audio output."""
288+
...
289+
290+
def handle_agent_response(self, response: str) -> Union[None, Awaitable[None]]:
291+
"""Handle agent response."""
292+
...
293+
294+
def handle_agent_response_correction(self, original: str, corrected: str) -> Union[None, Awaitable[None]]:
295+
"""Handle agent response correction."""
296+
...
297+
298+
def handle_user_transcript(self, transcript: str) -> Union[None, Awaitable[None]]:
299+
"""Handle user transcript."""
300+
...
301+
302+
def handle_interruption(self) -> Union[None, Awaitable[None]]:
303+
"""Handle interruption."""
304+
...
305+
306+
def handle_ping(self, event: Dict[str, Any]) -> Union[None, Awaitable[None]]:
307+
"""Handle ping event."""
308+
...
309+
310+
def handle_latency_measurement(self, latency: int) -> Union[None, Awaitable[None]]:
311+
"""Handle latency measurement."""
312+
...
313+
314+
def handle_client_tool_call(self, tool_name: str, parameters: Dict[str, Any]) -> None:
315+
"""Handle client tool call."""
316+
...
317+
318+
319+
class BaseMessageHandler:
320+
"""Base implementation for message handlers with common functionality."""
321+
322+
def __init__(self, conversation, ws_or_websocket):
323+
self.conversation = conversation
324+
self.ws = ws_or_websocket
325+
self.callback_agent_response = conversation.callback_agent_response
326+
self.callback_agent_response_correction = conversation.callback_agent_response_correction
327+
self.callback_user_transcript = conversation.callback_user_transcript
328+
self.callback_latency_measurement = conversation.callback_latency_measurement
329+
330+
def handle_client_tool_call(self, tool_name: str, parameters: Dict[str, Any]) -> None:
331+
"""Handle client tool call - common implementation for both sync and async."""
332+
def send_response(response):
333+
if not self.conversation._should_stop.is_set():
334+
self._send_response(response)
335+
336+
self.conversation.client_tools.execute_tool(tool_name, parameters, send_response)
337+
338+
def _send_response(self, response: Dict[str, Any]) -> None:
339+
"""Send response - to be implemented by subclasses."""
340+
raise NotImplementedError
341+
342+
279343
class BaseConversation:
280344
"""Base class for conversation implementations with shared parameters and logic."""
281345

@@ -300,6 +364,42 @@ def __init__(
300364

301365
self._conversation_id = None
302366
self._last_interrupt_id = 0
367+
368+
def _create_sync_audio_callback(self, ws) -> Callable[[bytes], None]:
369+
"""Create sync audio input callback."""
370+
def callback(audio: bytes) -> None:
371+
try:
372+
ws.send(
373+
json.dumps({
374+
"user_audio_chunk": base64.b64encode(audio).decode(),
375+
})
376+
)
377+
except ConnectionClosedOK:
378+
self.end_session()
379+
except Exception as e:
380+
print(f"Error sending user audio chunk: {e}")
381+
self.end_session()
382+
return callback
383+
384+
def _create_async_audio_callback(self, ws) -> Callable[[bytes], Awaitable[None]]:
385+
"""Create async audio input callback."""
386+
async def callback(audio: bytes) -> None:
387+
try:
388+
await ws.send(
389+
json.dumps({
390+
"user_audio_chunk": base64.b64encode(audio).decode(),
391+
})
392+
)
393+
except ConnectionClosedOK:
394+
await self.end_session()
395+
except Exception as e:
396+
print(f"Error sending user audio chunk: {e}")
397+
await self.end_session()
398+
return callback
399+
400+
def _handle_connection_closed(self) -> Union[None, Awaitable[None]]:
401+
"""Handle WebSocket connection closed - to be implemented by subclasses."""
402+
raise NotImplementedError
303403

304404
def _get_wss_url(self):
305405
base_ws_url = self.client._client_wrapper.get_environment().wss
@@ -327,7 +427,7 @@ def _create_initiation_message(self):
327427
}
328428
)
329429

330-
def _handle_message_core(self, message, message_handler):
430+
def _handle_message_core(self, message: Dict[str, Any], message_handler: MessageHandler) -> None:
331431
"""Core message handling logic shared between sync and async implementations.
332432
333433
Args:
@@ -383,7 +483,7 @@ def _handle_message_core(self, message, message_handler):
383483
else:
384484
pass # Ignore all other message types.
385485

386-
async def _handle_message_core_async(self, message, message_handler):
486+
async def _handle_message_core_async(self, message: Dict[str, Any], message_handler: MessageHandler) -> None:
387487
"""Async wrapper for core message handling logic."""
388488
if message["type"] == "conversation_initiation_metadata":
389489
event = message["conversation_initiation_metadata_event"]
@@ -592,35 +692,24 @@ def send_contextual_update(self, text: str):
592692
print(f"Error sending contextual update: {e}")
593693
raise
594694

695+
def _handle_connection_closed(self) -> None:
696+
self.end_session()
697+
595698
def _run(self, ws_url: str):
596699
with connect(ws_url, max_size=16 * 1024 * 1024) as ws:
597700
self._ws = ws
598701
ws.send(self._create_initiation_message())
599-
self._ws = ws
600-
601-
def input_callback(audio):
602-
try:
603-
ws.send(
604-
json.dumps(
605-
{
606-
"user_audio_chunk": base64.b64encode(audio).decode(),
607-
}
608-
)
609-
)
610-
except ConnectionClosedOK:
611-
self.end_session()
612-
except Exception as e:
613-
print(f"Error sending user audio chunk: {e}")
614-
self.end_session()
615-
702+
703+
input_callback = self._create_sync_audio_callback(ws)
616704
self.audio_interface.start(input_callback)
705+
617706
while not self._should_stop.is_set():
618707
try:
619708
message = json.loads(ws.recv(timeout=0.5))
620709
if self._should_stop.is_set():
621710
return
622711
self._handle_message(message, ws)
623-
except ConnectionClosedOK as e:
712+
except ConnectionClosedOK:
624713
self.end_session()
625714
except TimeoutError:
626715
pass
@@ -631,31 +720,23 @@ def input_callback(audio):
631720
self._ws = None
632721

633722
def _handle_message(self, message, ws):
634-
class SyncMessageHandler:
635-
def __init__(self, conversation, ws):
636-
self.conversation = conversation
637-
self.ws = ws
638-
self.callback_agent_response = conversation.callback_agent_response
639-
self.callback_agent_response_correction = conversation.callback_agent_response_correction
640-
self.callback_user_transcript = conversation.callback_user_transcript
641-
self.callback_latency_measurement = conversation.callback_latency_measurement
642-
643-
def handle_audio_output(self, audio):
723+
class SyncMessageHandler(BaseMessageHandler):
724+
def handle_audio_output(self, audio: bytes) -> None:
644725
self.conversation.audio_interface.output(audio)
645726

646-
def handle_agent_response(self, response):
727+
def handle_agent_response(self, response: str) -> None:
647728
self.conversation.callback_agent_response(response)
648729

649-
def handle_agent_response_correction(self, original, corrected):
730+
def handle_agent_response_correction(self, original: str, corrected: str) -> None:
650731
self.conversation.callback_agent_response_correction(original, corrected)
651732

652-
def handle_user_transcript(self, transcript):
733+
def handle_user_transcript(self, transcript: str) -> None:
653734
self.conversation.callback_user_transcript(transcript)
654735

655-
def handle_interruption(self):
736+
def handle_interruption(self) -> None:
656737
self.conversation.audio_interface.interrupt()
657738

658-
def handle_ping(self, event):
739+
def handle_ping(self, event: Dict[str, Any]) -> None:
659740
self.ws.send(
660741
json.dumps(
661742
{
@@ -665,15 +746,11 @@ def handle_ping(self, event):
665746
)
666747
)
667748

668-
def handle_latency_measurement(self, latency):
749+
def handle_latency_measurement(self, latency: int) -> None:
669750
self.conversation.callback_latency_measurement(latency)
670751

671-
def handle_client_tool_call(self, tool_name, parameters):
672-
def send_response(response):
673-
if not self.conversation._should_stop.is_set():
674-
self.ws.send(json.dumps(response))
675-
676-
self.conversation.client_tools.execute_tool(tool_name, parameters, send_response)
752+
def _send_response(self, response: Dict[str, Any]) -> None:
753+
self.ws.send(json.dumps(response))
677754

678755
handler = SyncMessageHandler(self, ws)
679756
self._handle_message_core(message, handler)
@@ -761,6 +838,14 @@ async def end_session(self):
761838
self.client_tools.stop()
762839
self._ws = None
763840
self._should_stop.set()
841+
842+
# Cleanup the background task
843+
if self._task and not self._task.done():
844+
self._task.cancel()
845+
try:
846+
await self._task
847+
except asyncio.CancelledError:
848+
pass
764849

765850
if self.callback_end_session:
766851
await self.callback_end_session()
@@ -836,26 +921,15 @@ async def send_contextual_update(self, text: str):
836921
print(f"Error sending contextual update: {e}")
837922
raise
838923

924+
async def _handle_connection_closed(self) -> None:
925+
await self.end_session()
926+
839927
async def _run(self, ws_url: str):
840928
async with websockets.connect(ws_url, max_size=16 * 1024 * 1024) as ws:
841929
self._ws = ws
842930
await ws.send(self._create_initiation_message())
843-
844-
async def input_callback(audio):
845-
try:
846-
await ws.send(
847-
json.dumps(
848-
{
849-
"user_audio_chunk": base64.b64encode(audio).decode(),
850-
}
851-
)
852-
)
853-
except ConnectionClosedOK:
854-
await self.end_session()
855-
except Exception as e:
856-
print(f"Error sending user audio chunk: {e}")
857-
await self.end_session()
858-
931+
932+
input_callback = self._create_async_audio_callback(ws)
859933
await self.audio_interface.start(input_callback)
860934

861935
try:
@@ -879,31 +953,23 @@ async def input_callback(audio):
879953
self._ws = None
880954

881955
async def _handle_message(self, message, ws):
882-
class AsyncMessageHandler:
883-
def __init__(self, conversation, ws):
884-
self.conversation = conversation
885-
self.ws = ws
886-
self.callback_agent_response = conversation.callback_agent_response
887-
self.callback_agent_response_correction = conversation.callback_agent_response_correction
888-
self.callback_user_transcript = conversation.callback_user_transcript
889-
self.callback_latency_measurement = conversation.callback_latency_measurement
890-
891-
async def handle_audio_output(self, audio):
956+
class AsyncMessageHandler(BaseMessageHandler):
957+
async def handle_audio_output(self, audio: bytes) -> None:
892958
await self.conversation.audio_interface.output(audio)
893959

894-
async def handle_agent_response(self, response):
960+
async def handle_agent_response(self, response: str) -> None:
895961
await self.conversation.callback_agent_response(response)
896962

897-
async def handle_agent_response_correction(self, original, corrected):
963+
async def handle_agent_response_correction(self, original: str, corrected: str) -> None:
898964
await self.conversation.callback_agent_response_correction(original, corrected)
899965

900-
async def handle_user_transcript(self, transcript):
966+
async def handle_user_transcript(self, transcript: str) -> None:
901967
await self.conversation.callback_user_transcript(transcript)
902968

903-
async def handle_interruption(self):
969+
async def handle_interruption(self) -> None:
904970
await self.conversation.audio_interface.interrupt()
905971

906-
async def handle_ping(self, event):
972+
async def handle_ping(self, event: Dict[str, Any]) -> None:
907973
await self.ws.send(
908974
json.dumps(
909975
{
@@ -913,15 +979,11 @@ async def handle_ping(self, event):
913979
)
914980
)
915981

916-
async def handle_latency_measurement(self, latency):
982+
async def handle_latency_measurement(self, latency: int) -> None:
917983
await self.conversation.callback_latency_measurement(latency)
918984

919-
def handle_client_tool_call(self, tool_name, parameters):
920-
def send_response(response):
921-
if not self.conversation._should_stop.is_set():
922-
asyncio.create_task(self.ws.send(json.dumps(response)))
923-
924-
self.conversation.client_tools.execute_tool(tool_name, parameters, send_response)
985+
def _send_response(self, response: Dict[str, Any]) -> None:
986+
asyncio.create_task(self.ws.send(json.dumps(response)))
925987

926988
handler = AsyncMessageHandler(self, ws)
927989

0 commit comments

Comments
 (0)