Skip to content

Commit 949bb68

Browse files
authored
Fix url generation (#631)
* fix url generation * move import * move these back * fix tests * make ruff happy
1 parent 0430688 commit 949bb68

3 files changed

Lines changed: 143 additions & 35 deletions

File tree

src/elevenlabs/conversational_ai/conversation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import asyncio
77
from concurrent.futures import ThreadPoolExecutor
88
from enum import Enum
9+
import urllib.parse
910

1011
from websockets.sync.client import connect, Connection
1112
import websockets
@@ -329,8 +330,11 @@ def __init__(
329330

330331
def _get_wss_url(self):
331332
base_http_url = self.client._client_wrapper.get_base_url()
332-
base_ws_url = base_http_url.replace("https://", "wss://").replace("http://", "ws://")
333-
return f"{base_ws_url}/v1/convai/conversation?agent_id={self.agent_id}&source=python_sdk&version={__version__}"
333+
base_ws_url = urllib.parse.urlparse(base_http_url)._replace(scheme="wss" if base_http_url.startswith("https") else "ws").geturl()
334+
# Ensure base URL ends with '/' for proper joining
335+
if not base_ws_url.endswith('/'):
336+
base_ws_url += '/'
337+
return f"{base_ws_url}v1/convai/conversation?agent_id={self.agent_id}&source=python_sdk&version={__version__}"
334338

335339
def _get_signed_url(self):
336340
response = self.client.conversational_ai.conversations.get_signed_url(agent_id=self.agent_id)

tests/test_async_convai.py

Lines changed: 66 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import asyncio
22
import json
33
from unittest.mock import AsyncMock, MagicMock, patch
4+
45
import pytest
56

67
from elevenlabs.conversational_ai.conversation import (
7-
AsyncConversation,
88
AsyncAudioInterface,
9+
AsyncConversation,
910
ConversationInitiationData,
1011
)
1112

@@ -45,7 +46,6 @@ def create_mock_async_websocket(messages=None):
4546

4647
# Convert messages to JSON strings
4748
json_messages = [json.dumps(msg) for msg in messages]
48-
json_messages.extend(['{"type": "keep_alive"}'] * 10) # Add some keep-alive messages
4949

5050
# Create an iterator
5151
message_iter = iter(json_messages)
@@ -54,8 +54,9 @@ async def mock_recv():
5454
try:
5555
return next(message_iter)
5656
except StopIteration:
57-
# Simulate connection close after messages
58-
raise asyncio.TimeoutError()
57+
# After all messages, simulate timeout by sleeping forever
58+
# This will be caught by asyncio.wait_for timeout in the conversation
59+
await asyncio.sleep(float("inf"))
5960

6061
mock_ws.recv = mock_recv
6162
return mock_ws
@@ -66,6 +67,7 @@ async def test_async_conversation_basic_flow():
6667
# Mock setup
6768
mock_ws = create_mock_async_websocket()
6869
mock_client = MagicMock()
70+
mock_client._client_wrapper.get_base_url.return_value = "https://api.elevenlabs.io"
6971
agent_response_callback = AsyncMock()
7072
test_user_id = "test_user_123"
7173

@@ -94,7 +96,7 @@ async def test_async_conversation_basic_flow():
9496

9597
# Assertions - check the call was made with the right structure
9698
send_calls = [call[0][0] for call in mock_ws.send.call_args_list]
97-
init_messages = [json.loads(call) for call in send_calls if 'conversation_initiation_client_data' in call]
99+
init_messages = [json.loads(call) for call in send_calls if "conversation_initiation_client_data" in call]
98100
assert len(init_messages) == 1
99101
init_message = init_messages[0]
100102

@@ -148,6 +150,7 @@ async def test_async_conversation_with_dynamic_variables():
148150
# Mock setup
149151
mock_ws = create_mock_async_websocket()
150152
mock_client = MagicMock()
153+
mock_client._client_wrapper.get_base_url.return_value = "https://api.elevenlabs.io"
151154
agent_response_callback = AsyncMock()
152155

153156
dynamic_variables = {"name": "angelo"}
@@ -177,7 +180,7 @@ async def test_async_conversation_with_dynamic_variables():
177180

178181
# Assertions - check the call was made with the right structure
179182
send_calls = [call[0][0] for call in mock_ws.send.call_args_list]
180-
init_messages = [json.loads(call) for call in send_calls if 'conversation_initiation_client_data' in call]
183+
init_messages = [json.loads(call) for call in send_calls if "conversation_initiation_client_data" in call]
181184
assert len(init_messages) == 1
182185
init_message = init_messages[0]
183186

@@ -196,6 +199,7 @@ async def test_async_conversation_with_contextual_update():
196199
# Mock setup
197200
mock_ws = create_mock_async_websocket([])
198201
mock_client = MagicMock()
202+
mock_client._client_wrapper.get_base_url.return_value = "https://api.elevenlabs.io"
199203

200204
# Setup the conversation
201205
conversation = AsyncConversation(
@@ -228,6 +232,7 @@ async def test_async_conversation_send_user_message():
228232
# Mock setup
229233
mock_ws = create_mock_async_websocket([])
230234
mock_client = MagicMock()
235+
mock_client._client_wrapper.get_base_url.return_value = "https://api.elevenlabs.io"
231236

232237
# Setup the conversation
233238
conversation = AsyncConversation(
@@ -260,6 +265,7 @@ async def test_async_conversation_register_user_activity():
260265
# Mock setup
261266
mock_ws = create_mock_async_websocket([])
262267
mock_client = MagicMock()
268+
mock_client._client_wrapper.get_base_url.return_value = "https://api.elevenlabs.io"
263269

264270
# Setup the conversation
265271
conversation = AsyncConversation(
@@ -300,29 +306,21 @@ async def test_async_conversation_callback_flows():
300306
"type": "agent_response_correction",
301307
"agent_response_correction_event": {
302308
"original_agent_response": "Hello ther!",
303-
"corrected_agent_response": "Hello there!"
304-
}
305-
},
306-
{
307-
"type": "user_transcript",
308-
"user_transcription_event": {"user_transcript": "Hi, how are you?"}
309-
},
310-
{
311-
"type": "ping",
312-
"ping_event": {"event_id": "123", "ping_ms": 50}
313-
},
314-
{
315-
"type": "interruption",
316-
"interruption_event": {"event_id": "456"}
309+
"corrected_agent_response": "Hello there!",
310+
},
317311
},
312+
{"type": "user_transcript", "user_transcription_event": {"user_transcript": "Hi, how are you?"}},
313+
{"type": "ping", "ping_event": {"event_id": "123", "ping_ms": 50}},
314+
{"type": "interruption", "interruption_event": {"event_id": "456"}},
318315
{
319316
"type": "audio",
320-
"audio_event": {"event_id": "789", "audio_base_64": "dGVzdA=="} # "test" in base64
321-
}
317+
"audio_event": {"event_id": "789", "audio_base_64": "dGVzdA=="}, # "test" in base64
318+
},
322319
]
323320

324321
mock_ws = create_mock_async_websocket(messages)
325322
mock_client = MagicMock()
323+
mock_client._client_wrapper.get_base_url.return_value = "https://api.elevenlabs.io"
326324

327325
# Setup callbacks
328326
agent_response_callback = AsyncMock()
@@ -368,7 +366,6 @@ async def test_async_conversation_callback_flows():
368366

369367
@pytest.mark.asyncio
370368
async def test_async_conversation_wss_url_generation_without_get_environment():
371-
372369
from elevenlabs.core.client_wrapper import SyncClientWrapper
373370

374371
# Test with various base URL formats to ensure robustness
@@ -383,24 +380,20 @@ async def test_async_conversation_wss_url_generation_without_get_environment():
383380
# Create a real SyncClientWrapper to ensure it doesn't have get_environment method
384381
mock_client = MagicMock()
385382
mock_client._client_wrapper = SyncClientWrapper(
386-
base_url=base_url,
387-
api_key="test_key",
388-
httpx_client=MagicMock(),
389-
timeout=30.0
383+
base_url=base_url, api_key="test_key", httpx_client=MagicMock(), timeout=30.0
390384
)
391385

392386
conversation = AsyncConversation(
393-
client=mock_client,
394-
agent_id=TEST_AGENT_ID,
395-
requires_auth=False,
396-
audio_interface=MockAsyncAudioInterface()
387+
client=mock_client, agent_id=TEST_AGENT_ID, requires_auth=False, audio_interface=MockAsyncAudioInterface()
397388
)
398389

399390
try:
400391
wss_url = conversation._get_wss_url()
401392

402393
# Verify the URL is correctly generated
403-
expected_url = f"{expected_ws_base}/v1/convai/conversation?agent_id={TEST_AGENT_ID}&source=python_sdk&version="
394+
expected_url = (
395+
f"{expected_ws_base}/v1/convai/conversation?agent_id={TEST_AGENT_ID}&source=python_sdk&version="
396+
)
404397
assert wss_url.startswith(expected_url), f"URL should start with {expected_url}, got {wss_url}"
405398

406399
# Verify the URL contains version parameter
@@ -414,3 +407,44 @@ async def test_async_conversation_wss_url_generation_without_get_environment():
414407

415408
except Exception as e:
416409
assert False, f"Unexpected error generating WebSocket URL: {e}"
410+
411+
412+
@pytest.mark.asyncio
413+
async def test_async_websocket_url_construction_edge_cases():
414+
"""Test WebSocket URL construction edge cases for async conversation, specifically for trailing slash handling."""
415+
from elevenlabs.conversational_ai.conversation import AsyncConversation
416+
from elevenlabs.core.client_wrapper import SyncClientWrapper
417+
418+
# Test cases with various base URL formats
419+
test_cases = [
420+
# Base URLs without trailing slashes (the main edge case)
421+
("https://api.eu.residency.elevenlabs.io", "wss://api.eu.residency.elevenlabs.io"),
422+
("https://api.elevenlabs.io", "wss://api.elevenlabs.io"),
423+
("http://localhost:8000", "ws://localhost:8000"),
424+
# Base URLs with trailing slashes (should still work)
425+
("https://api.eu.residency.elevenlabs.io/", "wss://api.eu.residency.elevenlabs.io"),
426+
("https://api.elevenlabs.io/", "wss://api.elevenlabs.io"),
427+
("http://localhost:8000/", "ws://localhost:8000"),
428+
]
429+
430+
for base_url, expected_ws_base in test_cases:
431+
# Test async conversation WebSocket URL construction
432+
mock_client = MagicMock()
433+
mock_client._client_wrapper = SyncClientWrapper(
434+
base_url=base_url, api_key="test_key", httpx_client=MagicMock(), timeout=30.0
435+
)
436+
437+
conversation = AsyncConversation(
438+
client=mock_client, agent_id=TEST_AGENT_ID, requires_auth=False, audio_interface=MockAsyncAudioInterface()
439+
)
440+
441+
# Test conversation URL generation
442+
conv_url = conversation._get_wss_url()
443+
expected_conv_url = f"{expected_ws_base}/v1/convai/conversation"
444+
assert (
445+
expected_conv_url in conv_url
446+
), f"Async conversation URL should contain {expected_conv_url}, got {conv_url}"
447+
448+
# Ensure no double slashes in the path (except after the protocol)
449+
url_path = conv_url.split("://", 1)[1] # Remove protocol
450+
assert "//" not in url_path, f"Async conversation URL should not contain double slashes in path: {conv_url}"

tests/test_convai.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import json
44
import time
55

6-
76
class MockAudioInterface(AudioInterface):
87
def start(self, input_callback):
98
print("Audio interface started")
@@ -51,6 +50,7 @@ def test_conversation_basic_flow():
5150
# Mock setup
5251
mock_ws = create_mock_websocket()
5352
mock_client = MagicMock()
53+
mock_client._client_wrapper.get_base_url.return_value = "https://api.elevenlabs.io"
5454
agent_response_callback = MagicMock()
5555
test_user_id = "test_user_123"
5656

@@ -132,6 +132,7 @@ def test_conversation_with_dynamic_variables():
132132
# Mock setup
133133
mock_ws = create_mock_websocket()
134134
mock_client = MagicMock()
135+
mock_client._client_wrapper.get_base_url.return_value = "https://api.elevenlabs.io"
135136
agent_response_callback = MagicMock()
136137

137138
dynamic_variables = {"name": "angelo"}
@@ -181,6 +182,7 @@ def test_conversation_with_contextual_update():
181182
# Mock setup
182183
mock_ws = create_mock_websocket([])
183184
mock_client = MagicMock()
185+
mock_client._client_wrapper.get_base_url.return_value = "https://api.elevenlabs.io"
184186

185187
# Setup the conversation
186188
conversation = Conversation(
@@ -256,3 +258,71 @@ def test_conversation_wss_url_generation_without_get_environment():
256258

257259
except Exception as e:
258260
assert False, f"Unexpected error generating WebSocket URL: {e}"
261+
262+
263+
def test_websocket_url_construction_edge_cases():
264+
"""Test WebSocket URL construction edge cases, specifically for trailing slash handling."""
265+
from elevenlabs.core.client_wrapper import SyncClientWrapper
266+
from elevenlabs.conversational_ai.conversation import Conversation
267+
from elevenlabs.realtime_tts import RealtimeTextToSpeechClient
268+
269+
# Test cases with various base URL formats
270+
test_cases = [
271+
# Base URLs without trailing slashes (the main edge case)
272+
("https://api.eu.residency.elevenlabs.io", "wss://api.eu.residency.elevenlabs.io", "wss://api.eu.residency.elevenlabs.io"),
273+
("https://api.elevenlabs.io", "wss://api.elevenlabs.io", "wss://api.elevenlabs.io"),
274+
("http://localhost:8000", "ws://localhost:8000", "wss://localhost:8000"),
275+
# Base URLs with trailing slashes (should still work)
276+
("https://api.eu.residency.elevenlabs.io/", "wss://api.eu.residency.elevenlabs.io", "wss://api.eu.residency.elevenlabs.io/"),
277+
("https://api.elevenlabs.io/", "wss://api.elevenlabs.io", "wss://api.elevenlabs.io/"),
278+
("http://localhost:8000/", "ws://localhost:8000", "wss://localhost:8000/"),
279+
]
280+
281+
for base_url, expected_ws_base, expected_tts_ws_base in test_cases:
282+
# Test conversation WebSocket URL construction
283+
mock_client = MagicMock()
284+
mock_client._client_wrapper = SyncClientWrapper(
285+
base_url=base_url,
286+
api_key="test_key",
287+
httpx_client=MagicMock(),
288+
timeout=30.0
289+
)
290+
291+
conversation = Conversation(
292+
client=mock_client,
293+
agent_id=TEST_AGENT_ID,
294+
requires_auth=False,
295+
audio_interface=MockAudioInterface()
296+
)
297+
298+
# Test conversation URL generation
299+
conv_url = conversation._get_wss_url()
300+
expected_conv_url = f"{expected_ws_base}/v1/convai/conversation"
301+
assert expected_conv_url in conv_url, f"Conversation URL should contain {expected_conv_url}, got {conv_url}"
302+
303+
# Ensure no double slashes in the path (except after the protocol)
304+
url_path = conv_url.split("://", 1)[1] # Remove protocol
305+
assert "//" not in url_path, f"URL should not contain double slashes in path: {conv_url}"
306+
307+
# Test realtime TTS WebSocket URL construction
308+
realtime_client = RealtimeTextToSpeechClient(client_wrapper=mock_client._client_wrapper)
309+
310+
# Test the WebSocket base URL construction
311+
# Note: realtime TTS always uses wss scheme, not ws
312+
assert realtime_client._ws_base_url == expected_tts_ws_base, f"TTS WebSocket base URL should be {expected_tts_ws_base}, got {realtime_client._ws_base_url}"
313+
314+
# Test full URL construction using urljoin (simulating the actual method)
315+
import urllib.parse
316+
test_voice_id = "test_voice_123"
317+
test_model = "eleven_turbo_v2_5"
318+
test_format = "mp3_44100_128"
319+
relative_path = f"v1/text-to-speech/{test_voice_id}/stream-input?model_id={test_model}&output_format={test_format}"
320+
321+
full_tts_url = urllib.parse.urljoin(realtime_client._ws_base_url, relative_path)
322+
# For URLs with trailing slash, expect it to be preserved in the joined URL
323+
expected_tts_url_base = expected_tts_ws_base.rstrip('/') + "/v1/text-to-speech/" + test_voice_id + "/stream-input"
324+
assert expected_tts_url_base in full_tts_url, f"TTS URL should contain {expected_tts_url_base}, got {full_tts_url}"
325+
326+
# Ensure no double slashes in the path
327+
tts_url_path = full_tts_url.split("://", 1)[1]
328+
assert "//" not in tts_url_path, f"TTS URL should not contain double slashes in path: {full_tts_url}"

0 commit comments

Comments
 (0)