diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 246914ee0c..49def4a577 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -1623,10 +1623,14 @@ class ChatProviderTemplate(TypedDict): "type": "gsvi_tts_api", "provider": "gpt_sovits_inference", "provider_type": "text_to_speech", - "api_base": "http://127.0.0.1:5000", - "character": "", - "emotion": "default", "enable": False, + "api_key": "", + "api_base": "http://127.0.0.1:8000", + "version": "v4", + "character": "", + "prompt_text_lang": "中文", + "emotion": "默认", + "text_lang": "中文", "timeout": 20, }, "FishAudio TTS(API)": { diff --git a/astrbot/core/provider/sources/gsvi_tts_source.py b/astrbot/core/provider/sources/gsvi_tts_source.py index 425e801f46..55a0975de6 100644 --- a/astrbot/core/provider/sources/gsvi_tts_source.py +++ b/astrbot/core/provider/sources/gsvi_tts_source.py @@ -1,6 +1,5 @@ -import os -import urllib.parse import uuid +from pathlib import Path import aiohttp @@ -23,37 +22,55 @@ def __init__( provider_settings: dict, ) -> None: super().__init__(provider_config, provider_settings) - self.api_base = provider_config.get("api_base", "http://127.0.0.1:5000") + self.api_key = provider_config.get("api_key", "") + self.api_base = provider_config.get("api_base", "http://127.0.0.1:8000") self.api_base = self.api_base.removesuffix("/") + self.version = provider_config.get("version", "v4") self.character = provider_config.get("character") - self.emotion = provider_config.get("emotion") + self.prompt_text_lang = provider_config.get("prompt_text_lang", "中文") + self.emotion = provider_config.get("emotion", "默认") + self.text_lang = provider_config.get("text_lang", "中文") async def get_audio(self, text: str) -> str: temp_dir = get_astrbot_temp_path() - path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav") - params = {"text": text} + path = Path(temp_dir) / f"gsvi_tts_{uuid.uuid4()}.wav" + url = f"{self.api_base}/infer_single" - if self.character: - params["character"] = self.character - if self.emotion: - params["emotion"] = self.emotion + headers = {"Content-Type": "application/json"} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" - query_parts = [] - for key, value in params.items(): - encoded_value = urllib.parse.quote(str(value)) - query_parts.append(f"{key}={encoded_value}") - - url = f"{self.api_base}/tts?{'&'.join(query_parts)}" + data = { + "dl_url": self.api_base, + "version": self.version, + "model_name": self.character, + "prompt_text_lang": self.prompt_text_lang, + "emotion": self.emotion, + "text": text, + "text_lang": self.text_lang, + } async with aiohttp.ClientSession() as session: - async with session.get(url) as response: + async with session.post(url, json=data, headers=headers) as response: if response.status == 200: - with open(path, "wb") as f: - f.write(await response.read()) + resp_json = await response.json() + msg = resp_json.get("msg") + audio_url = resp_json.get("audio_url") + if not msg or msg != "合成成功": + raise Exception(f"GSVI TTS API 合成失败: {msg}") + async with session.get(audio_url) as audio_response: + if audio_response.status == 200: + with open(path, "wb") as f: + f.write(await audio_response.read()) + else: + error_text = await audio_response.text() + raise Exception( + f"GSVI TTS API 下载音频失败,状态码: {audio_response.status},错误: {error_text}", + ) else: error_text = await response.text() raise Exception( f"GSVI TTS API 请求失败,状态码: {response.status},错误: {error_text}", ) - return path + return str(path)