|
1 | 1 | """SpeechEngineResource — client-facing handle for a speech engine instance.""" |
2 | 2 |
|
| 3 | +import base64 |
| 4 | +import hashlib |
| 5 | +import hmac |
| 6 | +import json |
| 7 | +import logging |
| 8 | +import time |
3 | 9 | import typing |
4 | 10 |
|
5 | 11 | from .server import SpeechEngineServer |
6 | 12 | from .session import SpeechEngineSession |
7 | 13 | from .types import WebSocketLike |
8 | 14 |
|
| 15 | +logger = logging.getLogger("elevenlabs.speech_engine") |
| 16 | + |
| 17 | +_ISSUER = "https://api.elevenlabs.io/convai/speech-engine" |
| 18 | +_SUBJECT = "convai_speech_engine_upstream" |
| 19 | +_LEEWAY_SECONDS = 60 |
| 20 | + |
| 21 | + |
| 22 | +def _base64url_decode(data: str) -> bytes: |
| 23 | + padded = data.replace("-", "+").replace("_", "/") |
| 24 | + remainder = len(padded) % 4 |
| 25 | + if remainder: |
| 26 | + padded += "=" * (4 - remainder) |
| 27 | + return base64.b64decode(padded) |
| 28 | + |
| 29 | + |
| 30 | +def verify_speech_engine_jwt(value: str, api_key: str) -> typing.Dict[str, typing.Any]: |
| 31 | + """Verify an HS256 JWT from the ElevenLabs Speech Engine API. |
| 32 | +
|
| 33 | + The HMAC secret is the SHA-256 hash of the API key. Returns the |
| 34 | + decoded payload on success, raises :class:`ValueError` on failure. |
| 35 | + """ |
| 36 | + token = value.strip() |
| 37 | + if token.lower().startswith("bearer "): |
| 38 | + token = token[7:].strip() |
| 39 | + |
| 40 | + parts = token.split(".") |
| 41 | + if len(parts) != 3: |
| 42 | + raise ValueError("Invalid JWT: expected 3 parts") |
| 43 | + |
| 44 | + header_b64, payload_b64, signature_b64 = parts |
| 45 | + |
| 46 | + try: |
| 47 | + payload = json.loads(_base64url_decode(payload_b64)) |
| 48 | + except Exception: |
| 49 | + raise ValueError("Invalid JWT: failed to decode payload") |
| 50 | + |
| 51 | + trimmed_key = api_key.strip() |
| 52 | + secret = hashlib.sha256(trimmed_key.encode("utf-8")).digest() |
| 53 | + |
| 54 | + expected_sig = hmac.new( |
| 55 | + secret, f"{header_b64}.{payload_b64}".encode(), hashlib.sha256 |
| 56 | + ).digest() |
| 57 | + actual_sig = _base64url_decode(signature_b64) |
| 58 | + |
| 59 | + if not hmac.compare_digest(expected_sig, actual_sig): |
| 60 | + key_prefix = ( |
| 61 | + f"{trimmed_key[:4]}...{trimmed_key[-4:]}" |
| 62 | + if len(trimmed_key) > 8 |
| 63 | + else "****" |
| 64 | + ) |
| 65 | + whitespace_note = ( |
| 66 | + " — key had trailing whitespace that was trimmed" |
| 67 | + if len(trimmed_key) != len(api_key) |
| 68 | + else "" |
| 69 | + ) |
| 70 | + raise ValueError( |
| 71 | + f"Invalid JWT: signature mismatch " |
| 72 | + f"(API key: {key_prefix}, {len(trimmed_key)} chars{whitespace_note})" |
| 73 | + ) |
| 74 | + |
| 75 | + if payload.get("iss") != _ISSUER: |
| 76 | + raise ValueError( |
| 77 | + f'Invalid JWT: expected issuer "{_ISSUER}", got "{payload.get("iss")}"' |
| 78 | + ) |
| 79 | + if payload.get("sub") != _SUBJECT: |
| 80 | + raise ValueError( |
| 81 | + f'Invalid JWT: expected subject "{_SUBJECT}", got "{payload.get("sub")}"' |
| 82 | + ) |
| 83 | + |
| 84 | + now = int(time.time()) |
| 85 | + |
| 86 | + exp = payload.get("exp") |
| 87 | + if not isinstance(exp, (int, float)): |
| 88 | + raise ValueError("Invalid JWT: missing exp claim") |
| 89 | + iat = payload.get("iat") |
| 90 | + if not isinstance(iat, (int, float)): |
| 91 | + raise ValueError("Invalid JWT: missing iat claim") |
| 92 | + if exp + _LEEWAY_SECONDS < now: |
| 93 | + raise ValueError("Invalid JWT: token has expired") |
| 94 | + if iat - _LEEWAY_SECONDS > now: |
| 95 | + raise ValueError("Invalid JWT: iat is in the future") |
| 96 | + |
| 97 | + return payload |
| 98 | + |
9 | 99 |
|
10 | 100 | class SpeechEngineResource: |
11 | 101 | """Represents a speech engine instance. |
@@ -39,15 +129,49 @@ def __init__( |
39 | 129 | self.engine_id = engine_id |
40 | 130 | self._options = client_options |
41 | 131 |
|
| 132 | + def _get_api_key(self) -> typing.Optional[str]: |
| 133 | + if self._options is not None and hasattr(self._options, "_api_key"): |
| 134 | + return self._options._api_key |
| 135 | + return None |
| 136 | + |
| 137 | + def verify_request( |
| 138 | + self, headers: typing.Dict[str, typing.Any] |
| 139 | + ) -> bool: |
| 140 | + """Verify that an incoming request is from the ElevenLabs API. |
| 141 | +
|
| 142 | + Checks the ``X-Elevenlabs-Speech-Engine-Authorization`` header |
| 143 | + for a valid JWT signed with the SHA-256 hash of the API key. |
| 144 | +
|
| 145 | + Only needed when managing the WebSocket upgrade yourself. |
| 146 | + When using :meth:`serve`, verification is handled automatically. |
| 147 | + """ |
| 148 | + api_key = self._get_api_key() |
| 149 | + if not api_key: |
| 150 | + return False |
| 151 | + raw = headers.get("x-elevenlabs-speech-engine-authorization") |
| 152 | + if isinstance(raw, list): |
| 153 | + raw = raw[0] if raw else None |
| 154 | + if not raw: |
| 155 | + return False |
| 156 | + try: |
| 157 | + verify_speech_engine_jwt(raw, api_key) |
| 158 | + return True |
| 159 | + except ValueError: |
| 160 | + return False |
| 161 | + |
42 | 162 | async def serve( |
43 | 163 | self, |
44 | 164 | *, |
45 | 165 | port: int = 3001, |
| 166 | + path: typing.Optional[str] = None, |
46 | 167 | debug: bool = False, |
47 | 168 | **handlers: typing.Any, |
48 | 169 | ) -> None: |
49 | 170 | """Start a standalone WebSocket server. Blocks until stopped.""" |
50 | | - server = SpeechEngineServer(port=port, debug=debug, **handlers) |
| 171 | + api_key = self._get_api_key() |
| 172 | + server = SpeechEngineServer( |
| 173 | + port=port, path=path, debug=debug, api_key=api_key, **handlers |
| 174 | + ) |
51 | 175 | await server.serve() |
52 | 176 |
|
53 | 177 | def create_session( |
|
0 commit comments