33os .environ ["no_proxy" ] = "localhost, 127.0.0.1, 0.0.0.0"
44import re
55import subprocess
6+ import time
7+ import wave
68from pathlib import Path
79
810import httpx
911import ormsgpack
1012import psutil
11- from PyQt6 .QtCore import QMutex , QMutexLocker , QThread , QWaitCondition , pyqtSignal
13+ import pyaudio
14+ from PyQt6 .QtCore import QMutex , QMutexLocker , QThread , pyqtSignal
1215
16+ from fish .modules .log import logger
1317from fish .utils .audio import ServeReferenceAudio , ServeTTSRequest
1418from fish .utils .i18n import _t
1519
@@ -106,8 +110,26 @@ def terminate_process(self):
106110 self .process = None
107111
108112
109- class TTSWorker (QThread ):
110- finished = pyqtSignal ()
113+ class TimeWorker (QThread ):
114+ time_signal = pyqtSignal (float )
115+
116+ def __init__ (self , pause_time = 0.1 , parent = None ):
117+ super ().__init__ (parent )
118+ self .start_time = time .time ()
119+ self ._stop_requested = False
120+ self .pause_time = pause_time
121+
122+ def run (self ):
123+ while not self ._stop_requested :
124+ time .sleep (self .pause_time )
125+ self .time_signal .emit (time .time () - self .start_time )
126+
127+ def stop (self ):
128+ self ._stop_requested = True
129+
130+
131+ class TTSWorker (BaseWorker ):
132+ packet_delay = pyqtSignal (float )
111133
112134 def __init__ (
113135 self ,
@@ -120,26 +142,91 @@ def __init__(
120142 ** kwargs ,
121143 ):
122144 super ().__init__ ()
123- self .mutex = QMutex ()
124- self .wait_condition = QWaitCondition ()
125- self ._stop_requested = False
145+
126146 self .ref_files = ref_files
127147 self .ref_id = ref_id if len (ref_id ) > 0 else None
128148 self .backend = backend
129149 self .text = text
130150 self .api_key = api_key
131151 self .audio_path = audio_path
132152 self .kwargs = kwargs
153+ self .time_worker = TimeWorker (pause_time = 0.1 )
154+ self .time_worker .time_signal .connect (self .calc_elapsed )
155+ self .elapsed = 0
156+
157+ def calc_elapsed (self , elapsed ):
158+ self .elapsed = elapsed
159+ self .packet_delay .emit (elapsed )
133160
134161 def run (self ):
135- pre_files = [f for f in self .ref_files if not f .endswith (".lab" )]
136- audio_files = [
162+ self ._process_audio_stream ()
163+
164+ def _process_audio_stream (self ):
165+ pre_files = self .get_pre_files ()
166+ audio_files = self .filter_audio_files (pre_files )
167+ streaming = self .kwargs .get ("stream" , False )
168+ request = self .create_tts_request (audio_files , streaming )
169+ frames_per_buffer = 16384
170+ first_packet_time = None
171+
172+ self .time_worker .start ()
173+
174+ if streaming :
175+ p , stream = self .initialize_audio_stream (frames_per_buffer )
176+ self .p = p
177+ self .stream = stream
178+ f = wave .open (self .audio_path , "wb" )
179+ f .setnchannels (1 )
180+ f .setsampwidth (2 )
181+ f .setframerate (44100 )
182+ else :
183+ f = open (self .audio_path , "wb" )
184+
185+ self .f = f
186+ with httpx .Client () as client :
187+ with client .stream (
188+ "POST" ,
189+ self .backend ,
190+ content = ormsgpack .packb (
191+ request , option = ormsgpack .OPT_SERIALIZE_PYDANTIC
192+ ),
193+ headers = {
194+ "authorization" : f"Bearer { self .api_key } " ,
195+ "content-type" : "application/msgpack" ,
196+ },
197+ timeout = None ,
198+ ) as response :
199+ for chunk in response .iter_bytes (chunk_size = frames_per_buffer ):
200+ if first_packet_time is None :
201+ first_packet_time = self .elapsed
202+ self .time_worker .stop ()
203+
204+ if self .is_interrupted :
205+ return
206+
207+ if streaming :
208+ stream .write (chunk )
209+ f .writeframesraw (chunk )
210+ else :
211+ f .write (chunk )
212+
213+ self .finish ()
214+
215+ if not self .is_interrupted :
216+ self .finished_signal .emit (self .audio_path )
217+
218+ def get_pre_files (self ):
219+ return [f for f in self .ref_files if not f .endswith (".lab" )]
220+
221+ def filter_audio_files (self , pre_files : list ):
222+ return [
137223 f
138224 for f in pre_files
139225 if Path (f ).exists () and Path (f ).with_suffix (".lab" ).exists ()
140226 ]
141227
142- request = ServeTTSRequest (
228+ def create_tts_request (self , audio_files : list , streaming : bool ):
229+ return ServeTTSRequest (
143230 text = self .text ,
144231 references = [
145232 ServeReferenceAudio (
@@ -149,42 +236,40 @@ def run(self):
149236 for f in audio_files
150237 ],
151238 reference_id = self .ref_id ,
152- streaming = False ,
153- format = "mp3" ,
154- chunk_length = self .kwargs [ "chunk_length" ] ,
155- top_p = self .kwargs [ "top_p" ] ,
156- repetition_penalty = self .kwargs [ "repetition_penalty" ] ,
157- max_new_tokens = self .kwargs [ "max_new_tokens" ] ,
158- temperature = self .kwargs [ "temperature" ] ,
159- mp3_bitrate = self .kwargs [ "mp3_bitrate" ] ,
239+ streaming = streaming ,
240+ format = "wav" if streaming else " mp3" ,
241+ chunk_length = self .kwargs . get ( "chunk_length" ) ,
242+ top_p = self .kwargs . get ( "top_p" ) ,
243+ repetition_penalty = self .kwargs . get ( "repetition_penalty" ) ,
244+ max_new_tokens = self .kwargs . get ( "max_new_tokens" ) ,
245+ temperature = self .kwargs . get ( "temperature" ) ,
246+ mp3_bitrate = self .kwargs . get ( "mp3_bitrate" ) ,
160247 )
161248
162- with httpx .Client () as client , open (f"{ self .audio_path } " , "wb" ) as f :
163- with client .stream (
164- "POST" ,
165- self .backend ,
166- content = ormsgpack .packb (
167- request , option = ormsgpack .OPT_SERIALIZE_PYDANTIC
168- ),
169- headers = {
170- "authorization" : f"Bearer { self .api_key } " ,
171- "content-type" : "application/msgpack" ,
172- },
173- timeout = None ,
174- ) as response :
175- for chunk in response .iter_bytes ():
176- self .mutex .lock ()
177- if self ._stop_requested :
178- print ("TTS is interrupted!" )
179- self .mutex .unlock ()
180- break
181- self .mutex .unlock ()
182- f .write (chunk )
183-
184- self .finished .emit ()
249+ def initialize_audio_stream (self , frames_per_buffer : int ):
250+ p = pyaudio .PyAudio ()
251+ stream = p .open (
252+ format = pyaudio .paInt16 ,
253+ channels = 1 ,
254+ rate = 44100 ,
255+ output = True ,
256+ frames_per_buffer = frames_per_buffer ,
257+ )
258+ return p , stream
185259
186260 def stop (self ):
187- self .mutex .lock ()
188- self ._stop_requested = True
189- self .mutex .unlock ()
190- self .wait_condition .wakeAll ()
261+ self .is_interrupted = True
262+ logger .info ("Stop requested!" )
263+ self .finish ()
264+
265+ def finish (self ):
266+ streaming = self .kwargs .get ("stream" , False )
267+ if streaming :
268+ self .stream .stop_stream ()
269+ self .stream .close ()
270+ self .p .terminate ()
271+ logger .warning ("Stop streaming!" )
272+ self .time_worker .stop ()
273+ logger .info ("Timer off!" )
274+ self .f .close ()
275+ logger .info ("File closed!" )
0 commit comments