1- from abc import ABC , abstractmethod
1+ import asyncio
22import base64
33import json
44import threading
5- from typing import Callable , Optional , Awaitable , Union , Any , Literal , Dict , Tuple
6- import asyncio
5+ from abc import ABC , abstractmethod
76from concurrent .futures import ThreadPoolExecutor
87from enum import Enum
8+ from typing import Any , Awaitable , Callable , Dict , Literal , Optional , Protocol , Tuple , Union
99
10- from websockets .sync .client import connect , Connection
1110import websockets
12- from websockets .exceptions import ConnectionClosedOK
13-
1411from ..base_client import BaseElevenLabs
1512from ..version import __version__
13+ from websockets .exceptions import ConnectionClosedOK
14+ from websockets .sync .client import Connection , connect
1615
1716
1817class ClientToOrchestratorEvent (str , Enum ):
@@ -276,6 +275,71 @@ def __init__(
276275 self .user_id = user_id
277276
278277
278+ class MessageHandler (Protocol ):
279+ """Protocol defining the interface for message handlers."""
280+
281+ callback_agent_response : Optional [Callable ]
282+ callback_agent_response_correction : Optional [Callable ]
283+ callback_user_transcript : Optional [Callable ]
284+ callback_latency_measurement : Optional [Callable ]
285+
286+ def handle_audio_output (self , audio : bytes ) -> Union [None , Awaitable [None ]]:
287+ """Handle audio output."""
288+ ...
289+
290+ def handle_agent_response (self , response : str ) -> Union [None , Awaitable [None ]]:
291+ """Handle agent response."""
292+ ...
293+
294+ def handle_agent_response_correction (self , original : str , corrected : str ) -> Union [None , Awaitable [None ]]:
295+ """Handle agent response correction."""
296+ ...
297+
298+ def handle_user_transcript (self , transcript : str ) -> Union [None , Awaitable [None ]]:
299+ """Handle user transcript."""
300+ ...
301+
302+ def handle_interruption (self ) -> Union [None , Awaitable [None ]]:
303+ """Handle interruption."""
304+ ...
305+
306+ def handle_ping (self , event : Dict [str , Any ]) -> Union [None , Awaitable [None ]]:
307+ """Handle ping event."""
308+ ...
309+
310+ def handle_latency_measurement (self , latency : int ) -> Union [None , Awaitable [None ]]:
311+ """Handle latency measurement."""
312+ ...
313+
314+ def handle_client_tool_call (self , tool_name : str , parameters : Dict [str , Any ]) -> None :
315+ """Handle client tool call."""
316+ ...
317+
318+
319+ class BaseMessageHandler :
320+ """Base implementation for message handlers with common functionality."""
321+
322+ def __init__ (self , conversation , ws_or_websocket ):
323+ self .conversation = conversation
324+ self .ws = ws_or_websocket
325+ self .callback_agent_response = conversation .callback_agent_response
326+ self .callback_agent_response_correction = conversation .callback_agent_response_correction
327+ self .callback_user_transcript = conversation .callback_user_transcript
328+ self .callback_latency_measurement = conversation .callback_latency_measurement
329+
330+ def handle_client_tool_call (self , tool_name : str , parameters : Dict [str , Any ]) -> None :
331+ """Handle client tool call - common implementation for both sync and async."""
332+ def send_response (response ):
333+ if not self .conversation ._should_stop .is_set ():
334+ self ._send_response (response )
335+
336+ self .conversation .client_tools .execute_tool (tool_name , parameters , send_response )
337+
338+ def _send_response (self , response : Dict [str , Any ]) -> None :
339+ """Send response - to be implemented by subclasses."""
340+ raise NotImplementedError
341+
342+
279343class BaseConversation :
280344 """Base class for conversation implementations with shared parameters and logic."""
281345
@@ -300,6 +364,42 @@ def __init__(
300364
301365 self ._conversation_id = None
302366 self ._last_interrupt_id = 0
367+
368+ def _create_sync_audio_callback (self , ws ) -> Callable [[bytes ], None ]:
369+ """Create sync audio input callback."""
370+ def callback (audio : bytes ) -> None :
371+ try :
372+ ws .send (
373+ json .dumps ({
374+ "user_audio_chunk" : base64 .b64encode (audio ).decode (),
375+ })
376+ )
377+ except ConnectionClosedOK :
378+ self .end_session ()
379+ except Exception as e :
380+ print (f"Error sending user audio chunk: { e } " )
381+ self .end_session ()
382+ return callback
383+
384+ def _create_async_audio_callback (self , ws ) -> Callable [[bytes ], Awaitable [None ]]:
385+ """Create async audio input callback."""
386+ async def callback (audio : bytes ) -> None :
387+ try :
388+ await ws .send (
389+ json .dumps ({
390+ "user_audio_chunk" : base64 .b64encode (audio ).decode (),
391+ })
392+ )
393+ except ConnectionClosedOK :
394+ await self .end_session ()
395+ except Exception as e :
396+ print (f"Error sending user audio chunk: { e } " )
397+ await self .end_session ()
398+ return callback
399+
400+ def _handle_connection_closed (self ) -> Union [None , Awaitable [None ]]:
401+ """Handle WebSocket connection closed - to be implemented by subclasses."""
402+ raise NotImplementedError
303403
304404 def _get_wss_url (self ):
305405 base_ws_url = self .client ._client_wrapper .get_environment ().wss
@@ -327,7 +427,7 @@ def _create_initiation_message(self):
327427 }
328428 )
329429
330- def _handle_message_core (self , message , message_handler ) :
430+ def _handle_message_core (self , message : Dict [ str , Any ], message_handler : MessageHandler ) -> None :
331431 """Core message handling logic shared between sync and async implementations.
332432
333433 Args:
@@ -383,7 +483,7 @@ def _handle_message_core(self, message, message_handler):
383483 else :
384484 pass # Ignore all other message types.
385485
386- async def _handle_message_core_async (self , message , message_handler ) :
486+ async def _handle_message_core_async (self , message : Dict [ str , Any ], message_handler : MessageHandler ) -> None :
387487 """Async wrapper for core message handling logic."""
388488 if message ["type" ] == "conversation_initiation_metadata" :
389489 event = message ["conversation_initiation_metadata_event" ]
@@ -592,35 +692,24 @@ def send_contextual_update(self, text: str):
592692 print (f"Error sending contextual update: { e } " )
593693 raise
594694
695+ def _handle_connection_closed (self ) -> None :
696+ self .end_session ()
697+
595698 def _run (self , ws_url : str ):
596699 with connect (ws_url , max_size = 16 * 1024 * 1024 ) as ws :
597700 self ._ws = ws
598701 ws .send (self ._create_initiation_message ())
599- self ._ws = ws
600-
601- def input_callback (audio ):
602- try :
603- ws .send (
604- json .dumps (
605- {
606- "user_audio_chunk" : base64 .b64encode (audio ).decode (),
607- }
608- )
609- )
610- except ConnectionClosedOK :
611- self .end_session ()
612- except Exception as e :
613- print (f"Error sending user audio chunk: { e } " )
614- self .end_session ()
615-
702+
703+ input_callback = self ._create_sync_audio_callback (ws )
616704 self .audio_interface .start (input_callback )
705+
617706 while not self ._should_stop .is_set ():
618707 try :
619708 message = json .loads (ws .recv (timeout = 0.5 ))
620709 if self ._should_stop .is_set ():
621710 return
622711 self ._handle_message (message , ws )
623- except ConnectionClosedOK as e :
712+ except ConnectionClosedOK :
624713 self .end_session ()
625714 except TimeoutError :
626715 pass
@@ -631,31 +720,23 @@ def input_callback(audio):
631720 self ._ws = None
632721
633722 def _handle_message (self , message , ws ):
634- class SyncMessageHandler :
635- def __init__ (self , conversation , ws ):
636- self .conversation = conversation
637- self .ws = ws
638- self .callback_agent_response = conversation .callback_agent_response
639- self .callback_agent_response_correction = conversation .callback_agent_response_correction
640- self .callback_user_transcript = conversation .callback_user_transcript
641- self .callback_latency_measurement = conversation .callback_latency_measurement
642-
643- def handle_audio_output (self , audio ):
723+ class SyncMessageHandler (BaseMessageHandler ):
724+ def handle_audio_output (self , audio : bytes ) -> None :
644725 self .conversation .audio_interface .output (audio )
645726
646- def handle_agent_response (self , response ) :
727+ def handle_agent_response (self , response : str ) -> None :
647728 self .conversation .callback_agent_response (response )
648729
649- def handle_agent_response_correction (self , original , corrected ) :
730+ def handle_agent_response_correction (self , original : str , corrected : str ) -> None :
650731 self .conversation .callback_agent_response_correction (original , corrected )
651732
652- def handle_user_transcript (self , transcript ) :
733+ def handle_user_transcript (self , transcript : str ) -> None :
653734 self .conversation .callback_user_transcript (transcript )
654735
655- def handle_interruption (self ):
736+ def handle_interruption (self ) -> None :
656737 self .conversation .audio_interface .interrupt ()
657738
658- def handle_ping (self , event ) :
739+ def handle_ping (self , event : Dict [ str , Any ]) -> None :
659740 self .ws .send (
660741 json .dumps (
661742 {
@@ -665,15 +746,11 @@ def handle_ping(self, event):
665746 )
666747 )
667748
668- def handle_latency_measurement (self , latency ) :
749+ def handle_latency_measurement (self , latency : int ) -> None :
669750 self .conversation .callback_latency_measurement (latency )
670751
671- def handle_client_tool_call (self , tool_name , parameters ):
672- def send_response (response ):
673- if not self .conversation ._should_stop .is_set ():
674- self .ws .send (json .dumps (response ))
675-
676- self .conversation .client_tools .execute_tool (tool_name , parameters , send_response )
752+ def _send_response (self , response : Dict [str , Any ]) -> None :
753+ self .ws .send (json .dumps (response ))
677754
678755 handler = SyncMessageHandler (self , ws )
679756 self ._handle_message_core (message , handler )
@@ -761,6 +838,14 @@ async def end_session(self):
761838 self .client_tools .stop ()
762839 self ._ws = None
763840 self ._should_stop .set ()
841+
842+ # Cleanup the background task
843+ if self ._task and not self ._task .done ():
844+ self ._task .cancel ()
845+ try :
846+ await self ._task
847+ except asyncio .CancelledError :
848+ pass
764849
765850 if self .callback_end_session :
766851 await self .callback_end_session ()
@@ -836,26 +921,15 @@ async def send_contextual_update(self, text: str):
836921 print (f"Error sending contextual update: { e } " )
837922 raise
838923
924+ async def _handle_connection_closed (self ) -> None :
925+ await self .end_session ()
926+
839927 async def _run (self , ws_url : str ):
840928 async with websockets .connect (ws_url , max_size = 16 * 1024 * 1024 ) as ws :
841929 self ._ws = ws
842930 await ws .send (self ._create_initiation_message ())
843-
844- async def input_callback (audio ):
845- try :
846- await ws .send (
847- json .dumps (
848- {
849- "user_audio_chunk" : base64 .b64encode (audio ).decode (),
850- }
851- )
852- )
853- except ConnectionClosedOK :
854- await self .end_session ()
855- except Exception as e :
856- print (f"Error sending user audio chunk: { e } " )
857- await self .end_session ()
858-
931+
932+ input_callback = self ._create_async_audio_callback (ws )
859933 await self .audio_interface .start (input_callback )
860934
861935 try :
@@ -879,31 +953,23 @@ async def input_callback(audio):
879953 self ._ws = None
880954
881955 async def _handle_message (self , message , ws ):
882- class AsyncMessageHandler :
883- def __init__ (self , conversation , ws ):
884- self .conversation = conversation
885- self .ws = ws
886- self .callback_agent_response = conversation .callback_agent_response
887- self .callback_agent_response_correction = conversation .callback_agent_response_correction
888- self .callback_user_transcript = conversation .callback_user_transcript
889- self .callback_latency_measurement = conversation .callback_latency_measurement
890-
891- async def handle_audio_output (self , audio ):
956+ class AsyncMessageHandler (BaseMessageHandler ):
957+ async def handle_audio_output (self , audio : bytes ) -> None :
892958 await self .conversation .audio_interface .output (audio )
893959
894- async def handle_agent_response (self , response ) :
960+ async def handle_agent_response (self , response : str ) -> None :
895961 await self .conversation .callback_agent_response (response )
896962
897- async def handle_agent_response_correction (self , original , corrected ) :
963+ async def handle_agent_response_correction (self , original : str , corrected : str ) -> None :
898964 await self .conversation .callback_agent_response_correction (original , corrected )
899965
900- async def handle_user_transcript (self , transcript ) :
966+ async def handle_user_transcript (self , transcript : str ) -> None :
901967 await self .conversation .callback_user_transcript (transcript )
902968
903- async def handle_interruption (self ):
969+ async def handle_interruption (self ) -> None :
904970 await self .conversation .audio_interface .interrupt ()
905971
906- async def handle_ping (self , event ) :
972+ async def handle_ping (self , event : Dict [ str , Any ]) -> None :
907973 await self .ws .send (
908974 json .dumps (
909975 {
@@ -913,15 +979,11 @@ async def handle_ping(self, event):
913979 )
914980 )
915981
916- async def handle_latency_measurement (self , latency ) :
982+ async def handle_latency_measurement (self , latency : int ) -> None :
917983 await self .conversation .callback_latency_measurement (latency )
918984
919- def handle_client_tool_call (self , tool_name , parameters ):
920- def send_response (response ):
921- if not self .conversation ._should_stop .is_set ():
922- asyncio .create_task (self .ws .send (json .dumps (response )))
923-
924- self .conversation .client_tools .execute_tool (tool_name , parameters , send_response )
985+ def _send_response (self , response : Dict [str , Any ]) -> None :
986+ asyncio .create_task (self .ws .send (json .dumps (response )))
925987
926988 handler = AsyncMessageHandler (self , ws )
927989
0 commit comments