Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/elevenlabs/realtime/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import typing
from enum import Enum

if typing.TYPE_CHECKING:
from websockets.asyncio.client import ClientConnection


class RealtimeEvents(str, Enum):
"""Events emitted by the RealtimeConnection"""
Expand Down Expand Up @@ -55,7 +58,7 @@ class RealtimeConnection:
```
"""

def __init__(self, websocket, current_sample_rate: int, ffmpeg_process: typing.Optional[subprocess.Popen] = None):
def __init__(self, websocket: "ClientConnection", current_sample_rate: int, ffmpeg_process: typing.Optional[subprocess.Popen] = None):
Comment thread
PaulAsjes marked this conversation as resolved.
self.websocket = websocket
self.current_sample_rate = current_sample_rate
self.ffmpeg_process = ffmpeg_process
Expand Down
48 changes: 31 additions & 17 deletions src/elevenlabs/realtime/scribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import subprocess
import typing
from enum import Enum
from typing import overload

from typing_extensions import Required

try:
import websockets
from websockets.asyncio.client import connect as websocket_connect
except ImportError:
raise ImportError(
"The websockets package is required for realtime speech-to-text. "
Expand All @@ -17,10 +20,13 @@

class AudioFormat(str, Enum):
Comment thread
PaulAsjes marked this conversation as resolved.
"""Audio format options for realtime transcription"""
PCM_8000 = "pcm_8000"
PCM_16000 = "pcm_16000"
PCM_22050 = "pcm_22050"
PCM_24000 = "pcm_24000"
PCM_44100 = "pcm_44100"
PCM_48000 = "pcm_48000"
ULAW_8000 = "ulaw_8000"


class CommitStrategy(str, Enum):
Expand Down Expand Up @@ -50,9 +56,9 @@ class RealtimeAudioOptions(typing.TypedDict, total=False):
language_code: An ISO-639-1 or ISO-639-3 language_code corresponding to the language of the audio file. Can sometimes improve transcription performance if known beforehand.
include_timestamps: Whether to receive the committed_transcript_with_timestamps event after committing the segment (optional, defaults to False)
"""
model_id: str
audio_format: AudioFormat
sample_rate: int
model_id: Required[str]
audio_format: Required[AudioFormat]
sample_rate: Required[int]
commit_strategy: CommitStrategy
vad_silence_threshold_secs: float
vad_threshold: float
Expand All @@ -77,8 +83,8 @@ class RealtimeUrlOptions(typing.TypedDict, total=False):
language_code: An ISO-639-1 or ISO-639-3 language_code corresponding to the language of the audio file. Can sometimes improve transcription performance if known beforehand.
include_timestamps: Whether to receive the committed_transcript_with_timestamps event after committing the segment (optional, defaults to False)
"""
model_id: str
url: str
model_id: Required[str]
url: Required[str]
commit_strategy: CommitStrategy
vad_silence_threshold_secs: float
vad_threshold: float
Expand Down Expand Up @@ -121,6 +127,18 @@ def __init__(self, api_key: str, base_url: str = "wss://api.elevenlabs.io"):
self.api_key = api_key
self.base_url = base_url

@overload
async def connect(
self,
options: RealtimeAudioOptions
) -> RealtimeConnection: ...

@overload
async def connect(
self,
options: RealtimeUrlOptions
) -> RealtimeConnection: ...

async def connect(
self,
options: typing.Union[RealtimeAudioOptions, RealtimeUrlOptions]
Expand Down Expand Up @@ -185,8 +203,7 @@ async def _connect_audio(self, options: RealtimeAudioOptions) -> RealtimeConnect
# Build WebSocket URL with query parameters
ws_url = self._build_websocket_url(
model_id=model_id,
encoding=audio_format.value,
sample_rate=sample_rate,
audio_format=audio_format.value,
commit_strategy=commit_strategy.value,
vad_silence_threshold_secs=vad_silence_threshold_secs,
vad_threshold=vad_threshold,
Expand All @@ -197,7 +214,7 @@ async def _connect_audio(self, options: RealtimeAudioOptions) -> RealtimeConnect
)

# Connect to WebSocket
websocket = await websockets.connect(
websocket = await websocket_connect(
ws_url,
additional_headers={"xi-api-key": self.api_key}
)
Expand Down Expand Up @@ -232,13 +249,12 @@ async def _connect_url(self, options: RealtimeUrlOptions) -> RealtimeConnection:

# Default to 16kHz for URL streaming
sample_rate = 16000
encoding = "pcm_16000"
audio_format = AudioFormat.PCM_16000

# Build WebSocket URL
ws_url = self._build_websocket_url(
model_id=model_id,
encoding=encoding,
sample_rate=sample_rate,
audio_format=audio_format,
commit_strategy=commit_strategy.value,
vad_silence_threshold_secs=vad_silence_threshold_secs,
vad_threshold=vad_threshold,
Expand All @@ -249,7 +265,7 @@ async def _connect_url(self, options: RealtimeUrlOptions) -> RealtimeConnection:
)

# Connect to WebSocket
websocket = await websockets.connect(
websocket = await websocket_connect(
ws_url,
additional_headers={"xi-api-key": self.api_key}
)
Expand Down Expand Up @@ -341,8 +357,7 @@ async def _stream_ffmpeg_to_websocket(self, connection: RealtimeConnection) -> N
def _build_websocket_url(
self,
model_id: str,
encoding: str,
sample_rate: int,
audio_format: str,
commit_strategy: str,
vad_silence_threshold_secs: typing.Optional[float] = None,
vad_threshold: typing.Optional[float] = None,
Expand All @@ -358,8 +373,7 @@ def _build_websocket_url(
# Build query parameters
params = [
f"model_id={model_id}",
f"encoding={encoding}",
f"sample_rate={sample_rate}",
f"audio_format={audio_format}",
f"commit_strategy={commit_strategy}"
]

Expand Down