diff --git a/src/elevenlabs/realtime/connection.py b/src/elevenlabs/realtime/connection.py index b54eb42a..ef2993d0 100644 --- a/src/elevenlabs/realtime/connection.py +++ b/src/elevenlabs/realtime/connection.py @@ -4,6 +4,9 @@ import typing from enum import Enum +if typing.TYPE_CHECKING: + from websockets.asyncio.client import ClientConnection + class RealtimeEvents(str, Enum): """Events emitted by the RealtimeConnection""" @@ -55,7 +58,7 @@ class RealtimeConnection: ``` """ - def __init__(self, websocket, current_sample_rate: int, ffmpeg_process: typing.Optional[subprocess.Popen] = None): + def __init__(self, websocket: "ClientConnection", current_sample_rate: int, ffmpeg_process: typing.Optional[subprocess.Popen] = None): self.websocket = websocket self.current_sample_rate = current_sample_rate self.ffmpeg_process = ffmpeg_process diff --git a/src/elevenlabs/realtime/scribe.py b/src/elevenlabs/realtime/scribe.py index 47dddab8..a509c994 100644 --- a/src/elevenlabs/realtime/scribe.py +++ b/src/elevenlabs/realtime/scribe.py @@ -3,9 +3,12 @@ import subprocess import typing from enum import Enum +from typing import overload + +from typing_extensions import Required try: - import websockets + from websockets.asyncio.client import connect as websocket_connect except ImportError: raise ImportError( "The websockets package is required for realtime speech-to-text. " @@ -17,10 +20,13 @@ class AudioFormat(str, Enum): """Audio format options for realtime transcription""" + PCM_8000 = "pcm_8000" PCM_16000 = "pcm_16000" PCM_22050 = "pcm_22050" PCM_24000 = "pcm_24000" PCM_44100 = "pcm_44100" + PCM_48000 = "pcm_48000" + ULAW_8000 = "ulaw_8000" class CommitStrategy(str, Enum): @@ -50,9 +56,9 @@ class RealtimeAudioOptions(typing.TypedDict, total=False): language_code: An ISO-639-1 or ISO-639-3 language_code corresponding to the language of the audio file. Can sometimes improve transcription performance if known beforehand. include_timestamps: Whether to receive the committed_transcript_with_timestamps event after committing the segment (optional, defaults to False) """ - model_id: str - audio_format: AudioFormat - sample_rate: int + model_id: Required[str] + audio_format: Required[AudioFormat] + sample_rate: Required[int] commit_strategy: CommitStrategy vad_silence_threshold_secs: float vad_threshold: float @@ -77,8 +83,8 @@ class RealtimeUrlOptions(typing.TypedDict, total=False): language_code: An ISO-639-1 or ISO-639-3 language_code corresponding to the language of the audio file. Can sometimes improve transcription performance if known beforehand. include_timestamps: Whether to receive the committed_transcript_with_timestamps event after committing the segment (optional, defaults to False) """ - model_id: str - url: str + model_id: Required[str] + url: Required[str] commit_strategy: CommitStrategy vad_silence_threshold_secs: float vad_threshold: float @@ -121,6 +127,18 @@ def __init__(self, api_key: str, base_url: str = "wss://api.elevenlabs.io"): self.api_key = api_key self.base_url = base_url + @overload + async def connect( + self, + options: RealtimeAudioOptions + ) -> RealtimeConnection: ... + + @overload + async def connect( + self, + options: RealtimeUrlOptions + ) -> RealtimeConnection: ... + async def connect( self, options: typing.Union[RealtimeAudioOptions, RealtimeUrlOptions] @@ -185,8 +203,7 @@ async def _connect_audio(self, options: RealtimeAudioOptions) -> RealtimeConnect # Build WebSocket URL with query parameters ws_url = self._build_websocket_url( model_id=model_id, - encoding=audio_format.value, - sample_rate=sample_rate, + audio_format=audio_format.value, commit_strategy=commit_strategy.value, vad_silence_threshold_secs=vad_silence_threshold_secs, vad_threshold=vad_threshold, @@ -197,7 +214,7 @@ async def _connect_audio(self, options: RealtimeAudioOptions) -> RealtimeConnect ) # Connect to WebSocket - websocket = await websockets.connect( + websocket = await websocket_connect( ws_url, additional_headers={"xi-api-key": self.api_key} ) @@ -232,13 +249,12 @@ async def _connect_url(self, options: RealtimeUrlOptions) -> RealtimeConnection: # Default to 16kHz for URL streaming sample_rate = 16000 - encoding = "pcm_16000" + audio_format = AudioFormat.PCM_16000 # Build WebSocket URL ws_url = self._build_websocket_url( model_id=model_id, - encoding=encoding, - sample_rate=sample_rate, + audio_format=audio_format, commit_strategy=commit_strategy.value, vad_silence_threshold_secs=vad_silence_threshold_secs, vad_threshold=vad_threshold, @@ -249,7 +265,7 @@ async def _connect_url(self, options: RealtimeUrlOptions) -> RealtimeConnection: ) # Connect to WebSocket - websocket = await websockets.connect( + websocket = await websocket_connect( ws_url, additional_headers={"xi-api-key": self.api_key} ) @@ -341,8 +357,7 @@ async def _stream_ffmpeg_to_websocket(self, connection: RealtimeConnection) -> N def _build_websocket_url( self, model_id: str, - encoding: str, - sample_rate: int, + audio_format: str, commit_strategy: str, vad_silence_threshold_secs: typing.Optional[float] = None, vad_threshold: typing.Optional[float] = None, @@ -358,8 +373,7 @@ def _build_websocket_url( # Build query parameters params = [ f"model_id={model_id}", - f"encoding={encoding}", - f"sample_rate={sample_rate}", + f"audio_format={audio_format}", f"commit_strategy={commit_strategy}" ]