|
3 | 3 | import json |
4 | 4 | import time |
5 | 5 |
|
| 6 | +from elevenlabs.core.client_wrapper import SyncClientWrapper |
| 7 | +from elevenlabs.conversational_ai.conversation import Conversation |
| 8 | +from elevenlabs.realtime_tts import RealtimeTextToSpeechClient |
6 | 9 |
|
7 | 10 | class MockAudioInterface(AudioInterface): |
8 | 11 | def start(self, input_callback): |
@@ -256,3 +259,68 @@ def test_conversation_wss_url_generation_without_get_environment(): |
256 | 259 |
|
257 | 260 | except Exception as e: |
258 | 261 | assert False, f"Unexpected error generating WebSocket URL: {e}" |
| 262 | + |
| 263 | + |
| 264 | +def test_websocket_url_construction_edge_cases(): |
| 265 | + """Test WebSocket URL construction edge cases, specifically for trailing slash handling.""" |
| 266 | + |
| 267 | + # Test cases with various base URL formats |
| 268 | + test_cases = [ |
| 269 | + # Base URLs without trailing slashes (the main edge case) |
| 270 | + ("https://api.eu.residency.elevenlabs.io", "wss://api.eu.residency.elevenlabs.io", "wss://api.eu.residency.elevenlabs.io"), |
| 271 | + ("https://api.elevenlabs.io", "wss://api.elevenlabs.io", "wss://api.elevenlabs.io"), |
| 272 | + ("http://localhost:8000", "ws://localhost:8000", "wss://localhost:8000"), |
| 273 | + # Base URLs with trailing slashes (should still work) |
| 274 | + ("https://api.eu.residency.elevenlabs.io/", "wss://api.eu.residency.elevenlabs.io", "wss://api.eu.residency.elevenlabs.io/"), |
| 275 | + ("https://api.elevenlabs.io/", "wss://api.elevenlabs.io", "wss://api.elevenlabs.io/"), |
| 276 | + ("http://localhost:8000/", "ws://localhost:8000", "wss://localhost:8000/"), |
| 277 | + ] |
| 278 | + |
| 279 | + for base_url, expected_ws_base, expected_tts_ws_base in test_cases: |
| 280 | + # Test conversation WebSocket URL construction |
| 281 | + mock_client = MagicMock() |
| 282 | + mock_client._client_wrapper = SyncClientWrapper( |
| 283 | + base_url=base_url, |
| 284 | + api_key="test_key", |
| 285 | + httpx_client=MagicMock(), |
| 286 | + timeout=30.0 |
| 287 | + ) |
| 288 | + |
| 289 | + conversation = Conversation( |
| 290 | + client=mock_client, |
| 291 | + agent_id=TEST_AGENT_ID, |
| 292 | + requires_auth=False, |
| 293 | + audio_interface=MockAudioInterface() |
| 294 | + ) |
| 295 | + |
| 296 | + # Test conversation URL generation |
| 297 | + conv_url = conversation._get_wss_url() |
| 298 | + expected_conv_url = f"{expected_ws_base}/v1/convai/conversation" |
| 299 | + assert expected_conv_url in conv_url, f"Conversation URL should contain {expected_conv_url}, got {conv_url}" |
| 300 | + |
| 301 | + # Ensure no double slashes in the path (except after the protocol) |
| 302 | + url_path = conv_url.split("://", 1)[1] # Remove protocol |
| 303 | + assert "//" not in url_path, f"URL should not contain double slashes in path: {conv_url}" |
| 304 | + |
| 305 | + # Test realtime TTS WebSocket URL construction |
| 306 | + realtime_client = RealtimeTextToSpeechClient(client_wrapper=mock_client._client_wrapper) |
| 307 | + |
| 308 | + # Test the WebSocket base URL construction |
| 309 | + # Note: realtime TTS always uses wss scheme, not ws |
| 310 | + assert realtime_client._ws_base_url == expected_tts_ws_base, f"TTS WebSocket base URL should be {expected_tts_ws_base}, got {realtime_client._ws_base_url}" |
| 311 | + |
| 312 | + # Test full URL construction using urljoin (simulating the actual method) |
| 313 | + import urllib.parse |
| 314 | + test_voice_id = "test_voice_123" |
| 315 | + test_model = "eleven_turbo_v2_5" |
| 316 | + test_format = "mp3_44100_128" |
| 317 | + relative_path = f"v1/text-to-speech/{test_voice_id}/stream-input?model_id={test_model}&output_format={test_format}" |
| 318 | + |
| 319 | + full_tts_url = urllib.parse.urljoin(realtime_client._ws_base_url, relative_path) |
| 320 | + # For URLs with trailing slash, expect it to be preserved in the joined URL |
| 321 | + expected_tts_url_base = expected_tts_ws_base.rstrip('/') + "/v1/text-to-speech/" + test_voice_id + "/stream-input" |
| 322 | + assert expected_tts_url_base in full_tts_url, f"TTS URL should contain {expected_tts_url_base}, got {full_tts_url}" |
| 323 | + |
| 324 | + # Ensure no double slashes in the path |
| 325 | + tts_url_path = full_tts_url.split("://", 1)[1] |
| 326 | + assert "//" not in tts_url_path, f"TTS URL should not contain double slashes in path: {full_tts_url}" |
0 commit comments