Skip to content

Commit e1dc0f5

Browse files
committed
angelo fix tests
1 parent 90246f1 commit e1dc0f5

1 file changed

Lines changed: 58 additions & 12 deletions

File tree

src/elevenlabs/conversational_ai/conversation.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -275,39 +275,80 @@ def __init__(
275275
self.user_id = user_id
276276

277277

278-
class MessageHandler(Protocol):
279-
"""Protocol defining the interface for message handlers."""
278+
class SyncMessageHandler(Protocol):
279+
"""Protocol defining the interface for sync message handlers."""
280280

281281
callback_agent_response: Optional[Callable]
282282
callback_agent_response_correction: Optional[Callable]
283283
callback_user_transcript: Optional[Callable]
284284
callback_latency_measurement: Optional[Callable]
285285

286-
def handle_audio_output(self, audio: bytes) -> Union[None, Awaitable[None]]:
286+
def handle_audio_output(self, audio: bytes) -> None:
287287
"""Handle audio output."""
288288
...
289289

290-
def handle_agent_response(self, response: str) -> Union[None, Awaitable[None]]:
290+
def handle_agent_response(self, response: str) -> None:
291291
"""Handle agent response."""
292292
...
293293

294-
def handle_agent_response_correction(self, original: str, corrected: str) -> Union[None, Awaitable[None]]:
294+
def handle_agent_response_correction(self, original: str, corrected: str) -> None:
295295
"""Handle agent response correction."""
296296
...
297297

298-
def handle_user_transcript(self, transcript: str) -> Union[None, Awaitable[None]]:
298+
def handle_user_transcript(self, transcript: str) -> None:
299299
"""Handle user transcript."""
300300
...
301301

302-
def handle_interruption(self) -> Union[None, Awaitable[None]]:
302+
def handle_interruption(self) -> None:
303303
"""Handle interruption."""
304304
...
305305

306-
def handle_ping(self, event: Dict[str, Any]) -> Union[None, Awaitable[None]]:
306+
def handle_ping(self, event: Dict[str, Any]) -> None:
307307
"""Handle ping event."""
308308
...
309309

310-
def handle_latency_measurement(self, latency: int) -> Union[None, Awaitable[None]]:
310+
def handle_latency_measurement(self, latency: int) -> None:
311+
"""Handle latency measurement."""
312+
...
313+
314+
def handle_client_tool_call(self, tool_name: str, parameters: Dict[str, Any]) -> None:
315+
"""Handle client tool call."""
316+
...
317+
318+
319+
class AsyncMessageHandler(Protocol):
320+
"""Protocol defining the interface for async message handlers."""
321+
322+
callback_agent_response: Optional[Callable]
323+
callback_agent_response_correction: Optional[Callable]
324+
callback_user_transcript: Optional[Callable]
325+
callback_latency_measurement: Optional[Callable]
326+
327+
async def handle_audio_output(self, audio: bytes) -> None:
328+
"""Handle audio output."""
329+
...
330+
331+
async def handle_agent_response(self, response: str) -> None:
332+
"""Handle agent response."""
333+
...
334+
335+
async def handle_agent_response_correction(self, original: str, corrected: str) -> None:
336+
"""Handle agent response correction."""
337+
...
338+
339+
async def handle_user_transcript(self, transcript: str) -> None:
340+
"""Handle user transcript."""
341+
...
342+
343+
async def handle_interruption(self) -> None:
344+
"""Handle interruption."""
345+
...
346+
347+
async def handle_ping(self, event: Dict[str, Any]) -> None:
348+
"""Handle ping event."""
349+
...
350+
351+
async def handle_latency_measurement(self, latency: int) -> None:
311352
"""Handle latency measurement."""
312353
...
313354

@@ -340,7 +381,7 @@ def _send_response(self, response: Dict[str, Any]) -> None:
340381
raise NotImplementedError
341382

342383

343-
class BaseConversation:
384+
class BaseConversation(ABC):
344385
"""Base class for conversation implementations with shared parameters and logic."""
345386

346387
def __init__(
@@ -397,6 +438,11 @@ async def callback(audio: bytes) -> None:
397438
await self.end_session()
398439
return callback
399440

441+
@abstractmethod
442+
def end_session(self) -> Union[None, Awaitable[None]]:
443+
"""End the conversation session - to be implemented by subclasses."""
444+
pass
445+
400446
def _handle_connection_closed(self) -> Union[None, Awaitable[None]]:
401447
"""Handle WebSocket connection closed - to be implemented by subclasses."""
402448
raise NotImplementedError
@@ -427,7 +473,7 @@ def _create_initiation_message(self):
427473
}
428474
)
429475

430-
def _handle_message_core(self, message: Dict[str, Any], message_handler: MessageHandler) -> None:
476+
def _handle_message_core(self, message: Dict[str, Any], message_handler: SyncMessageHandler) -> None:
431477
"""Core message handling logic shared between sync and async implementations.
432478
433479
Args:
@@ -483,7 +529,7 @@ def _handle_message_core(self, message: Dict[str, Any], message_handler: Message
483529
else:
484530
pass # Ignore all other message types.
485531

486-
async def _handle_message_core_async(self, message: Dict[str, Any], message_handler: MessageHandler) -> None:
532+
async def _handle_message_core_async(self, message: Dict[str, Any], message_handler: AsyncMessageHandler) -> None:
487533
"""Async wrapper for core message handling logic."""
488534
if message["type"] == "conversation_initiation_metadata":
489535
event = message["conversation_initiation_metadata_event"]

0 commit comments

Comments
 (0)