Skip to content

Commit 763508e

Browse files
committed
Ver 1.3.0: Streaming & Prettier ?
1 parent 893383f commit 763508e

7 files changed

Lines changed: 195 additions & 72 deletions

File tree

fish/gui.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from PyQt6.QtGui import QIcon, QPixmap
1111
from PyQt6.QtMultimedia import QAudioOutput, QMediaPlayer
1212
from PyQt6.QtWidgets import (
13+
QCheckBox,
1314
QComboBox,
1415
QFileDialog,
1516
QGridLayout,
@@ -40,8 +41,9 @@
4041
FAPTranscribeWidget,
4142
)
4243
from fish.input import TextEditorWidget
43-
from fish.modules.console import ConsoleStream, ConsoleWidget
44+
from fish.modules.console import ConsoleWidget
4445
from fish.modules.globals import STOP_BUTTON_QSS
46+
from fish.modules.log import stderr_stream, stdout_stream
4547
from fish.modules.registry import widget_registry
4648
from fish.modules.worker import TTSWorker
4749
from fish.utils.audio import get_devices
@@ -101,23 +103,14 @@ def __init__(self):
101103
self.main_layout.setAlignment(Qt.AlignmentFlag.AlignTop)
102104
self.setup_action_buttons(self.main_layout)
103105

104-
self.change_theme(self.theme_combo.currentIndex()) # initialize theme for 1st
106+
self.change_theme(self.theme_combo.currentIndex()) # initialize theme first
107+
108+
stdout_stream.new_message.connect(lambda msg: self.update_console(msg, "white"))
109+
stderr_stream.new_message.connect(lambda msg: self.update_console(msg, "red"))
105110

106111
# Use size hint to set a reasonable size
107112
self.setMinimumWidth(800)
108113

109-
# Redefined Stream
110-
self.stdout_stream = ConsoleStream()
111-
self.stderr_stream = ConsoleStream()
112-
self.stdout_stream.new_message.connect(
113-
lambda msg: self.update_console(msg, "white")
114-
)
115-
self.stderr_stream.new_message.connect(
116-
lambda msg: self.update_console(msg, "red")
117-
)
118-
sys.stdout = self.stdout_stream
119-
sys.stderr = self.stderr_stream
120-
121114
# Uploaded ref files
122115
self.files = []
123116

@@ -637,10 +630,14 @@ def setup_action_buttons(self, layout: QVBoxLayout):
637630
row_layout = QHBoxLayout()
638631
widget_registry.register(row, "action_widget")
639632

640-
self.now_audio = QLabel(_t("action.audio").format(audio_name="(null)"))
633+
self.now_audio = QLineEdit(_t("action.audio").format(audio_name="(null)"))
634+
self.now_audio.setMinimumWidth(200)
635+
641636
row_layout.addWidget(self.now_audio)
642-
row_layout.addStretch(1)
637+
# row_layout.addStretch(1)
643638

639+
self.stream = QCheckBox(_t("action.stream"))
640+
row_layout.addWidget(self.stream)
644641
self.start_button = QPushButton(_t("action.start"))
645642
self.start_button.clicked.connect(self.start_conversion)
646643
row_layout.addWidget(self.start_button)
@@ -674,11 +671,19 @@ def change_theme(self, index):
674671
}
675672
"""
676673
)
674+
self.now_audio.setStyleSheet(
675+
"""
676+
QLineEdit {
677+
border: none;
678+
color: black;
679+
}
680+
"""
681+
)
677682

678683
else:
679684
for widget in widget_registry.get_registered_widgets().values():
680685
widget.setStyleSheet("")
681-
686+
self.now_audio.setStyleSheet("")
682687
save_config()
683688
qdarktheme.setup_theme(config.theme)
684689

@@ -861,7 +866,7 @@ def open_file(self):
861866
)
862867
self.set_audio(file_name)
863868

864-
def set_audio(self, audio_file):
869+
def set_audio(self, audio_file: str):
865870
if Path(audio_file).exists():
866871
self.player.setSource(QUrl.fromLocalFile(audio_file))
867872
self.play_button.setText(_t("tts_output.play"))
@@ -916,7 +921,8 @@ def start_conversion(self):
916921
text = self.text_editor.input_edit.toPlainText()
917922

918923
audio_name = now.strftime("%Y%m%d_%H%M%S")
919-
audio_path = Path(self.save_audio_path.text()) / f"{audio_name}.mp3"
924+
wav_suffix = "wav" if self.stream.isChecked() else "mp3"
925+
audio_path = Path(self.save_audio_path.text()) / f"{audio_name}.{wav_suffix}"
920926
audio_path.parent.mkdir(parents=True, exist_ok=True)
921927
self.audio_path = str(audio_path)
922928
kwargs = dict(
@@ -926,6 +932,7 @@ def start_conversion(self):
926932
max_new_tokens=self.max_new_tokens_slider.value(),
927933
temperature=self.temperature_slider.value() / 1000.0,
928934
mp3_bitrate=int(self.mp3_bitrate_combo.currentText()),
935+
stream=self.stream.isChecked(),
929936
)
930937
self.tts_worker = TTSWorker(
931938
ref_files=self.files,
@@ -936,17 +943,22 @@ def start_conversion(self):
936943
audio_path=str(audio_path),
937944
**kwargs,
938945
)
939-
self.tts_worker.finished.connect(self.on_conversion_finished)
946+
self.tts_worker.finished_signal.connect(self.on_conversion_finished)
947+
self.tts_worker.packet_delay.connect(
948+
lambda t: self.latency_label.setText(
949+
_t("action.latency").format(latency=(t * 1000.0))
950+
)
951+
)
940952
self.tts_worker.start()
941953

942-
self.now_audio.setText(_t("action.audio").format(audio_name=str(audio_path)))
943-
944954
def stop_conversion(self):
945955
self.tts_worker.stop()
946-
self.tts_worker.wait()
956+
# self.tts_worker.wait()
947957
self.start_button.setEnabled(True)
948958
self.stop_button.setEnabled(False)
949959

950-
def on_conversion_finished(self):
951-
self.stop_conversion()
952-
self.set_audio(self.audio_path)()
960+
def on_conversion_finished(self, audio_path):
961+
self.now_audio.setText(_t("action.audio").format(audio_name=audio_path))
962+
self.set_audio(self.audio_path)
963+
self.start_button.setEnabled(True)
964+
self.stop_button.setEnabled(False)

fish/modules/log.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import logging
2+
import sys
3+
4+
from fish.modules.console import ConsoleStream
5+
6+
# Redefined Stream
7+
stdout_stream = ConsoleStream()
8+
stderr_stream = ConsoleStream()
9+
10+
sys.stdout = stdout_stream
11+
sys.stderr = stderr_stream
12+
13+
# Global logger
14+
logger = logging.getLogger()
15+
logger.setLevel(logging.INFO)
16+
stdout_handler = logging.StreamHandler(sys.stdout)
17+
stdout_handler.setLevel(logging.INFO)
18+
stderr_handler = logging.StreamHandler(sys.stdout)
19+
stderr_handler.setLevel(logging.WARNING)
20+
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
21+
stdout_handler.setFormatter(formatter)
22+
stderr_handler.setFormatter(formatter)
23+
logger.addHandler(stdout_handler)
24+
logger.addHandler(stderr_handler)

fish/modules/worker.py

Lines changed: 129 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,17 @@
33
os.environ["no_proxy"] = "localhost, 127.0.0.1, 0.0.0.0"
44
import re
55
import subprocess
6+
import time
7+
import wave
68
from pathlib import Path
79

810
import httpx
911
import ormsgpack
1012
import psutil
11-
from PyQt6.QtCore import QMutex, QMutexLocker, QThread, QWaitCondition, pyqtSignal
13+
import pyaudio
14+
from PyQt6.QtCore import QMutex, QMutexLocker, QThread, pyqtSignal
1215

16+
from fish.modules.log import logger
1317
from fish.utils.audio import ServeReferenceAudio, ServeTTSRequest
1418
from fish.utils.i18n import _t
1519

@@ -106,8 +110,26 @@ def terminate_process(self):
106110
self.process = None
107111

108112

109-
class TTSWorker(QThread):
110-
finished = pyqtSignal()
113+
class TimeWorker(QThread):
114+
time_signal = pyqtSignal(float)
115+
116+
def __init__(self, pause_time=0.1, parent=None):
117+
super().__init__(parent)
118+
self.start_time = time.time()
119+
self._stop_requested = False
120+
self.pause_time = pause_time
121+
122+
def run(self):
123+
while not self._stop_requested:
124+
time.sleep(self.pause_time)
125+
self.time_signal.emit(time.time() - self.start_time)
126+
127+
def stop(self):
128+
self._stop_requested = True
129+
130+
131+
class TTSWorker(BaseWorker):
132+
packet_delay = pyqtSignal(float)
111133

112134
def __init__(
113135
self,
@@ -120,26 +142,91 @@ def __init__(
120142
**kwargs,
121143
):
122144
super().__init__()
123-
self.mutex = QMutex()
124-
self.wait_condition = QWaitCondition()
125-
self._stop_requested = False
145+
126146
self.ref_files = ref_files
127147
self.ref_id = ref_id if len(ref_id) > 0 else None
128148
self.backend = backend
129149
self.text = text
130150
self.api_key = api_key
131151
self.audio_path = audio_path
132152
self.kwargs = kwargs
153+
self.time_worker = TimeWorker(pause_time=0.1)
154+
self.time_worker.time_signal.connect(self.calc_elapsed)
155+
self.elapsed = 0
156+
157+
def calc_elapsed(self, elapsed):
158+
self.elapsed = elapsed
159+
self.packet_delay.emit(elapsed)
133160

134161
def run(self):
135-
pre_files = [f for f in self.ref_files if not f.endswith(".lab")]
136-
audio_files = [
162+
self._process_audio_stream()
163+
164+
def _process_audio_stream(self):
165+
pre_files = self.get_pre_files()
166+
audio_files = self.filter_audio_files(pre_files)
167+
streaming = self.kwargs.get("stream", False)
168+
request = self.create_tts_request(audio_files, streaming)
169+
frames_per_buffer = 16384
170+
first_packet_time = None
171+
172+
self.time_worker.start()
173+
174+
if streaming:
175+
p, stream = self.initialize_audio_stream(frames_per_buffer)
176+
self.p = p
177+
self.stream = stream
178+
f = wave.open(self.audio_path, "wb")
179+
f.setnchannels(1)
180+
f.setsampwidth(2)
181+
f.setframerate(44100)
182+
else:
183+
f = open(self.audio_path, "wb")
184+
185+
self.f = f
186+
with httpx.Client() as client:
187+
with client.stream(
188+
"POST",
189+
self.backend,
190+
content=ormsgpack.packb(
191+
request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
192+
),
193+
headers={
194+
"authorization": f"Bearer {self.api_key}",
195+
"content-type": "application/msgpack",
196+
},
197+
timeout=None,
198+
) as response:
199+
for chunk in response.iter_bytes(chunk_size=frames_per_buffer):
200+
if first_packet_time is None:
201+
first_packet_time = self.elapsed
202+
self.time_worker.stop()
203+
204+
if self.is_interrupted:
205+
return
206+
207+
if streaming:
208+
stream.write(chunk)
209+
f.writeframesraw(chunk)
210+
else:
211+
f.write(chunk)
212+
213+
self.finish()
214+
215+
if not self.is_interrupted:
216+
self.finished_signal.emit(self.audio_path)
217+
218+
def get_pre_files(self):
219+
return [f for f in self.ref_files if not f.endswith(".lab")]
220+
221+
def filter_audio_files(self, pre_files: list):
222+
return [
137223
f
138224
for f in pre_files
139225
if Path(f).exists() and Path(f).with_suffix(".lab").exists()
140226
]
141227

142-
request = ServeTTSRequest(
228+
def create_tts_request(self, audio_files: list, streaming: bool):
229+
return ServeTTSRequest(
143230
text=self.text,
144231
references=[
145232
ServeReferenceAudio(
@@ -149,42 +236,40 @@ def run(self):
149236
for f in audio_files
150237
],
151238
reference_id=self.ref_id,
152-
streaming=False,
153-
format="mp3",
154-
chunk_length=self.kwargs["chunk_length"],
155-
top_p=self.kwargs["top_p"],
156-
repetition_penalty=self.kwargs["repetition_penalty"],
157-
max_new_tokens=self.kwargs["max_new_tokens"],
158-
temperature=self.kwargs["temperature"],
159-
mp3_bitrate=self.kwargs["mp3_bitrate"],
239+
streaming=streaming,
240+
format="wav" if streaming else "mp3",
241+
chunk_length=self.kwargs.get("chunk_length"),
242+
top_p=self.kwargs.get("top_p"),
243+
repetition_penalty=self.kwargs.get("repetition_penalty"),
244+
max_new_tokens=self.kwargs.get("max_new_tokens"),
245+
temperature=self.kwargs.get("temperature"),
246+
mp3_bitrate=self.kwargs.get("mp3_bitrate"),
160247
)
161248

162-
with httpx.Client() as client, open(f"{self.audio_path}", "wb") as f:
163-
with client.stream(
164-
"POST",
165-
self.backend,
166-
content=ormsgpack.packb(
167-
request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
168-
),
169-
headers={
170-
"authorization": f"Bearer {self.api_key}",
171-
"content-type": "application/msgpack",
172-
},
173-
timeout=None,
174-
) as response:
175-
for chunk in response.iter_bytes():
176-
self.mutex.lock()
177-
if self._stop_requested:
178-
print("TTS is interrupted!")
179-
self.mutex.unlock()
180-
break
181-
self.mutex.unlock()
182-
f.write(chunk)
183-
184-
self.finished.emit()
249+
def initialize_audio_stream(self, frames_per_buffer: int):
250+
p = pyaudio.PyAudio()
251+
stream = p.open(
252+
format=pyaudio.paInt16,
253+
channels=1,
254+
rate=44100,
255+
output=True,
256+
frames_per_buffer=frames_per_buffer,
257+
)
258+
return p, stream
185259

186260
def stop(self):
187-
self.mutex.lock()
188-
self._stop_requested = True
189-
self.mutex.unlock()
190-
self.wait_condition.wakeAll()
261+
self.is_interrupted = True
262+
logger.info("Stop requested!")
263+
self.finish()
264+
265+
def finish(self):
266+
streaming = self.kwargs.get("stream", False)
267+
if streaming:
268+
self.stream.stop_stream()
269+
self.stream.close()
270+
self.p.terminate()
271+
logger.warning("Stop streaming!")
272+
self.time_worker.stop()
273+
logger.info("Timer off!")
274+
self.f.close()
275+
logger.info("File closed!")

locales/en_US.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ tts_output:
8282

8383
action:
8484
audio: "Now playing: {audio_name}"
85-
toggle_console: "Open/Close Console"
85+
stream: "Streaming"
8686
start: "Start Text To Speech"
8787
stop: "Stop Text To Speech"
8888
latency: "Latency: {latency:.2f} ms"

0 commit comments

Comments
 (0)