22
33import asyncio
44import base64
5- import io
6- import warnings
75from typing import Any
86
9- import soundfile as sf
7+ import aiohttp
108
119from helpers import plugins
12- from helpers .notification import (
13- NotificationManager ,
14- NotificationPriority ,
15- NotificationType ,
16- )
1710from helpers .print_style import PrintStyle
1811from plugins ._kokoro_tts .helpers import migration
1912
2013
21- warnings .filterwarnings ("ignore" , category = FutureWarning )
22- warnings .filterwarnings ("ignore" , category = UserWarning )
23-
24-
2514PLUGIN_NAME = "_kokoro_tts"
2615DEFAULT_CONFIG = {
27- "voice" : "am_puck, am_onyx" ,
16+ "voice" : "am_onyx+am_echo " ,
2817 "speed" : 1.1 ,
18+ "remote_url" : "http://ares.moon-dragon.us:18890" ,
19+ "response_format" : "mp3" ,
2920}
3021
31- _pipeline = None
32- is_updating_model = False
22+ VALID_FORMATS = {"wav" , "mp3" , "opus" , "flac" }
23+ MIME_TYPES = {
24+ "wav" : "audio/wav" ,
25+ "mp3" : "audio/mpeg" ,
26+ "opus" : "audio/opus" ,
27+ "flac" : "audio/flac" ,
28+ }
29+
30+ _remote_healthy : bool | None = None
3331
3432
3533def normalize_config (config : dict [str , Any ] | None ) -> dict [str , Any ]:
@@ -48,6 +46,14 @@ def normalize_config(config: dict[str, Any] | None) -> dict[str, Any]:
4846 except (TypeError , ValueError ):
4947 pass
5048
49+ remote_url = str (config .get ("remote_url" , normalized ["remote_url" ]) or "" ).strip ()
50+ if remote_url :
51+ normalized ["remote_url" ] = remote_url .rstrip ("/" )
52+
53+ response_format = str (config .get ("response_format" , normalized ["response_format" ]) or "" ).strip ().lower ()
54+ if response_format in VALID_FORMATS :
55+ normalized ["response_format" ] = response_format
56+
5157 return normalized
5258
5359
@@ -68,79 +74,77 @@ async def preload(config: dict[str, Any] | None = None):
6874
6975
7076async def _preload ():
71- global _pipeline , is_updating_model
72-
73- while is_updating_model :
74- await asyncio .sleep (0.1 )
75-
77+ global _remote_healthy
7678 try :
77- is_updating_model = True
78- if not _pipeline :
79- NotificationManager .send_notification (
80- NotificationType .INFO ,
81- NotificationPriority .NORMAL ,
82- "Loading Kokoro TTS model..." ,
83- display_time = 99 ,
84- group = "kokoro-preload" ,
85- )
86- PrintStyle .standard ("Loading Kokoro TTS model..." )
87- from kokoro import KPipeline
88-
89- _pipeline = KPipeline (lang_code = "a" , repo_id = "hexgrad/Kokoro-82M" )
90- NotificationManager .send_notification (
91- NotificationType .INFO ,
92- NotificationPriority .NORMAL ,
93- "Kokoro TTS model loaded." ,
94- display_time = 2 ,
95- group = "kokoro-preload" ,
96- )
97- finally :
98- is_updating_model = False
79+ cfg = get_config ()
80+ remote_url = cfg .get ("remote_url" , DEFAULT_CONFIG ["remote_url" ])
81+ async with aiohttp .ClientSession () as session :
82+ async with session .get (
83+ f"{ remote_url } /health" ,
84+ timeout = aiohttp .ClientTimeout (total = 5 ),
85+ ) as resp :
86+ _remote_healthy = resp .status == 200
87+ if _remote_healthy :
88+ PrintStyle .standard ("Kokoro TTS remote API is healthy." )
89+ else :
90+ PrintStyle .error (f"Kokoro TTS remote API unhealthy: status { resp .status } " )
91+ except Exception as e :
92+ _remote_healthy = False
93+ PrintStyle .error (f"Kokoro TTS remote API check failed: { e } " )
9994
10095
10196async def is_downloading () -> bool :
102- return is_updating_model
97+ return False
10398
10499
105100async def is_downloaded () -> bool :
106- return _pipeline is not None
101+ if _remote_healthy is None :
102+ await _preload ()
103+ return _remote_healthy is True
107104
108105
109106async def synthesize_sentences (
110107 sentences : list [str ], config : dict [str , Any ] | None = None
111- ) -> str :
108+ ) -> tuple [ str , str ] :
112109 cfg = normalize_config (config or get_config ())
113110 return await _synthesize_sentences (
114111 sentences ,
115112 voice = str (cfg ["voice" ]),
116113 speed = float (cfg ["speed" ]),
114+ remote_url = str (cfg ["remote_url" ]),
115+ response_format = str (cfg ["response_format" ]),
117116 )
118117
119118
120119async def _synthesize_sentences (
121- sentences : list [str ], * , voice : str , speed : float
122- ) -> str :
123- await _preload ()
124-
125- combined_audio : list [float ] = []
120+ sentences : list [str ],
121+ * ,
122+ voice : str ,
123+ speed : float ,
124+ remote_url : str ,
125+ response_format : str ,
126+ ) -> tuple [str , str ]:
127+ text = " " .join (s .strip () for s in sentences if s .strip ())
128+ if not text :
129+ return "" , MIME_TYPES .get (response_format , "audio/mpeg" )
126130
127131 try :
128- for sentence in sentences :
129- if not sentence . strip ():
130- continue
131-
132- segments = _pipeline ( sentence . strip (), voice = voice , speed = speed ) # type: ignore[misc]
133- for segment in list ( segments ):
134- audio_tensor = segment . audio
135- audio_numpy = audio_tensor . detach (). cpu (). numpy () # type: ignore[union-attr]
136- combined_audio . extend ( audio_numpy . tolist ())
137-
138- if not combined_audio :
139- return ""
140-
141- buffer = io . BytesIO ()
142- sf . write ( buffer , combined_audio , 24000 , format = "WAV " )
143- return base64 .b64encode (buffer . getvalue ()) .decode ("utf-8" )
132+ async with aiohttp . ClientSession () as session :
133+ async with session . post (
134+ f" { remote_url } /v1/audio/speech" ,
135+ json = {
136+ "model" : "kokoro" ,
137+ "input" : text ,
138+ "voice" : voice ,
139+ "response_format" : response_format ,
140+ "speed" : speed ,
141+ },
142+ timeout = aiohttp . ClientTimeout ( total = 30 ),
143+ ) as resp :
144+ resp . raise_for_status ()
145+ audio_bytes = await resp . read ()
146+ mime_type = MIME_TYPES . get ( response_format , "audio/mpeg " )
147+ return base64 .b64encode (audio_bytes ) .decode ("utf-8" ), mime_type
144148 except Exception as e :
145- PrintStyle .error (f"Error in Kokoro TTS synthesis: { e } " )
149+ PrintStyle .error (f"Error in remote Kokoro TTS synthesis: { e } " )
146150 raise
0 commit comments