Skip to content

Commit 18fad44

Browse files
jaebong-humanclaude
andcommitted
feat: add Typecast TTS provider
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent e34d950 commit 18fad44

2 files changed

Lines changed: 380 additions & 0 deletions

File tree

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import json
2+
import os
3+
import uuid
4+
5+
from httpx import AsyncClient
6+
7+
from astrbot import logger
8+
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
9+
10+
from ..entities import ProviderType
11+
from ..provider import TTSProvider
12+
from ..register import register_provider_adapter
13+
14+
15+
def _safe_cast(value, type_func, default):
16+
try:
17+
return type_func(value)
18+
except (TypeError, ValueError):
19+
return default
20+
21+
22+
@register_provider_adapter(
23+
"typecast_tts",
24+
"Typecast TTS",
25+
provider_type=ProviderType.TEXT_TO_SPEECH,
26+
)
27+
class ProviderTypecastTTS(TTSProvider):
28+
API_URL = "https://api.typecast.ai/v1/text-to-speech"
29+
30+
def __init__(
31+
self,
32+
provider_config: dict,
33+
provider_settings: dict,
34+
) -> None:
35+
super().__init__(provider_config, provider_settings)
36+
37+
self.api_key: str = provider_config.get("api_key", "")
38+
if not self.api_key:
39+
raise ValueError("[Typecast TTS] api_key is required")
40+
self.voice_id: str = provider_config.get("typecast-voice-id", "")
41+
if not self.voice_id:
42+
raise ValueError("[Typecast TTS] typecast-voice-id is required")
43+
self.language: str = provider_config.get("language", "kor")
44+
self.emotion_preset: str = provider_config.get(
45+
"typecast-emotion-preset", "normal"
46+
)
47+
self.emotion_intensity: float = _safe_cast(
48+
provider_config.get("typecast-emotion-intensity", 1.0), float, 1.0
49+
)
50+
self.volume: int = _safe_cast(
51+
provider_config.get("typecast-volume", 100), int, 100
52+
)
53+
self.pitch: int = _safe_cast(
54+
provider_config.get("typecast-pitch", 0), int, 0
55+
)
56+
self.tempo: float = _safe_cast(
57+
provider_config.get("typecast-tempo", 1.0), float, 1.0
58+
)
59+
self.timeout: int = _safe_cast(
60+
provider_config.get("timeout", 30), int, 30
61+
)
62+
self.proxy: str = provider_config.get("proxy", "")
63+
64+
if self.proxy:
65+
logger.info(f"[Typecast TTS] Using proxy: {self.proxy}")
66+
67+
self.set_model(provider_config.get("model", "ssfm-v30"))
68+
69+
def _build_request_body(self, text: str) -> dict:
70+
return {
71+
"voice_id": self.voice_id,
72+
"text": text,
73+
"model": self.model_name,
74+
"language": self.language,
75+
"prompt": {
76+
"emotion_type": "preset",
77+
"emotion_preset": self.emotion_preset,
78+
"emotion_intensity": self.emotion_intensity,
79+
},
80+
"output": {
81+
"volume": self.volume,
82+
"audio_pitch": self.pitch,
83+
"audio_tempo": self.tempo,
84+
"audio_format": "wav",
85+
},
86+
}
87+
88+
async def get_audio(self, text: str) -> str:
89+
if not text or not text.strip():
90+
raise ValueError("[Typecast TTS] text must not be empty")
91+
if len(text) > 2000:
92+
raise ValueError(
93+
f"[Typecast TTS] text length {len(text)} exceeds maximum of 2000 characters"
94+
)
95+
96+
temp_dir = get_astrbot_temp_path()
97+
os.makedirs(temp_dir, exist_ok=True)
98+
path = os.path.join(temp_dir, f"typecast_tts_{uuid.uuid4()}.wav")
99+
100+
headers = {
101+
"Content-Type": "application/json",
102+
"X-API-KEY": self.api_key,
103+
}
104+
body = self._build_request_body(text)
105+
106+
async with AsyncClient(
107+
timeout=self.timeout,
108+
proxy=self.proxy if self.proxy else None,
109+
) as client, client.stream(
110+
"POST",
111+
self.API_URL,
112+
headers=headers,
113+
json=body,
114+
) as response:
115+
if response.status_code == 200 and response.headers.get(
116+
"content-type", ""
117+
).lower().startswith("audio/"):
118+
with open(path, "wb") as f:
119+
async for chunk in response.aiter_bytes():
120+
f.write(chunk)
121+
return path
122+
123+
error_bytes = await response.aread()
124+
error_text = error_bytes.decode("utf-8", errors="replace")[:1024]
125+
try:
126+
error_detail = json.loads(error_text).get("detail", error_text)
127+
except (json.JSONDecodeError, AttributeError):
128+
error_detail = error_text
129+
raise RuntimeError(
130+
f"Typecast API request failed: status {response.status_code}, "
131+
f"response: {error_detail}"
132+
)

tests/test_typecast_tts_source.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
import os
2+
from pathlib import Path
3+
from unittest.mock import AsyncMock, MagicMock
4+
5+
import pytest
6+
7+
from astrbot.core.provider.sources.typecast_tts_source import ProviderTypecastTTS
8+
9+
10+
def _make_provider(**overrides) -> ProviderTypecastTTS:
11+
config = {
12+
"id": "test-typecast",
13+
"type": "typecast_tts",
14+
"api_key": "test-api-key",
15+
"typecast-voice-id": "tc_60e5426de8b95f1d3000d7b5",
16+
"model": "ssfm-v30",
17+
"language": "kor",
18+
"typecast-emotion-preset": "normal",
19+
"typecast-emotion-intensity": 1.0,
20+
"typecast-volume": 100,
21+
"typecast-pitch": 0,
22+
"typecast-tempo": 1.0,
23+
"timeout": 30,
24+
}
25+
config.update(overrides)
26+
return ProviderTypecastTTS(provider_config=config, provider_settings={})
27+
28+
29+
@pytest.mark.asyncio
30+
async def test_get_audio_success(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
31+
"""Successful API call saves WAV and returns path."""
32+
provider = _make_provider()
33+
34+
monkeypatch.setattr(
35+
"astrbot.core.provider.sources.typecast_tts_source.get_astrbot_temp_path",
36+
lambda: str(tmp_path),
37+
)
38+
39+
fake_response = AsyncMock()
40+
fake_response.status_code = 200
41+
fake_response.headers = {"content-type": "audio/wav"}
42+
fake_response.aiter_bytes = lambda: _async_iter([b"RIFF", b"fake_wav_data"])
43+
44+
fake_client = AsyncMock()
45+
fake_client.__aenter__ = AsyncMock(return_value=fake_client)
46+
fake_client.__aexit__ = AsyncMock(return_value=False)
47+
fake_client.stream = MagicMock(return_value=_async_context_manager(fake_response))
48+
49+
monkeypatch.setattr(
50+
"astrbot.core.provider.sources.typecast_tts_source.AsyncClient",
51+
lambda **kwargs: fake_client,
52+
)
53+
54+
path = await provider.get_audio("Hello world")
55+
56+
assert path.endswith(".wav")
57+
assert os.path.exists(path)
58+
with open(path, "rb") as f:
59+
assert f.read() == b"RIFFfake_wav_data"
60+
61+
62+
@pytest.mark.asyncio
63+
async def test_get_audio_api_error(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
64+
"""API error raises RuntimeError with detail message."""
65+
provider = _make_provider()
66+
67+
monkeypatch.setattr(
68+
"astrbot.core.provider.sources.typecast_tts_source.get_astrbot_temp_path",
69+
lambda: str(tmp_path),
70+
)
71+
72+
fake_response = AsyncMock()
73+
fake_response.status_code = 401
74+
fake_response.headers = {"content-type": "application/json"}
75+
fake_response.aread = AsyncMock(return_value=b'{"detail": "Invalid API key"}')
76+
77+
fake_client = AsyncMock()
78+
fake_client.__aenter__ = AsyncMock(return_value=fake_client)
79+
fake_client.__aexit__ = AsyncMock(return_value=False)
80+
fake_client.stream = MagicMock(return_value=_async_context_manager(fake_response))
81+
82+
monkeypatch.setattr(
83+
"astrbot.core.provider.sources.typecast_tts_source.AsyncClient",
84+
lambda **kwargs: fake_client,
85+
)
86+
87+
with pytest.raises(RuntimeError, match="Invalid API key"):
88+
await provider.get_audio("Hello world")
89+
90+
91+
@pytest.mark.asyncio
92+
async def test_get_audio_request_body(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
93+
"""Verify the request body sent to Typecast API."""
94+
provider = _make_provider(
95+
**{
96+
"typecast-emotion-preset": "happy",
97+
"typecast-emotion-intensity": 1.5,
98+
"typecast-pitch": 3,
99+
"typecast-tempo": 1.2,
100+
}
101+
)
102+
103+
monkeypatch.setattr(
104+
"astrbot.core.provider.sources.typecast_tts_source.get_astrbot_temp_path",
105+
lambda: str(tmp_path),
106+
)
107+
108+
captured_kwargs = {}
109+
110+
fake_response = AsyncMock()
111+
fake_response.status_code = 200
112+
fake_response.headers = {"content-type": "audio/wav"}
113+
fake_response.aiter_bytes = lambda: _async_iter([b"fake_wav"])
114+
115+
fake_client = AsyncMock()
116+
fake_client.__aenter__ = AsyncMock(return_value=fake_client)
117+
fake_client.__aexit__ = AsyncMock(return_value=False)
118+
119+
def capture_stream(method, url, **kwargs):
120+
captured_kwargs.update(kwargs)
121+
return _async_context_manager(fake_response)
122+
123+
fake_client.stream = capture_stream
124+
125+
monkeypatch.setattr(
126+
"astrbot.core.provider.sources.typecast_tts_source.AsyncClient",
127+
lambda **kwargs: fake_client,
128+
)
129+
130+
await provider.get_audio("Test text")
131+
132+
body = captured_kwargs["json"]
133+
assert body["voice_id"] == "tc_60e5426de8b95f1d3000d7b5"
134+
assert body["text"] == "Test text"
135+
assert body["model"] == "ssfm-v30"
136+
assert body["language"] == "kor"
137+
assert body["prompt"]["emotion_type"] == "preset"
138+
assert body["prompt"]["emotion_preset"] == "happy"
139+
assert body["prompt"]["emotion_intensity"] == 1.5
140+
assert body["output"]["audio_pitch"] == 3
141+
assert body["output"]["audio_tempo"] == 1.2
142+
assert body["output"]["audio_format"] == "wav"
143+
assert body["output"]["volume"] == 100
144+
145+
146+
def test_provider_config_defaults():
147+
"""Default config values are applied correctly."""
148+
provider = ProviderTypecastTTS(
149+
provider_config={
150+
"id": "test-typecast",
151+
"type": "typecast_tts",
152+
"api_key": "test-api-key",
153+
"typecast-voice-id": "tc_60e5426de8b95f1d3000d7b5",
154+
},
155+
provider_settings={},
156+
)
157+
assert provider.voice_id == "tc_60e5426de8b95f1d3000d7b5"
158+
assert provider.model_name == "ssfm-v30"
159+
assert provider.language == "kor"
160+
assert provider.emotion_preset == "normal"
161+
assert provider.emotion_intensity == 1.0
162+
assert provider.volume == 100
163+
assert provider.pitch == 0
164+
assert provider.tempo == 1.0
165+
assert provider.timeout == 30
166+
167+
168+
def test_provider_config_missing_api_key():
169+
"""Missing api_key raises ValueError."""
170+
with pytest.raises(ValueError, match="api_key is required"):
171+
ProviderTypecastTTS(
172+
provider_config={
173+
"id": "test",
174+
"type": "typecast_tts",
175+
"typecast-voice-id": "tc_123",
176+
},
177+
provider_settings={},
178+
)
179+
180+
181+
def test_provider_config_missing_voice_id():
182+
"""Missing voice_id raises ValueError."""
183+
with pytest.raises(ValueError, match="typecast-voice-id is required"):
184+
ProviderTypecastTTS(
185+
provider_config={
186+
"id": "test",
187+
"type": "typecast_tts",
188+
"api_key": "test-key",
189+
},
190+
provider_settings={},
191+
)
192+
193+
194+
@pytest.mark.asyncio
195+
async def test_get_audio_empty_text():
196+
"""Empty text raises ValueError."""
197+
provider = _make_provider()
198+
with pytest.raises(ValueError, match="text must not be empty"):
199+
await provider.get_audio("")
200+
201+
202+
@pytest.mark.asyncio
203+
async def test_get_audio_text_too_long():
204+
"""Text exceeding 2000 chars raises ValueError."""
205+
provider = _make_provider()
206+
with pytest.raises(ValueError, match="exceeds maximum of 2000 characters"):
207+
await provider.get_audio("a" * 2001)
208+
209+
210+
def test_provider_config_invalid_numbers_use_defaults():
211+
"""Invalid numeric config values fall back to defaults."""
212+
provider = ProviderTypecastTTS(
213+
provider_config={
214+
"id": "test-typecast",
215+
"type": "typecast_tts",
216+
"api_key": "test-api-key",
217+
"typecast-voice-id": "tc_60e5426de8b95f1d3000d7b5",
218+
"typecast-emotion-intensity": "not-a-number",
219+
"typecast-volume": "not-a-number",
220+
"typecast-pitch": "not-a-number",
221+
"typecast-tempo": "not-a-number",
222+
"timeout": "not-a-number",
223+
},
224+
provider_settings={},
225+
)
226+
assert provider.emotion_intensity == 1.0
227+
assert provider.volume == 100
228+
assert provider.pitch == 0
229+
assert provider.tempo == 1.0
230+
assert provider.timeout == 30
231+
232+
233+
# --- Test helpers ---
234+
235+
async def _async_iter(items):
236+
for item in items:
237+
yield item
238+
239+
240+
class _async_context_manager:
241+
def __init__(self, response):
242+
self.response = response
243+
244+
async def __aenter__(self):
245+
return self.response
246+
247+
async def __aexit__(self, *args):
248+
pass

0 commit comments

Comments
 (0)