diff --git a/src/elevenlabs/conversational_ai/conversation.py b/src/elevenlabs/conversational_ai/conversation.py index 8d39bc58..ac351b04 100644 --- a/src/elevenlabs/conversational_ai/conversation.py +++ b/src/elevenlabs/conversational_ai/conversation.py @@ -15,6 +15,7 @@ from websockets.sync.client import Connection, connect from ..base_client import BaseElevenLabs +from ..url_utils import build_ws_url from ..version import __version__ @@ -422,18 +423,14 @@ def _get_wss_url(self): return self.on_prem_config.on_prem_conversation_url base_http_url = self.client._client_wrapper.get_base_url() - base_ws_url = ( - urllib.parse.urlparse(base_http_url) - ._replace(scheme="wss" if base_http_url.startswith("https") else "ws") - .geturl() - ) - # Ensure base URL ends with '/' for proper joining - if not base_ws_url.endswith("/"): - base_ws_url += "/" - url = f"{base_ws_url}v1/convai/conversation?agent_id={self.agent_id}&source=python_sdk&version={__version__}" + params = [ + ("agent_id", self.agent_id), + ("source", "python_sdk"), + ("version", __version__), + ] if self.environment: - url += f"&environment={self.environment}" - return url + params.append(("environment", self.environment)) + return build_ws_url(base_http_url, ["v1", "convai", "conversation"], params) def _get_signed_url(self): response = self.client.conversational_ai.conversations.get_signed_url( @@ -442,8 +439,10 @@ def _get_signed_url(self): ) signed_url = response.signed_url # Append source and version query parameters to the signed URL - separator = "&" if "?" in signed_url else "?" - return f"{signed_url}{separator}source=python_sdk&version={__version__}" + parsed = urllib.parse.urlparse(signed_url) + existing_params = urllib.parse.parse_qsl(parsed.query, keep_blank_values=True) + existing_params.extend([("source", "python_sdk"), ("version", __version__)]) + return urllib.parse.urlunparse(parsed._replace(query=urllib.parse.urlencode(existing_params, quote_via=urllib.parse.quote))) def _create_on_prem_initiation_message(self): return json.dumps( diff --git a/src/elevenlabs/realtime/scribe.py b/src/elevenlabs/realtime/scribe.py index 61aaddbc..6409253d 100644 --- a/src/elevenlabs/realtime/scribe.py +++ b/src/elevenlabs/realtime/scribe.py @@ -15,6 +15,7 @@ "Install it with: pip install websockets" ) +from ..url_utils import build_ws_url from .connection import RealtimeConnection @@ -367,30 +368,22 @@ def _build_websocket_url( include_timestamps: typing.Optional[bool] = None ) -> str: """Build the WebSocket URL with query parameters""" - # Extract base domain - base = self.base_url.replace("https://", "wss://").replace("http://", "ws://") - - # Build query parameters params = [ - f"model_id={model_id}", - f"audio_format={audio_format}", - f"commit_strategy={commit_strategy}" + ("model_id", model_id), + ("audio_format", audio_format), + ("commit_strategy", commit_strategy), ] - - # Add optional VAD parameters - if vad_silence_threshold_secs is not None: - params.append(f"vad_silence_threshold_secs={vad_silence_threshold_secs}") - if vad_threshold is not None: - params.append(f"vad_threshold={vad_threshold}") - if min_speech_duration_ms is not None: - params.append(f"min_speech_duration_ms={min_speech_duration_ms}") - if min_silence_duration_ms is not None: - params.append(f"min_silence_duration_ms={min_silence_duration_ms}") - if language_code is not None: - params.append(f"language_code={language_code}") + for key, value in [ + ("vad_silence_threshold_secs", vad_silence_threshold_secs), + ("vad_threshold", vad_threshold), + ("min_speech_duration_ms", min_speech_duration_ms), + ("min_silence_duration_ms", min_silence_duration_ms), + ("language_code", language_code), + ]: + if value is not None: + params.append((key, str(value))) if include_timestamps is not None: - params.append(f"include_timestamps={str(include_timestamps).lower()}") + params.append(("include_timestamps", str(include_timestamps).lower())) - query_string = "&".join(params) - return f"{base}/v1/speech-to-text/realtime?{query_string}" + return build_ws_url(self.base_url, ["v1", "speech-to-text", "realtime"], params) diff --git a/src/elevenlabs/realtime_tts.py b/src/elevenlabs/realtime_tts.py index b23f2faf..9ab22de0 100644 --- a/src/elevenlabs/realtime_tts.py +++ b/src/elevenlabs/realtime_tts.py @@ -16,6 +16,7 @@ from .types.voice_settings import VoiceSettings from .text_to_speech.client import TextToSpeechClient from .types import OutputFormat +from .url_utils import build_ws_url # this is used as the default value for optional parameters OMIT = typing.cast(typing.Any, ...) @@ -92,9 +93,10 @@ def get_text() -> typing.Iterator[str]: ) """ with connect( - urllib.parse.urljoin( - self._ws_base_url, - f"v1/text-to-speech/{jsonable_encoder(voice_id)}/stream-input?model_id={model_id}&output_format={output_format}" + build_ws_url( + self._ws_base_url, + ["v1", "text-to-speech", voice_id, "stream-input"], + {"model_id": model_id, "output_format": output_format}, ), additional_headers=jsonable_encoder( remove_none_from_dict( diff --git a/src/elevenlabs/url_utils.py b/src/elevenlabs/url_utils.py new file mode 100644 index 00000000..dc7d48a7 --- /dev/null +++ b/src/elevenlabs/url_utils.py @@ -0,0 +1,28 @@ +import urllib.parse +from typing import Any, Sequence, Tuple, Union + + +_WS_SCHEME = {"https": "wss", "http": "ws"} + + +def build_ws_url( + base_url: str, + path_segments: Sequence[Any], + params: Union[Sequence[Tuple[str, str]], dict], +) -> str: + """Build a WebSocket URL with proper percent-encoding. + + Converts http(s) schemes to ws(s), appends percent-encoded + *path_segments* beneath the existing base path, and encodes + *params* as the query string. + """ + parsed = urllib.parse.urlparse(base_url) + path = "/".join(urllib.parse.quote(str(seg), safe="") for seg in path_segments) + return urllib.parse.urlunparse(( + _WS_SCHEME.get(parsed.scheme, parsed.scheme), + parsed.netloc, + parsed.path.rstrip("/") + "/" + path, + "", + urllib.parse.urlencode(params, quote_via=urllib.parse.quote), + "", + )) diff --git a/tests/test_url_utils.py b/tests/test_url_utils.py new file mode 100644 index 00000000..a5e86389 --- /dev/null +++ b/tests/test_url_utils.py @@ -0,0 +1,80 @@ +"""Tests for the build_ws_url utility.""" + +import pytest + +from elevenlabs.url_utils import build_ws_url + + +class TestSchemeConversion: + def test_https_to_wss(self): + url = build_ws_url("https://api.example.com", ["v1"], {}) + assert url.startswith("wss://") + + def test_http_to_ws(self): + url = build_ws_url("http://localhost:8080", ["v1"], {}) + assert url.startswith("ws://") + + def test_wss_preserved(self): + url = build_ws_url("wss://api.example.com", ["v1"], {}) + assert url.startswith("wss://") + + def test_ws_preserved(self): + url = build_ws_url("ws://localhost:8080", ["v1"], {}) + assert url.startswith("ws://") + + +class TestPathSegments: + def test_segments_joined(self): + url = build_ws_url("wss://api.example.com", ["v1", "speech", "realtime"], {}) + assert url == "wss://api.example.com/v1/speech/realtime" + + def test_segments_percent_encoded(self): + url = build_ws_url("wss://api.example.com", ["v1", "hello world"], {}) + assert "/v1/hello%20world" in url + + def test_special_characters_encoded(self): + url = build_ws_url("wss://api.example.com", ["v1", "a/b", "c?d", "e&f"], {}) + assert "/v1/a%2Fb/c%3Fd/e%26f" in url + + def test_non_string_segments_converted(self): + url = build_ws_url("wss://api.example.com", ["v1", 42, True], {}) + assert "/v1/42/True" in url + + def test_appended_to_existing_base_path(self): + url = build_ws_url("wss://api.example.com/base", ["v1", "endpoint"], {}) + assert url == "wss://api.example.com/base/v1/endpoint" + + def test_base_path_trailing_slash_not_duplicated(self): + url = build_ws_url("wss://api.example.com/base/", ["v1"], {}) + assert "//v1" not in url + assert "/base/v1" in url + + +class TestQueryParams: + def test_dict_params(self): + url = build_ws_url("wss://api.example.com", ["v1"], {"key": "value"}) + assert url.endswith("?key=value") + + def test_tuple_params(self): + url = build_ws_url("wss://api.example.com", ["v1"], [("a", "1"), ("b", "2")]) + assert "a=1" in url + assert "b=2" in url + + def test_params_percent_encoded(self): + url = build_ws_url("wss://api.example.com", ["v1"], {"term": "hello world"}) + assert "term=hello%20world" in url + + def test_repeated_keys(self): + url = build_ws_url("wss://api.example.com", ["v1"], [("k", "a"), ("k", "b")]) + assert "k=a" in url + assert "k=b" in url + + def test_empty_params(self): + url = build_ws_url("wss://api.example.com", ["v1"], {}) + assert url == "wss://api.example.com/v1" + + +class TestPortPreserved: + def test_custom_port(self): + url = build_ws_url("http://localhost:9090", ["v1"], {}) + assert url == "ws://localhost:9090/v1"