Skip to content

Commit 2d39717

Browse files
committed
Bug fixes and API surface updates
1 parent 6242787 commit 2d39717

7 files changed

Lines changed: 515 additions & 14 deletions

File tree

src/elevenlabs/speech_engine/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""ElevenLabs Speech Engine SDK module."""
22

3-
from .resource import SpeechEngineResource
3+
from .resource import SpeechEngineResource, verify_speech_engine_jwt
44
from .server import SpeechEngineServer
55
from .session import SpeechEngineSession
66
from .types import (
@@ -19,6 +19,7 @@
1919
"SpeechEngineServer",
2020
"SpeechEngineSession",
2121
"WebSocketLike",
22+
"verify_speech_engine_jwt",
2223
"CLOSE",
2324
"DISCONNECTED",
2425
"ERROR",

src/elevenlabs/speech_engine/resource.py

Lines changed: 125 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,101 @@
11
"""SpeechEngineResource — client-facing handle for a speech engine instance."""
22

3+
import base64
4+
import hashlib
5+
import hmac
6+
import json
7+
import logging
8+
import time
39
import typing
410

511
from .server import SpeechEngineServer
612
from .session import SpeechEngineSession
713
from .types import WebSocketLike
814

15+
logger = logging.getLogger("elevenlabs.speech_engine")
16+
17+
_ISSUER = "https://api.elevenlabs.io/convai/speech-engine"
18+
_SUBJECT = "convai_speech_engine_upstream"
19+
_LEEWAY_SECONDS = 60
20+
21+
22+
def _base64url_decode(data: str) -> bytes:
23+
padded = data.replace("-", "+").replace("_", "/")
24+
remainder = len(padded) % 4
25+
if remainder:
26+
padded += "=" * (4 - remainder)
27+
return base64.b64decode(padded)
28+
29+
30+
def verify_speech_engine_jwt(value: str, api_key: str) -> typing.Dict[str, typing.Any]:
31+
"""Verify an HS256 JWT from the ElevenLabs Speech Engine API.
32+
33+
The HMAC secret is the SHA-256 hash of the API key. Returns the
34+
decoded payload on success, raises :class:`ValueError` on failure.
35+
"""
36+
token = value.strip()
37+
if token.lower().startswith("bearer "):
38+
token = token[7:].strip()
39+
40+
parts = token.split(".")
41+
if len(parts) != 3:
42+
raise ValueError("Invalid JWT: expected 3 parts")
43+
44+
header_b64, payload_b64, signature_b64 = parts
45+
46+
try:
47+
payload = json.loads(_base64url_decode(payload_b64))
48+
except Exception:
49+
raise ValueError("Invalid JWT: failed to decode payload")
50+
51+
trimmed_key = api_key.strip()
52+
secret = hashlib.sha256(trimmed_key.encode("utf-8")).digest()
53+
54+
expected_sig = hmac.new(
55+
secret, f"{header_b64}.{payload_b64}".encode(), hashlib.sha256
56+
).digest()
57+
actual_sig = _base64url_decode(signature_b64)
58+
59+
if not hmac.compare_digest(expected_sig, actual_sig):
60+
key_prefix = (
61+
f"{trimmed_key[:4]}...{trimmed_key[-4:]}"
62+
if len(trimmed_key) > 8
63+
else "****"
64+
)
65+
whitespace_note = (
66+
" — key had trailing whitespace that was trimmed"
67+
if len(trimmed_key) != len(api_key)
68+
else ""
69+
)
70+
raise ValueError(
71+
f"Invalid JWT: signature mismatch "
72+
f"(API key: {key_prefix}, {len(trimmed_key)} chars{whitespace_note})"
73+
)
74+
75+
if payload.get("iss") != _ISSUER:
76+
raise ValueError(
77+
f'Invalid JWT: expected issuer "{_ISSUER}", got "{payload.get("iss")}"'
78+
)
79+
if payload.get("sub") != _SUBJECT:
80+
raise ValueError(
81+
f'Invalid JWT: expected subject "{_SUBJECT}", got "{payload.get("sub")}"'
82+
)
83+
84+
now = int(time.time())
85+
86+
exp = payload.get("exp")
87+
if not isinstance(exp, (int, float)):
88+
raise ValueError("Invalid JWT: missing exp claim")
89+
iat = payload.get("iat")
90+
if not isinstance(iat, (int, float)):
91+
raise ValueError("Invalid JWT: missing iat claim")
92+
if exp + _LEEWAY_SECONDS < now:
93+
raise ValueError("Invalid JWT: token has expired")
94+
if iat - _LEEWAY_SECONDS > now:
95+
raise ValueError("Invalid JWT: iat is in the future")
96+
97+
return payload
98+
999

10100
class SpeechEngineResource:
11101
"""Represents a speech engine instance.
@@ -39,15 +129,49 @@ def __init__(
39129
self.engine_id = engine_id
40130
self._options = client_options
41131

132+
def _get_api_key(self) -> typing.Optional[str]:
133+
if self._options is not None and hasattr(self._options, "_api_key"):
134+
return self._options._api_key
135+
return None
136+
137+
def verify_request(
138+
self, headers: typing.Dict[str, typing.Any]
139+
) -> bool:
140+
"""Verify that an incoming request is from the ElevenLabs API.
141+
142+
Checks the ``X-Elevenlabs-Speech-Engine-Authorization`` header
143+
for a valid JWT signed with the SHA-256 hash of the API key.
144+
145+
Only needed when managing the WebSocket upgrade yourself.
146+
When using :meth:`serve`, verification is handled automatically.
147+
"""
148+
api_key = self._get_api_key()
149+
if not api_key:
150+
return False
151+
raw = headers.get("x-elevenlabs-speech-engine-authorization")
152+
if isinstance(raw, list):
153+
raw = raw[0] if raw else None
154+
if not raw:
155+
return False
156+
try:
157+
verify_speech_engine_jwt(raw, api_key)
158+
return True
159+
except ValueError:
160+
return False
161+
42162
async def serve(
43163
self,
44164
*,
45165
port: int = 3001,
166+
path: typing.Optional[str] = None,
46167
debug: bool = False,
47168
**handlers: typing.Any,
48169
) -> None:
49170
"""Start a standalone WebSocket server. Blocks until stopped."""
50-
server = SpeechEngineServer(port=port, debug=debug, **handlers)
171+
api_key = self._get_api_key()
172+
server = SpeechEngineServer(
173+
port=port, path=path, debug=debug, api_key=api_key, **handlers
174+
)
51175
await server.serve()
52176

53177
def create_session(

src/elevenlabs/speech_engine/server.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import logging
5+
import os
56
import typing
67

78
from .session import SpeechEngineSession, _wire_handlers
@@ -15,10 +16,14 @@ class SpeechEngineServer:
1516
instances for each incoming connection from the ElevenLabs Speech Engine
1617
API.
1718
19+
Every incoming connection is verified against the ElevenLabs API using
20+
the configured API key before being accepted.
21+
1822
Example::
1923
2024
server = SpeechEngineServer(
2125
port=3001,
26+
api_key="sk_...",
2227
debug=True,
2328
on_transcript=handle_transcript,
2429
)
@@ -29,15 +34,28 @@ def __init__(
2934
self,
3035
*,
3136
port: int = 3001,
37+
path: typing.Optional[str] = None,
38+
api_key: typing.Optional[str] = None,
3239
debug: bool = False,
3340
**handlers: typing.Any,
3441
) -> None:
3542
self._port = port
43+
self._path = path
44+
self._api_key = api_key
3645
self._debug = debug
3746
self._handlers = handlers
3847
self._stop_event = asyncio.Event()
3948
self._server = None # type: typing.Any
4049

50+
if debug:
51+
logger.setLevel(logging.DEBUG)
52+
if not logger.handlers:
53+
handler = logging.StreamHandler()
54+
handler.setFormatter(
55+
logging.Formatter("[SpeechEngine] %(message)s")
56+
)
57+
logger.addHandler(handler)
58+
4159
def handle_connection(self, ws: WebSocketLike) -> SpeechEngineSession:
4260
"""Wrap *ws* in a :class:`SpeechEngineSession` with the server's
4361
handlers wired up.
@@ -46,15 +64,51 @@ def handle_connection(self, ws: WebSocketLike) -> SpeechEngineSession:
4664
individual connections. The returned session's :meth:`run` must
4765
still be awaited by the caller.
4866
"""
67+
logger.debug("creating new session")
4968
session = SpeechEngineSession(ws, debug=self._debug)
5069
_wire_handlers(session, self._handlers)
5170
return session
5271

5372
async def serve(self) -> None:
5473
"""Start the WebSocket server. Blocks until :meth:`stop` is called."""
74+
from .resource import verify_speech_engine_jwt # noqa: E402
75+
5576
import websockets # noqa: E402 — keep import lazy
5677

78+
api_key = self._api_key or os.environ.get("ELEVENLABS_API_KEY")
79+
if not api_key:
80+
raise RuntimeError(
81+
"SpeechEngineServer requires an API key to verify incoming "
82+
"connections. Pass api_key= or set the ELEVENLABS_API_KEY "
83+
"environment variable."
84+
)
85+
5786
async def _handler(websocket: typing.Any, *_args: typing.Any) -> None:
87+
if self._path is not None and websocket.request.path != self._path:
88+
await websocket.close(4000, "not found")
89+
return
90+
91+
header_value = websocket.request.headers.get(
92+
"x-elevenlabs-speech-engine-authorization"
93+
)
94+
if not header_value:
95+
logger.debug(
96+
"rejected connection — missing "
97+
"X-Elevenlabs-Speech-Engine-Authorization header"
98+
)
99+
await websocket.close(
100+
4001, "missing authorization header"
101+
)
102+
return
103+
104+
try:
105+
verify_speech_engine_jwt(header_value, api_key)
106+
except ValueError as e:
107+
logger.debug("rejected connection — %s", e)
108+
await websocket.close(4001, str(e))
109+
return
110+
111+
logger.debug("verified connection, accepting WebSocket")
58112
session = self.handle_connection(websocket)
59113
await session.run()
60114

src/elevenlabs/speech_engine/session.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import typing
77

8-
from .types import ConversationMessage, WebSocketLike
8+
from .types import ConversationMessage, WebSocketLike, wrap_websocket
99

1010
logger = logging.getLogger("elevenlabs.speech_engine")
1111

@@ -159,11 +159,11 @@ async def handle(transcript):
159159

160160
def __init__(
161161
self,
162-
ws: WebSocketLike,
162+
ws: typing.Any,
163163
*,
164164
debug: bool = False,
165165
) -> None:
166-
self._ws = ws
166+
self._ws = wrap_websocket(ws)
167167
self._conversation_id = None # type: typing.Optional[str]
168168
self._current_task = None # type: typing.Optional[asyncio.Task] # type: ignore[type-arg]
169169
self._current_event_id = None # type: typing.Optional[int]
@@ -173,6 +173,12 @@ def __init__(
173173

174174
if debug:
175175
logger.setLevel(logging.DEBUG)
176+
if not logger.handlers:
177+
handler = logging.StreamHandler()
178+
handler.setFormatter(
179+
logging.Formatter("[SpeechEngine] %(message)s")
180+
)
181+
logger.addHandler(handler)
176182

177183
# ------------------------------------------------------------------
178184
# Event emitter interface
@@ -232,8 +238,7 @@ async def run(self) -> None:
232238
except asyncio.CancelledError:
233239
raise
234240
except Exception:
235-
# Connection closed or errored — exit the loop and
236-
# let the finally block emit "disconnected".
241+
logger.debug("WebSocket connection lost")
237242
break
238243

239244
try:
@@ -275,10 +280,19 @@ async def send_response(
275280
if self._closed:
276281
raise RuntimeError("Cannot send response: session is closed")
277282

283+
if self._current_event_id is None:
284+
logger.warning(
285+
"sendResponse() called outside of an on_transcript handler. "
286+
"Responses can only be sent in reply to a user transcript. "
287+
"To have the agent speak first, set a first message in your "
288+
"Speech Engine conversation config on the client."
289+
)
290+
return
291+
278292
if isinstance(response, str):
279293
logger.debug(
280-
"sending string response (%d chars), event_id=%s",
281-
len(response),
294+
'sending string response: "%s", event_id=%s',
295+
response,
282296
self._current_event_id,
283297
)
284298
await self._send_agent_response(response, False)
@@ -318,6 +332,19 @@ async def _handle_message(self, msg: typing.Dict[str, typing.Any]) -> None:
318332
await self._emit("init", self._conversation_id)
319333

320334
elif msg_type == "user_transcript":
335+
incoming_event_id = msg.get("event_id")
336+
337+
if (
338+
incoming_event_id == self._current_event_id
339+
and self._current_task is not None
340+
and not self._current_task.done()
341+
):
342+
logger.debug(
343+
"skipping duplicate transcript, event_id=%s",
344+
incoming_event_id,
345+
)
346+
return
347+
321348
was_active = (
322349
self._current_task is not None
323350
and not self._current_task.done()
@@ -328,10 +355,10 @@ async def _handle_message(self, msg: typing.Dict[str, typing.Any]) -> None:
328355
"interrupted: cancelling previous response "
329356
"(event_id=%s) for new transcript (event_id=%s)",
330357
self._current_event_id,
331-
msg.get("event_id"),
358+
incoming_event_id,
332359
)
333360

334-
self._current_event_id = msg.get("event_id")
361+
self._current_event_id = incoming_event_id
335362
transcript_data = msg.get("user_transcript", [])
336363
logger.debug(
337364
"received transcript, event_id=%s, messages=%d",

0 commit comments

Comments
 (0)