@@ -327,6 +327,113 @@ def _create_initiation_message(self):
327327 }
328328 )
329329
330+ def _handle_message_core (self , message , message_handler ):
331+ """Core message handling logic shared between sync and async implementations.
332+
333+ Args:
334+ message: The parsed message dictionary
335+ message_handler: Handler object with methods for different operations
336+ """
337+ if message ["type" ] == "conversation_initiation_metadata" :
338+ event = message ["conversation_initiation_metadata_event" ]
339+ assert self ._conversation_id is None
340+ self ._conversation_id = event ["conversation_id" ]
341+
342+ elif message ["type" ] == "audio" :
343+ event = message ["audio_event" ]
344+ if int (event ["event_id" ]) <= self ._last_interrupt_id :
345+ return
346+ audio = base64 .b64decode (event ["audio_base_64" ])
347+ message_handler .handle_audio_output (audio )
348+
349+ elif message ["type" ] == "agent_response" :
350+ if message_handler .callback_agent_response :
351+ event = message ["agent_response_event" ]
352+ message_handler .handle_agent_response (event ["agent_response" ].strip ())
353+
354+ elif message ["type" ] == "agent_response_correction" :
355+ if message_handler .callback_agent_response_correction :
356+ event = message ["agent_response_correction_event" ]
357+ message_handler .handle_agent_response_correction (
358+ event ["original_agent_response" ].strip (),
359+ event ["corrected_agent_response" ].strip ()
360+ )
361+
362+ elif message ["type" ] == "user_transcript" :
363+ if message_handler .callback_user_transcript :
364+ event = message ["user_transcription_event" ]
365+ message_handler .handle_user_transcript (event ["user_transcript" ].strip ())
366+
367+ elif message ["type" ] == "interruption" :
368+ event = message ["interruption_event" ]
369+ self ._last_interrupt_id = int (event ["event_id" ])
370+ message_handler .handle_interruption ()
371+
372+ elif message ["type" ] == "ping" :
373+ event = message ["ping_event" ]
374+ message_handler .handle_ping (event )
375+ if message_handler .callback_latency_measurement and event ["ping_ms" ]:
376+ message_handler .handle_latency_measurement (int (event ["ping_ms" ]))
377+
378+ elif message ["type" ] == "client_tool_call" :
379+ tool_call = message .get ("client_tool_call" , {})
380+ tool_name = tool_call .get ("tool_name" )
381+ parameters = {"tool_call_id" : tool_call ["tool_call_id" ], ** tool_call .get ("parameters" , {})}
382+ message_handler .handle_client_tool_call (tool_name , parameters )
383+ else :
384+ pass # Ignore all other message types.
385+
386+ async def _handle_message_core_async (self , message , message_handler ):
387+ """Async wrapper for core message handling logic."""
388+ if message ["type" ] == "conversation_initiation_metadata" :
389+ event = message ["conversation_initiation_metadata_event" ]
390+ assert self ._conversation_id is None
391+ self ._conversation_id = event ["conversation_id" ]
392+
393+ elif message ["type" ] == "audio" :
394+ event = message ["audio_event" ]
395+ if int (event ["event_id" ]) <= self ._last_interrupt_id :
396+ return
397+ audio = base64 .b64decode (event ["audio_base_64" ])
398+ await message_handler .handle_audio_output (audio )
399+
400+ elif message ["type" ] == "agent_response" :
401+ if message_handler .callback_agent_response :
402+ event = message ["agent_response_event" ]
403+ await message_handler .handle_agent_response (event ["agent_response" ].strip ())
404+
405+ elif message ["type" ] == "agent_response_correction" :
406+ if message_handler .callback_agent_response_correction :
407+ event = message ["agent_response_correction_event" ]
408+ await message_handler .handle_agent_response_correction (
409+ event ["original_agent_response" ].strip (),
410+ event ["corrected_agent_response" ].strip ()
411+ )
412+
413+ elif message ["type" ] == "user_transcript" :
414+ if message_handler .callback_user_transcript :
415+ event = message ["user_transcription_event" ]
416+ await message_handler .handle_user_transcript (event ["user_transcript" ].strip ())
417+
418+ elif message ["type" ] == "interruption" :
419+ event = message ["interruption_event" ]
420+ self ._last_interrupt_id = int (event ["event_id" ])
421+ await message_handler .handle_interruption ()
422+
423+ elif message ["type" ] == "ping" :
424+ event = message ["ping_event" ]
425+ await message_handler .handle_ping (event )
426+ if message_handler .callback_latency_measurement and event ["ping_ms" ]:
427+ await message_handler .handle_latency_measurement (int (event ["ping_ms" ]))
428+
429+ elif message ["type" ] == "client_tool_call" :
430+ tool_call = message .get ("client_tool_call" , {})
431+ tool_name = tool_call .get ("tool_name" )
432+ parameters = {"tool_call_id" : tool_call ["tool_call_id" ], ** tool_call .get ("parameters" , {})}
433+ message_handler .handle_client_tool_call (tool_name , parameters )
434+ else :
435+ pass # Ignore all other message types.
436+
330437
331438class Conversation (BaseConversation ):
332439 audio_interface : AudioInterface
@@ -524,59 +631,52 @@ def input_callback(audio):
524631 self ._ws = None
525632
526633 def _handle_message (self , message , ws ):
527- if message ["type" ] == "conversation_initiation_metadata" :
528- event = message ["conversation_initiation_metadata_event" ]
529- assert self ._conversation_id is None
530- self ._conversation_id = event ["conversation_id" ]
531-
532- elif message ["type" ] == "audio" :
533- event = message ["audio_event" ]
534- if int (event ["event_id" ]) <= self ._last_interrupt_id :
535- return
536- audio = base64 .b64decode (event ["audio_base_64" ])
537- self .audio_interface .output (audio )
538- elif message ["type" ] == "agent_response" :
539- if self .callback_agent_response :
540- event = message ["agent_response_event" ]
541- self .callback_agent_response (event ["agent_response" ].strip ())
542- elif message ["type" ] == "agent_response_correction" :
543- if self .callback_agent_response_correction :
544- event = message ["agent_response_correction_event" ]
545- self .callback_agent_response_correction (
546- event ["original_agent_response" ].strip (), event ["corrected_agent_response" ].strip ()
547- )
548- elif message ["type" ] == "user_transcript" :
549- if self .callback_user_transcript :
550- event = message ["user_transcription_event" ]
551- self .callback_user_transcript (event ["user_transcript" ].strip ())
552- elif message ["type" ] == "interruption" :
553- event = message ["interruption_event" ]
554- self ._last_interrupt_id = int (event ["event_id" ])
555- self .audio_interface .interrupt ()
556- elif message ["type" ] == "ping" :
557- event = message ["ping_event" ]
558- ws .send (
559- json .dumps (
560- {
561- "type" : "pong" ,
562- "event_id" : event ["event_id" ],
563- }
634+ class SyncMessageHandler :
635+ def __init__ (self , conversation , ws ):
636+ self .conversation = conversation
637+ self .ws = ws
638+ self .callback_agent_response = conversation .callback_agent_response
639+ self .callback_agent_response_correction = conversation .callback_agent_response_correction
640+ self .callback_user_transcript = conversation .callback_user_transcript
641+ self .callback_latency_measurement = conversation .callback_latency_measurement
642+
643+ def handle_audio_output (self , audio ):
644+ self .conversation .audio_interface .output (audio )
645+
646+ def handle_agent_response (self , response ):
647+ self .conversation .callback_agent_response (response )
648+
649+ def handle_agent_response_correction (self , original , corrected ):
650+ self .conversation .callback_agent_response_correction (original , corrected )
651+
652+ def handle_user_transcript (self , transcript ):
653+ self .conversation .callback_user_transcript (transcript )
654+
655+ def handle_interruption (self ):
656+ self .conversation .audio_interface .interrupt ()
657+
658+ def handle_ping (self , event ):
659+ self .ws .send (
660+ json .dumps (
661+ {
662+ "type" : "pong" ,
663+ "event_id" : event ["event_id" ],
664+ }
665+ )
564666 )
565- )
566- if self .callback_latency_measurement and event ["ping_ms" ]:
567- self .callback_latency_measurement (int (event ["ping_ms" ]))
568- elif message ["type" ] == "client_tool_call" :
569- tool_call = message .get ("client_tool_call" , {})
570- tool_name = tool_call .get ("tool_name" )
571- parameters = {"tool_call_id" : tool_call ["tool_call_id" ], ** tool_call .get ("parameters" , {})}
572-
573- def send_response (response ):
574- if not self ._should_stop .is_set ():
575- ws .send (json .dumps (response ))
576-
577- self .client_tools .execute_tool (tool_name , parameters , send_response )
578- else :
579- pass # Ignore all other message types.
667+
668+ def handle_latency_measurement (self , latency ):
669+ self .conversation .callback_latency_measurement (latency )
670+
671+ def handle_client_tool_call (self , tool_name , parameters ):
672+ def send_response (response ):
673+ if not self .conversation ._should_stop .is_set ():
674+ self .ws .send (json .dumps (response ))
675+
676+ self .conversation .client_tools .execute_tool (tool_name , parameters , send_response )
677+
678+ handler = SyncMessageHandler (self , ws )
679+ self ._handle_message_core (message , handler )
580680
581681
582682class AsyncConversation (BaseConversation ):
@@ -779,56 +879,51 @@ async def input_callback(audio):
779879 self ._ws = None
780880
781881 async def _handle_message (self , message , ws ):
782- if message ["type" ] == "conversation_initiation_metadata" :
783- event = message ["conversation_initiation_metadata_event" ]
784- assert self ._conversation_id is None
785- self ._conversation_id = event ["conversation_id" ]
786-
787- elif message ["type" ] == "audio" :
788- event = message ["audio_event" ]
789- if int (event ["event_id" ]) <= self ._last_interrupt_id :
790- return
791- audio = base64 .b64decode (event ["audio_base_64" ])
792- await self .audio_interface .output (audio )
793- elif message ["type" ] == "agent_response" :
794- if self .callback_agent_response :
795- event = message ["agent_response_event" ]
796- await self .callback_agent_response (event ["agent_response" ].strip ())
797- elif message ["type" ] == "agent_response_correction" :
798- if self .callback_agent_response_correction :
799- event = message ["agent_response_correction_event" ]
800- await self .callback_agent_response_correction (
801- event ["original_agent_response" ].strip (), event ["corrected_agent_response" ].strip ()
802- )
803- elif message ["type" ] == "user_transcript" :
804- if self .callback_user_transcript :
805- event = message ["user_transcription_event" ]
806- await self .callback_user_transcript (event ["user_transcript" ].strip ())
807- elif message ["type" ] == "interruption" :
808- event = message ["interruption_event" ]
809- self ._last_interrupt_id = int (event ["event_id" ])
810- await self .audio_interface .interrupt ()
811- elif message ["type" ] == "ping" :
812- event = message ["ping_event" ]
813- await ws .send (
814- json .dumps (
815- {
816- "type" : "pong" ,
817- "event_id" : event ["event_id" ],
818- }
882+ class AsyncMessageHandler :
883+ def __init__ (self , conversation , ws ):
884+ self .conversation = conversation
885+ self .ws = ws
886+ self .callback_agent_response = conversation .callback_agent_response
887+ self .callback_agent_response_correction = conversation .callback_agent_response_correction
888+ self .callback_user_transcript = conversation .callback_user_transcript
889+ self .callback_latency_measurement = conversation .callback_latency_measurement
890+
891+ async def handle_audio_output (self , audio ):
892+ await self .conversation .audio_interface .output (audio )
893+
894+ async def handle_agent_response (self , response ):
895+ await self .conversation .callback_agent_response (response )
896+
897+ async def handle_agent_response_correction (self , original , corrected ):
898+ await self .conversation .callback_agent_response_correction (original , corrected )
899+
900+ async def handle_user_transcript (self , transcript ):
901+ await self .conversation .callback_user_transcript (transcript )
902+
903+ async def handle_interruption (self ):
904+ await self .conversation .audio_interface .interrupt ()
905+
906+ async def handle_ping (self , event ):
907+ await self .ws .send (
908+ json .dumps (
909+ {
910+ "type" : "pong" ,
911+ "event_id" : event ["event_id" ],
912+ }
913+ )
819914 )
820- )
821- if self . callback_latency_measurement and event [ "ping_ms" ] :
822- await self .callback_latency_measurement (int ( event [ "ping_ms" ]) )
823- elif message [ "type" ] == "client_tool_call" :
824- tool_call = message . get ( "client_tool_call" , {})
825- tool_name = tool_call . get ( "tool_name" )
826- parameters = { "tool_call_id" : tool_call [ "tool_call_id" ], ** tool_call . get ( "parameters" , {})}
827-
828- def send_response ( response ):
829- if not self ._should_stop . is_set ():
830- asyncio . create_task ( ws . send ( json . dumps ( response )))
831-
832- self . client_tools . execute_tool ( tool_name , parameters , send_response )
833- else :
834- pass # Ignore all other message types.
915+
916+ async def handle_latency_measurement ( self , latency ) :
917+ await self .conversation . callback_latency_measurement (latency )
918+
919+ def handle_client_tool_call ( self , tool_name , parameters ):
920+ def send_response ( response ):
921+ if not self . conversation . _should_stop . is_set ():
922+ asyncio . create_task ( self . ws . send ( json . dumps ( response )))
923+
924+ self .conversation . client_tools . execute_tool ( tool_name , parameters , send_response )
925+
926+ handler = AsyncMessageHandler ( self , ws )
927+
928+ # Use the shared core message handling logic with async wrapper
929+ await self . _handle_message_core_async ( message , handler )
0 commit comments