Skip to content

Commit 03d98f1

Browse files
committed
reduce duplication
1 parent cf2e177 commit 03d98f1

1 file changed

Lines changed: 199 additions & 104 deletions

File tree

src/elevenlabs/conversational_ai/conversation.py

Lines changed: 199 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,113 @@ def _create_initiation_message(self):
327327
}
328328
)
329329

330+
def _handle_message_core(self, message, message_handler):
331+
"""Core message handling logic shared between sync and async implementations.
332+
333+
Args:
334+
message: The parsed message dictionary
335+
message_handler: Handler object with methods for different operations
336+
"""
337+
if message["type"] == "conversation_initiation_metadata":
338+
event = message["conversation_initiation_metadata_event"]
339+
assert self._conversation_id is None
340+
self._conversation_id = event["conversation_id"]
341+
342+
elif message["type"] == "audio":
343+
event = message["audio_event"]
344+
if int(event["event_id"]) <= self._last_interrupt_id:
345+
return
346+
audio = base64.b64decode(event["audio_base_64"])
347+
message_handler.handle_audio_output(audio)
348+
349+
elif message["type"] == "agent_response":
350+
if message_handler.callback_agent_response:
351+
event = message["agent_response_event"]
352+
message_handler.handle_agent_response(event["agent_response"].strip())
353+
354+
elif message["type"] == "agent_response_correction":
355+
if message_handler.callback_agent_response_correction:
356+
event = message["agent_response_correction_event"]
357+
message_handler.handle_agent_response_correction(
358+
event["original_agent_response"].strip(),
359+
event["corrected_agent_response"].strip()
360+
)
361+
362+
elif message["type"] == "user_transcript":
363+
if message_handler.callback_user_transcript:
364+
event = message["user_transcription_event"]
365+
message_handler.handle_user_transcript(event["user_transcript"].strip())
366+
367+
elif message["type"] == "interruption":
368+
event = message["interruption_event"]
369+
self._last_interrupt_id = int(event["event_id"])
370+
message_handler.handle_interruption()
371+
372+
elif message["type"] == "ping":
373+
event = message["ping_event"]
374+
message_handler.handle_ping(event)
375+
if message_handler.callback_latency_measurement and event["ping_ms"]:
376+
message_handler.handle_latency_measurement(int(event["ping_ms"]))
377+
378+
elif message["type"] == "client_tool_call":
379+
tool_call = message.get("client_tool_call", {})
380+
tool_name = tool_call.get("tool_name")
381+
parameters = {"tool_call_id": tool_call["tool_call_id"], **tool_call.get("parameters", {})}
382+
message_handler.handle_client_tool_call(tool_name, parameters)
383+
else:
384+
pass # Ignore all other message types.
385+
386+
async def _handle_message_core_async(self, message, message_handler):
387+
"""Async wrapper for core message handling logic."""
388+
if message["type"] == "conversation_initiation_metadata":
389+
event = message["conversation_initiation_metadata_event"]
390+
assert self._conversation_id is None
391+
self._conversation_id = event["conversation_id"]
392+
393+
elif message["type"] == "audio":
394+
event = message["audio_event"]
395+
if int(event["event_id"]) <= self._last_interrupt_id:
396+
return
397+
audio = base64.b64decode(event["audio_base_64"])
398+
await message_handler.handle_audio_output(audio)
399+
400+
elif message["type"] == "agent_response":
401+
if message_handler.callback_agent_response:
402+
event = message["agent_response_event"]
403+
await message_handler.handle_agent_response(event["agent_response"].strip())
404+
405+
elif message["type"] == "agent_response_correction":
406+
if message_handler.callback_agent_response_correction:
407+
event = message["agent_response_correction_event"]
408+
await message_handler.handle_agent_response_correction(
409+
event["original_agent_response"].strip(),
410+
event["corrected_agent_response"].strip()
411+
)
412+
413+
elif message["type"] == "user_transcript":
414+
if message_handler.callback_user_transcript:
415+
event = message["user_transcription_event"]
416+
await message_handler.handle_user_transcript(event["user_transcript"].strip())
417+
418+
elif message["type"] == "interruption":
419+
event = message["interruption_event"]
420+
self._last_interrupt_id = int(event["event_id"])
421+
await message_handler.handle_interruption()
422+
423+
elif message["type"] == "ping":
424+
event = message["ping_event"]
425+
await message_handler.handle_ping(event)
426+
if message_handler.callback_latency_measurement and event["ping_ms"]:
427+
await message_handler.handle_latency_measurement(int(event["ping_ms"]))
428+
429+
elif message["type"] == "client_tool_call":
430+
tool_call = message.get("client_tool_call", {})
431+
tool_name = tool_call.get("tool_name")
432+
parameters = {"tool_call_id": tool_call["tool_call_id"], **tool_call.get("parameters", {})}
433+
message_handler.handle_client_tool_call(tool_name, parameters)
434+
else:
435+
pass # Ignore all other message types.
436+
330437

331438
class Conversation(BaseConversation):
332439
audio_interface: AudioInterface
@@ -524,59 +631,52 @@ def input_callback(audio):
524631
self._ws = None
525632

526633
def _handle_message(self, message, ws):
527-
if message["type"] == "conversation_initiation_metadata":
528-
event = message["conversation_initiation_metadata_event"]
529-
assert self._conversation_id is None
530-
self._conversation_id = event["conversation_id"]
531-
532-
elif message["type"] == "audio":
533-
event = message["audio_event"]
534-
if int(event["event_id"]) <= self._last_interrupt_id:
535-
return
536-
audio = base64.b64decode(event["audio_base_64"])
537-
self.audio_interface.output(audio)
538-
elif message["type"] == "agent_response":
539-
if self.callback_agent_response:
540-
event = message["agent_response_event"]
541-
self.callback_agent_response(event["agent_response"].strip())
542-
elif message["type"] == "agent_response_correction":
543-
if self.callback_agent_response_correction:
544-
event = message["agent_response_correction_event"]
545-
self.callback_agent_response_correction(
546-
event["original_agent_response"].strip(), event["corrected_agent_response"].strip()
547-
)
548-
elif message["type"] == "user_transcript":
549-
if self.callback_user_transcript:
550-
event = message["user_transcription_event"]
551-
self.callback_user_transcript(event["user_transcript"].strip())
552-
elif message["type"] == "interruption":
553-
event = message["interruption_event"]
554-
self._last_interrupt_id = int(event["event_id"])
555-
self.audio_interface.interrupt()
556-
elif message["type"] == "ping":
557-
event = message["ping_event"]
558-
ws.send(
559-
json.dumps(
560-
{
561-
"type": "pong",
562-
"event_id": event["event_id"],
563-
}
634+
class SyncMessageHandler:
635+
def __init__(self, conversation, ws):
636+
self.conversation = conversation
637+
self.ws = ws
638+
self.callback_agent_response = conversation.callback_agent_response
639+
self.callback_agent_response_correction = conversation.callback_agent_response_correction
640+
self.callback_user_transcript = conversation.callback_user_transcript
641+
self.callback_latency_measurement = conversation.callback_latency_measurement
642+
643+
def handle_audio_output(self, audio):
644+
self.conversation.audio_interface.output(audio)
645+
646+
def handle_agent_response(self, response):
647+
self.conversation.callback_agent_response(response)
648+
649+
def handle_agent_response_correction(self, original, corrected):
650+
self.conversation.callback_agent_response_correction(original, corrected)
651+
652+
def handle_user_transcript(self, transcript):
653+
self.conversation.callback_user_transcript(transcript)
654+
655+
def handle_interruption(self):
656+
self.conversation.audio_interface.interrupt()
657+
658+
def handle_ping(self, event):
659+
self.ws.send(
660+
json.dumps(
661+
{
662+
"type": "pong",
663+
"event_id": event["event_id"],
664+
}
665+
)
564666
)
565-
)
566-
if self.callback_latency_measurement and event["ping_ms"]:
567-
self.callback_latency_measurement(int(event["ping_ms"]))
568-
elif message["type"] == "client_tool_call":
569-
tool_call = message.get("client_tool_call", {})
570-
tool_name = tool_call.get("tool_name")
571-
parameters = {"tool_call_id": tool_call["tool_call_id"], **tool_call.get("parameters", {})}
572-
573-
def send_response(response):
574-
if not self._should_stop.is_set():
575-
ws.send(json.dumps(response))
576-
577-
self.client_tools.execute_tool(tool_name, parameters, send_response)
578-
else:
579-
pass # Ignore all other message types.
667+
668+
def handle_latency_measurement(self, latency):
669+
self.conversation.callback_latency_measurement(latency)
670+
671+
def handle_client_tool_call(self, tool_name, parameters):
672+
def send_response(response):
673+
if not self.conversation._should_stop.is_set():
674+
self.ws.send(json.dumps(response))
675+
676+
self.conversation.client_tools.execute_tool(tool_name, parameters, send_response)
677+
678+
handler = SyncMessageHandler(self, ws)
679+
self._handle_message_core(message, handler)
580680

581681

582682
class AsyncConversation(BaseConversation):
@@ -779,56 +879,51 @@ async def input_callback(audio):
779879
self._ws = None
780880

781881
async def _handle_message(self, message, ws):
782-
if message["type"] == "conversation_initiation_metadata":
783-
event = message["conversation_initiation_metadata_event"]
784-
assert self._conversation_id is None
785-
self._conversation_id = event["conversation_id"]
786-
787-
elif message["type"] == "audio":
788-
event = message["audio_event"]
789-
if int(event["event_id"]) <= self._last_interrupt_id:
790-
return
791-
audio = base64.b64decode(event["audio_base_64"])
792-
await self.audio_interface.output(audio)
793-
elif message["type"] == "agent_response":
794-
if self.callback_agent_response:
795-
event = message["agent_response_event"]
796-
await self.callback_agent_response(event["agent_response"].strip())
797-
elif message["type"] == "agent_response_correction":
798-
if self.callback_agent_response_correction:
799-
event = message["agent_response_correction_event"]
800-
await self.callback_agent_response_correction(
801-
event["original_agent_response"].strip(), event["corrected_agent_response"].strip()
802-
)
803-
elif message["type"] == "user_transcript":
804-
if self.callback_user_transcript:
805-
event = message["user_transcription_event"]
806-
await self.callback_user_transcript(event["user_transcript"].strip())
807-
elif message["type"] == "interruption":
808-
event = message["interruption_event"]
809-
self._last_interrupt_id = int(event["event_id"])
810-
await self.audio_interface.interrupt()
811-
elif message["type"] == "ping":
812-
event = message["ping_event"]
813-
await ws.send(
814-
json.dumps(
815-
{
816-
"type": "pong",
817-
"event_id": event["event_id"],
818-
}
882+
class AsyncMessageHandler:
883+
def __init__(self, conversation, ws):
884+
self.conversation = conversation
885+
self.ws = ws
886+
self.callback_agent_response = conversation.callback_agent_response
887+
self.callback_agent_response_correction = conversation.callback_agent_response_correction
888+
self.callback_user_transcript = conversation.callback_user_transcript
889+
self.callback_latency_measurement = conversation.callback_latency_measurement
890+
891+
async def handle_audio_output(self, audio):
892+
await self.conversation.audio_interface.output(audio)
893+
894+
async def handle_agent_response(self, response):
895+
await self.conversation.callback_agent_response(response)
896+
897+
async def handle_agent_response_correction(self, original, corrected):
898+
await self.conversation.callback_agent_response_correction(original, corrected)
899+
900+
async def handle_user_transcript(self, transcript):
901+
await self.conversation.callback_user_transcript(transcript)
902+
903+
async def handle_interruption(self):
904+
await self.conversation.audio_interface.interrupt()
905+
906+
async def handle_ping(self, event):
907+
await self.ws.send(
908+
json.dumps(
909+
{
910+
"type": "pong",
911+
"event_id": event["event_id"],
912+
}
913+
)
819914
)
820-
)
821-
if self.callback_latency_measurement and event["ping_ms"]:
822-
await self.callback_latency_measurement(int(event["ping_ms"]))
823-
elif message["type"] == "client_tool_call":
824-
tool_call = message.get("client_tool_call", {})
825-
tool_name = tool_call.get("tool_name")
826-
parameters = {"tool_call_id": tool_call["tool_call_id"], **tool_call.get("parameters", {})}
827-
828-
def send_response(response):
829-
if not self._should_stop.is_set():
830-
asyncio.create_task(ws.send(json.dumps(response)))
831-
832-
self.client_tools.execute_tool(tool_name, parameters, send_response)
833-
else:
834-
pass # Ignore all other message types.
915+
916+
async def handle_latency_measurement(self, latency):
917+
await self.conversation.callback_latency_measurement(latency)
918+
919+
def handle_client_tool_call(self, tool_name, parameters):
920+
def send_response(response):
921+
if not self.conversation._should_stop.is_set():
922+
asyncio.create_task(self.ws.send(json.dumps(response)))
923+
924+
self.conversation.client_tools.execute_tool(tool_name, parameters, send_response)
925+
926+
handler = AsyncMessageHandler(self, ws)
927+
928+
# Use the shared core message handling logic with async wrapper
929+
await self._handle_message_core_async(message, handler)

0 commit comments

Comments
 (0)