Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/elevenlabs/realtime/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import typing
from enum import Enum

if typing.TYPE_CHECKING:
from websockets.asyncio.client import ClientConnection


class RealtimeEvents(str, Enum):
"""Events emitted by the RealtimeConnection"""
Expand Down Expand Up @@ -55,7 +58,7 @@ class RealtimeConnection:
```
"""

def __init__(self, websocket, current_sample_rate: int, ffmpeg_process: typing.Optional[subprocess.Popen] = None):
def __init__(self, websocket: "ClientConnection", current_sample_rate: int, ffmpeg_process: typing.Optional[subprocess.Popen] = None):
Comment thread
PaulAsjes marked this conversation as resolved.
self.websocket = websocket
self.current_sample_rate = current_sample_rate
self.ffmpeg_process = ffmpeg_process
Expand Down
48 changes: 31 additions & 17 deletions src/elevenlabs/realtime/scribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import subprocess
import typing
from enum import Enum
from typing import overload

from typing_extensions import Required

try:
import websockets
from websockets.asyncio.client import connect as websocket_connect
except ImportError:
raise ImportError(
"The websockets package is required for realtime speech-to-text. "
Expand All @@ -17,10 +20,13 @@

class AudioFormat(str, Enum):
Comment thread
PaulAsjes marked this conversation as resolved.
"""Audio format options for realtime transcription"""
PCM_8000 = "pcm_8000"
PCM_16000 = "pcm_16000"
PCM_22050 = "pcm_22050"
PCM_24000 = "pcm_24000"
PCM_44100 = "pcm_44100"
PCM_48000 = "pcm_48000"
ULAW_8000 = "ulaw_8000"


class CommitStrategy(str, Enum):
Expand Down Expand Up @@ -50,9 +56,9 @@ class RealtimeAudioOptions(typing.TypedDict, total=False):
language_code: An ISO-639-1 or ISO-639-3 language_code corresponding to the language of the audio file. Can sometimes improve transcription performance if known beforehand.
include_timestamps: Whether to receive the committed_transcript_with_timestamps event after committing the segment (optional, defaults to False)
"""
model_id: str
audio_format: AudioFormat
sample_rate: int
model_id: Required[str]
audio_format: Required[AudioFormat]
sample_rate: Required[int]
commit_strategy: CommitStrategy
vad_silence_threshold_secs: float
vad_threshold: float
Expand All @@ -77,8 +83,8 @@ class RealtimeUrlOptions(typing.TypedDict, total=False):
language_code: An ISO-639-1 or ISO-639-3 language_code corresponding to the language of the audio file. Can sometimes improve transcription performance if known beforehand.
include_timestamps: Whether to receive the committed_transcript_with_timestamps event after committing the segment (optional, defaults to False)
"""
model_id: str
url: str
model_id: Required[str]
url: Required[str]
commit_strategy: CommitStrategy
vad_silence_threshold_secs: float
vad_threshold: float
Expand Down Expand Up @@ -121,6 +127,18 @@ def __init__(self, api_key: str, base_url: str = "wss://api.elevenlabs.io"):
self.api_key = api_key
self.base_url = base_url

@overload
async def connect(
self,
options: RealtimeAudioOptions
) -> RealtimeConnection: ...

@overload
async def connect(
self,
options: RealtimeUrlOptions
) -> RealtimeConnection: ...

async def connect(
self,
options: typing.Union[RealtimeAudioOptions, RealtimeUrlOptions]
Expand Down Expand Up @@ -185,8 +203,7 @@ async def _connect_audio(self, options: RealtimeAudioOptions) -> RealtimeConnect
# Build WebSocket URL with query parameters
ws_url = self._build_websocket_url(
model_id=model_id,
encoding=audio_format.value,
sample_rate=sample_rate,
audio_format=audio_format.value,
commit_strategy=commit_strategy.value,
vad_silence_threshold_secs=vad_silence_threshold_secs,
vad_threshold=vad_threshold,
Expand All @@ -197,7 +214,7 @@ async def _connect_audio(self, options: RealtimeAudioOptions) -> RealtimeConnect
)

# Connect to WebSocket
websocket = await websockets.connect(
websocket = await websocket_connect(
ws_url,
additional_headers={"xi-api-key": self.api_key}
)
Expand Down Expand Up @@ -232,13 +249,12 @@ async def _connect_url(self, options: RealtimeUrlOptions) -> RealtimeConnection:

# Default to 16kHz for URL streaming
sample_rate = 16000
encoding = "pcm_16000"
audio_format = "pcm_16000"
Comment thread
PaulAsjes marked this conversation as resolved.
Outdated

# Build WebSocket URL
ws_url = self._build_websocket_url(
model_id=model_id,
encoding=encoding,
sample_rate=sample_rate,
audio_format=audio_format,
commit_strategy=commit_strategy.value,
vad_silence_threshold_secs=vad_silence_threshold_secs,
vad_threshold=vad_threshold,
Expand All @@ -249,7 +265,7 @@ async def _connect_url(self, options: RealtimeUrlOptions) -> RealtimeConnection:
)

# Connect to WebSocket
websocket = await websockets.connect(
websocket = await websocket_connect(
ws_url,
additional_headers={"xi-api-key": self.api_key}
)
Expand Down Expand Up @@ -341,8 +357,7 @@ async def _stream_ffmpeg_to_websocket(self, connection: RealtimeConnection) -> N
def _build_websocket_url(
self,
model_id: str,
encoding: str,
sample_rate: int,
audio_format: str,
commit_strategy: str,
vad_silence_threshold_secs: typing.Optional[float] = None,
vad_threshold: typing.Optional[float] = None,
Expand All @@ -358,8 +373,7 @@ def _build_websocket_url(
# Build query parameters
params = [
f"model_id={model_id}",
f"encoding={encoding}",
f"sample_rate={sample_rate}",
f"audio_format={audio_format}",
f"commit_strategy={commit_strategy}"
]

Expand Down