Skip to content

Commit 73d9ee4

Browse files
committed
address feedback
1 parent db21164 commit 73d9ee4

2 files changed

Lines changed: 61 additions & 55 deletions

File tree

README.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -198,22 +198,22 @@ from elevenlabs.conversational_ai.conversation import ClientTools
198198
async def main():
199199
# Get the current event loop
200200
custom_loop = asyncio.get_running_loop()
201-
201+
202202
# Create ClientTools with custom loop to prevent "different event loop" errors
203203
client_tools = ClientTools(loop=custom_loop)
204-
204+
205205
# Register your tools
206206
async def get_weather(params):
207207
location = params.get("location", "Unknown")
208208
# Your async logic here
209209
return f"Weather in {location}: Sunny, 72°F"
210-
210+
211211
client_tools.register("get_weather", get_weather, is_async=True)
212-
212+
213213
# Use with conversation
214214
conversation = Conversation(
215215
client=client,
216-
agent_id="your-agent-id",
216+
agent_id="your-agent-id",
217217
requires_auth=True,
218218
audio_interface=audio_interface,
219219
client_tools=client_tools
@@ -228,6 +228,9 @@ asyncio.run(main())
228228
- **Loop Management**: Prevent "Task got Future attached to a different event loop" errors
229229
- **Performance**: Better control over async task scheduling and execution
230230

231+
**Important:** When using a custom loop, you're responsible for its lifecycle
232+
Don't close the loop while ClientTools are still using it.
233+
231234
### Tool Registration
232235

233236
Register custom tools that the AI agent can call during conversations:
@@ -240,7 +243,7 @@ def calculate_sum(params):
240243
numbers = params.get("numbers", [])
241244
return sum(numbers)
242245

243-
# Async tool
246+
# Async tool
244247
async def fetch_data(params):
245248
url = params.get("url")
246249
# Your async HTTP request logic

src/elevenlabs/conversational_ai/conversation.py

Lines changed: 52 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ class ClientTools:
155155
156156
Supports both synchronous and asynchronous tools running in a dedicated event loop,
157157
ensuring non-blocking operation of the main conversation thread.
158-
158+
159159
Args:
160160
loop: Optional custom asyncio event loop to use for tool execution. If not provided,
161161
a new event loop will be created and run in a separate thread. Using a custom
@@ -276,12 +276,15 @@ async def _execute_and_callback():
276276
}
277277
callback(response)
278278

279+
self._schedule_coroutine(_execute_and_callback())
280+
281+
282+
def _schedule_coroutine(self, coro):
283+
"""Schedule a coroutine on the appropriate event loop."""
279284
if self._custom_loop is not None:
280-
# For custom loops, schedule the task on the custom loop
281-
self._loop.create_task(_execute_and_callback())
285+
return self._loop.create_task(coro)
282286
else:
283-
# For our own loop running in a separate thread, use run_coroutine_threadsafe
284-
asyncio.run_coroutine_threadsafe(_execute_and_callback(), self._loop)
287+
return asyncio.run_coroutine_threadsafe(coro, self._loop)
285288

286289

287290
class ConversationInitiationData:
@@ -302,7 +305,7 @@ def __init__(
302305

303306
class BaseConversation:
304307
"""Base class for conversation implementations with shared parameters and logic."""
305-
308+
306309
def __init__(
307310
self,
308311
client: BaseElevenLabs,
@@ -319,9 +322,9 @@ def __init__(
319322
self.requires_auth = requires_auth
320323
self.config = config or ConversationInitiationData()
321324
self.client_tools = client_tools or ClientTools()
322-
325+
323326
self.client_tools.start()
324-
327+
325328
self._conversation_id = None
326329
self._last_interrupt_id = 0
327330

@@ -353,7 +356,7 @@ def _create_initiation_message(self):
353356

354357
def _handle_message_core(self, message, message_handler):
355358
"""Core message handling logic shared between sync and async implementations.
356-
359+
357360
Args:
358361
message: The parsed message dictionary
359362
message_handler: Handler object with methods for different operations
@@ -369,36 +372,36 @@ def _handle_message_core(self, message, message_handler):
369372
return
370373
audio = base64.b64decode(event["audio_base_64"])
371374
message_handler.handle_audio_output(audio)
372-
375+
373376
elif message["type"] == "agent_response":
374377
if message_handler.callback_agent_response:
375378
event = message["agent_response_event"]
376379
message_handler.handle_agent_response(event["agent_response"].strip())
377-
380+
378381
elif message["type"] == "agent_response_correction":
379382
if message_handler.callback_agent_response_correction:
380383
event = message["agent_response_correction_event"]
381384
message_handler.handle_agent_response_correction(
382-
event["original_agent_response"].strip(),
385+
event["original_agent_response"].strip(),
383386
event["corrected_agent_response"].strip()
384387
)
385-
388+
386389
elif message["type"] == "user_transcript":
387390
if message_handler.callback_user_transcript:
388391
event = message["user_transcription_event"]
389392
message_handler.handle_user_transcript(event["user_transcript"].strip())
390-
393+
391394
elif message["type"] == "interruption":
392395
event = message["interruption_event"]
393396
self._last_interrupt_id = int(event["event_id"])
394397
message_handler.handle_interruption()
395-
398+
396399
elif message["type"] == "ping":
397400
event = message["ping_event"]
398401
message_handler.handle_ping(event)
399402
if message_handler.callback_latency_measurement and event["ping_ms"]:
400403
message_handler.handle_latency_measurement(int(event["ping_ms"]))
401-
404+
402405
elif message["type"] == "client_tool_call":
403406
tool_call = message.get("client_tool_call", {})
404407
tool_name = tool_call.get("tool_name")
@@ -420,36 +423,36 @@ async def _handle_message_core_async(self, message, message_handler):
420423
return
421424
audio = base64.b64decode(event["audio_base_64"])
422425
await message_handler.handle_audio_output(audio)
423-
426+
424427
elif message["type"] == "agent_response":
425428
if message_handler.callback_agent_response:
426429
event = message["agent_response_event"]
427430
await message_handler.handle_agent_response(event["agent_response"].strip())
428-
431+
429432
elif message["type"] == "agent_response_correction":
430433
if message_handler.callback_agent_response_correction:
431434
event = message["agent_response_correction_event"]
432435
await message_handler.handle_agent_response_correction(
433-
event["original_agent_response"].strip(),
436+
event["original_agent_response"].strip(),
434437
event["corrected_agent_response"].strip()
435438
)
436-
439+
437440
elif message["type"] == "user_transcript":
438441
if message_handler.callback_user_transcript:
439442
event = message["user_transcription_event"]
440443
await message_handler.handle_user_transcript(event["user_transcript"].strip())
441-
444+
442445
elif message["type"] == "interruption":
443446
event = message["interruption_event"]
444447
self._last_interrupt_id = int(event["event_id"])
445448
await message_handler.handle_interruption()
446-
449+
447450
elif message["type"] == "ping":
448451
event = message["ping_event"]
449452
await message_handler.handle_ping(event)
450453
if message_handler.callback_latency_measurement and event["ping_ms"]:
451454
await message_handler.handle_latency_measurement(int(event["ping_ms"]))
452-
455+
453456
elif message["type"] == "client_tool_call":
454457
tool_call = message.get("client_tool_call", {})
455458
tool_name = tool_call.get("tool_name")
@@ -514,7 +517,7 @@ def __init__(
514517
config=config,
515518
client_tools=client_tools,
516519
)
517-
520+
518521
self.audio_interface = audio_interface
519522
self.callback_agent_response = callback_agent_response
520523
self.callback_agent_response_correction = callback_agent_response_correction
@@ -663,22 +666,22 @@ def __init__(self, conversation, ws):
663666
self.callback_agent_response_correction = conversation.callback_agent_response_correction
664667
self.callback_user_transcript = conversation.callback_user_transcript
665668
self.callback_latency_measurement = conversation.callback_latency_measurement
666-
669+
667670
def handle_audio_output(self, audio):
668671
self.conversation.audio_interface.output(audio)
669-
672+
670673
def handle_agent_response(self, response):
671674
self.conversation.callback_agent_response(response)
672-
675+
673676
def handle_agent_response_correction(self, original, corrected):
674677
self.conversation.callback_agent_response_correction(original, corrected)
675-
678+
676679
def handle_user_transcript(self, transcript):
677680
self.conversation.callback_user_transcript(transcript)
678-
681+
679682
def handle_interruption(self):
680683
self.conversation.audio_interface.interrupt()
681-
684+
682685
def handle_ping(self, event):
683686
self.ws.send(
684687
json.dumps(
@@ -688,17 +691,17 @@ def handle_ping(self, event):
688691
}
689692
)
690693
)
691-
694+
692695
def handle_latency_measurement(self, latency):
693696
self.conversation.callback_latency_measurement(latency)
694-
697+
695698
def handle_client_tool_call(self, tool_name, parameters):
696699
def send_response(response):
697700
if not self.conversation._should_stop.is_set():
698701
self.ws.send(json.dumps(response))
699-
702+
700703
self.conversation.client_tools.execute_tool(tool_name, parameters, send_response)
701-
704+
702705
handler = SyncMessageHandler(self, ws)
703706
self._handle_message_core(message, handler)
704707

@@ -759,7 +762,7 @@ def __init__(
759762
config=config,
760763
client_tools=client_tools,
761764
)
762-
765+
763766
self.audio_interface = audio_interface
764767
self.callback_agent_response = callback_agent_response
765768
self.callback_agent_response_correction = callback_agent_response_correction
@@ -777,7 +780,7 @@ async def start_session(self):
777780
Will run in background task until `end_session` is called.
778781
"""
779782
ws_url = self._get_signed_url() if self.requires_auth else self._get_wss_url()
780-
self._task = asyncio.create_task(self._run(ws_url))
783+
self._task = self._schedule_coroutine(self._run(ws_url))
781784

782785
async def end_session(self):
783786
"""Ends the conversation session and cleans up resources."""
@@ -881,7 +884,7 @@ async def input_callback(audio):
881884
await self.end_session()
882885

883886
await self.audio_interface.start(input_callback)
884-
887+
885888
try:
886889
while not self._should_stop.is_set():
887890
try:
@@ -911,22 +914,22 @@ def __init__(self, conversation, ws):
911914
self.callback_agent_response_correction = conversation.callback_agent_response_correction
912915
self.callback_user_transcript = conversation.callback_user_transcript
913916
self.callback_latency_measurement = conversation.callback_latency_measurement
914-
917+
915918
async def handle_audio_output(self, audio):
916919
await self.conversation.audio_interface.output(audio)
917-
920+
918921
async def handle_agent_response(self, response):
919922
await self.conversation.callback_agent_response(response)
920-
923+
921924
async def handle_agent_response_correction(self, original, corrected):
922925
await self.conversation.callback_agent_response_correction(original, corrected)
923-
926+
924927
async def handle_user_transcript(self, transcript):
925928
await self.conversation.callback_user_transcript(transcript)
926-
929+
927930
async def handle_interruption(self):
928931
await self.conversation.audio_interface.interrupt()
929-
932+
930933
async def handle_ping(self, event):
931934
await self.ws.send(
932935
json.dumps(
@@ -936,18 +939,18 @@ async def handle_ping(self, event):
936939
}
937940
)
938941
)
939-
942+
940943
async def handle_latency_measurement(self, latency):
941944
await self.conversation.callback_latency_measurement(latency)
942-
945+
943946
def handle_client_tool_call(self, tool_name, parameters):
944947
def send_response(response):
945948
if not self.conversation._should_stop.is_set():
946-
asyncio.create_task(self.ws.send(json.dumps(response)))
947-
949+
self.conversation._schedule_coroutine(self.ws.send(json.dumps(response)))
950+
948951
self.conversation.client_tools.execute_tool(tool_name, parameters, send_response)
949-
952+
950953
handler = AsyncMessageHandler(self, ws)
951-
954+
952955
# Use the shared core message handling logic with async wrapper
953956
await self._handle_message_core_async(message, handler)

0 commit comments

Comments
 (0)