Skip to content

Commit 75af8a6

Browse files
committed
Ver 1.4.4: fix bug in stopping audio playback
1 parent ad6cb43 commit 75af8a6

5 files changed

Lines changed: 153 additions & 66 deletions

File tree

fish/chat.py

Lines changed: 100 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,7 @@ def __init__(self):
485485
self.state = ChatState()
486486
self.thread_pool = QThreadPool.globalInstance()
487487
self.loop = asyncio.get_event_loop()
488+
self.async_msg_task = None
488489
self.initUI()
489490
self.init_messages()
490491

@@ -537,10 +538,15 @@ def initUI(self):
537538
self.clear_button.clicked.connect(self.clear_messages)
538539
self.clear_button.setStyleSheet(CLEAN_QSS)
539540

541+
self.stop_button = QPushButton(_t("ChatWidget.stop_btn"))
542+
self.stop_button.clicked.connect(self.stop_message_task)
543+
self.stop_button.setStyleSheet(CLEAN_QSS)
544+
540545
input_layout.addWidget(self.voice_mode_button)
541546
input_layout.addWidget(self.input_field)
542547
input_layout.addWidget(self.send_button)
543548
input_layout.addWidget(self.clear_button)
549+
input_layout.addWidget(self.stop_button)
544550
main_layout.addLayout(input_layout)
545551

546552
# Add the settings button as an overlay in the top-right corner
@@ -670,10 +676,7 @@ def add_message(
670676
)
671677
self.scroll_layout.addWidget(bubble)
672678
QApplication.processEvents()
673-
# Drag scroll bar to the bottom
674-
self.scroll_area.verticalScrollBar().setValue(
675-
self.scroll_area.verticalScrollBar().maximum()
676-
)
679+
677680
return bubble
678681

679682
def clear_messages(self):
@@ -715,10 +718,19 @@ def start_message_task(self, *, text: str = None, audio: str = None):
715718
message_worker.update_bubble_signal.connect(self.on_update_bubble)
716719
message_worker.update_duration_signal.connect(self.on_update_duration)
717720
message_worker.update_text_signal.connect(self.on_update_text)
718-
719721
# worker -> QRunnable -> QThreadPool
720-
async_task = AsyncTaskWorker(message_worker)
721-
self.thread_pool.start(async_task)
722+
self.async_msg_task = AsyncTaskWorker(message_worker)
723+
self.thread_pool.start(self.async_msg_task)
724+
pass
725+
726+
def stop_message_task(self):
727+
if self.async_msg_task:
728+
self.async_msg_task.cancel()
729+
self.async_msg_task = None
730+
731+
def on_message_task_finished(self, audio):
732+
self.audio_files.append(audio)
733+
logger.info("Message Task Complete")
722734
pass
723735

724736
def on_add_message(self, text, is_sender, is_voice, audio, duration):
@@ -746,21 +758,24 @@ def on_update_text(self, text):
746758
item: MessageBubble = self.scroll_layout.itemAt(num_bubbles - 1).widget()
747759
item.update_text(text)
748760

749-
def on_message_task_finished(self, audio):
750-
self.audio_files.append(audio)
751-
logger.info("Message Task Complete")
752-
pass
753-
754761
def update_bubble_size(self, mode: str = "all"):
755762
start_idx = self.scroll_layout.count() - 1 if mode == "last" else 0
756763
for i in range(start_idx, self.scroll_layout.count(), 1):
757764
item = self.scroll_layout.itemAt(i).widget()
758765
if isinstance(item, MessageBubble):
759766
item.msg.setFixedWidth(item.get_dynamic_width(self.width()))
767+
if mode == "last":
768+
# Drag scroll bar to the bottom
769+
self.scroll_area.verticalScrollBar().setValue(
770+
self.scroll_area.verticalScrollBar().maximum()
771+
)
760772
pass
761773

762774
def keyPressEvent(self, event):
763-
if event.key() in (Qt.Key.Key_Return, Qt.Key.Key_Enter):
775+
if event.modifiers() == Qt.Modifier.CTRL and event.key() == Qt.Key.Key_B:
776+
self.stop_message_task()
777+
event.accept()
778+
elif event.key() in (Qt.Key.Key_Return, Qt.Key.Key_Enter):
764779
if self.input_field.hasFocus():
765780
self.send_message_text()
766781
else:
@@ -820,8 +835,9 @@ def __init__(
820835
self.system_prompt = system_prompt
821836
self.system_audios = system_audios
822837
self.loop = loop
838+
self._task = None
823839

824-
async def send_message_async(self):
840+
async def send_message_async(self, cancel_event: asyncio.Event):
825841
text = self.input_text
826842
audio = self.input_audio
827843
agent = self.agent
@@ -875,32 +891,84 @@ async def send_message_async(self):
875891
self.update_bubble_signal.emit("last")
876892

877893
# Step 3: Generate audio and text segments in real-time
878-
async def infostream_generator():
894+
async def wave_generator(audio_data: bytes, cancel_event: asyncio.Event):
895+
chunk_size = 32768 # 32KB = 16K samples = 16384 / 44100 = 0.372 s
896+
offset = 0
897+
898+
while offset + chunk_size <= len(audio_data):
899+
# one method to stop async audioplayer is to cut off the wav-stream
900+
if cancel_event.is_set():
901+
break
902+
yield audio_data[offset : offset + chunk_size]
903+
offset += chunk_size
904+
905+
if cancel_event.is_set():
906+
yield b""
907+
elif offset < len(audio_data):
908+
yield audio_data[offset:]
909+
910+
async def infostream_generator(cancel_event: asyncio.Event):
879911
total_seg_time = 0.0
880-
yield wav_chunk_header()
881-
async for event in agent.stream(
882-
chat_ctx={"messages": self.state.conversation}
883-
):
884-
if event.type == FishE2EEventType.SPEECH_SEGMENT:
885-
self.state.append_to_chat_ctx(ServeVQPart(codes=event.vq_codes))
886-
total_seg_time += len(event.vq_codes[0]) / 21
887-
yield bytes(event.frame.data)
888-
self.update_duration_signal.emit(total_seg_time)
889-
elif event.type == FishE2EEventType.TEXT_SEGMENT:
890-
self.state.append_to_chat_ctx(ServeTextPart(text=event.text))
891-
self.update_text_signal.emit(
892-
self.state.repr_message(self.state.conversation[-1])
893-
)
894-
self.update_bubble_signal.emit("last")
912+
yield wav_chunk_header() # Initial header
913+
914+
try:
915+
async for event in agent.stream(
916+
chat_ctx={"messages": self.state.conversation}
917+
):
918+
if cancel_event.is_set():
919+
break
920+
921+
if event.type == FishE2EEventType.SPEECH_SEGMENT:
922+
self.state.append_to_chat_ctx(ServeVQPart(codes=event.vq_codes))
923+
total_seg_time += len(event.vq_codes[0]) / 21
924+
925+
audio_data = bytes(event.frame.data)
926+
async for chunk in wave_generator(audio_data, cancel_event):
927+
yield chunk
928+
929+
self.update_duration_signal.emit(total_seg_time)
930+
931+
elif event.type == FishE2EEventType.TEXT_SEGMENT:
932+
self.state.append_to_chat_ctx(ServeTextPart(text=event.text))
933+
self.update_text_signal.emit(
934+
self.state.repr_message(self.state.conversation[-1])
935+
)
936+
self.update_bubble_signal.emit("last")
937+
938+
except asyncio.CancelledError:
939+
logger.warning("Infostream generator was cancelled.")
940+
raise # Re-raise to assure interruption
895941

896942
# Step 4: Play audio (streaming)
943+
897944
audio_player = AudioPlayWorker(audio_path=temp_wavfile, streaming=True)
898-
await audio_player.run_async(infostream_generator())
945+
audio_player.set_chunks(infostream_generator(cancel_event))
946+
await audio_player.run_async()
899947
self.finished.emit(temp_wavfile)
900948

901949
def run(self):
902950
# Run asynchronous tasks in a new event loop using asyncio.run
903-
self.loop.run_until_complete(self.send_message_async())
951+
self.cancel_event = asyncio.Event()
952+
self._task = self.loop.create_task(self.send_message_async(self.cancel_event))
953+
self._task.add_done_callback(self.on_task_done)
954+
try:
955+
self.loop.run_until_complete(self._task)
956+
except asyncio.CancelledError:
957+
pass # Don't show redundant error
958+
959+
def cancel(self):
960+
if self._task:
961+
self.cancel_event.set()
962+
self._task.cancel()
963+
self._task = None
964+
965+
def on_task_done(self, task: asyncio.Task):
966+
if task.cancelled():
967+
logger.warning("Task was cancelled")
968+
elif task.exception():
969+
logger.error(f"Task encountered an exception: {task.exception()}")
970+
else:
971+
logger.info("Task completed successfully")
904972

905973

906974
if __name__ == "__main__":

fish/gui.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,7 @@ def start_conversion(self):
954954
**kwargs,
955955
)
956956
self.tts_worker.finished_signal.connect(self.on_conversion_finished)
957+
self.tts_worker.error_signal.connect(self.stop_conversion)
957958
self.tts_worker.packet_delay.connect(
958959
lambda t: self.latency_label.setText(
959960
_t("action.latency").format(latency=(t * 1000.0))

fish/modules/worker.py

Lines changed: 50 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55
import wave
66
from pathlib import Path
7-
from typing import Iterator, List
7+
from typing import AsyncIterator, Iterator, List
88

99
import numpy as np
1010
import ormsgpack
@@ -145,24 +145,22 @@ def __init__(
145145
self,
146146
audio_path: str,
147147
streaming: bool,
148-
iterable_chunks: Iterator[bytes] = None,
149148
frames_per_buffer: int = 16384,
150149
):
151150
super().__init__()
152151
self.audio_path = audio_path
153152
self.streaming = streaming
154-
self.iterable_chunks = iterable_chunks
155153
self.frames_per_buffer = frames_per_buffer
156-
self.is_interrupted = False
154+
self.iterable_chunks = None
157155

156+
self.is_interrupted = False
158157
self.elapsed = 0
159158
self.p = None
160159
self.stream = None
161160

162161
self.time_worker = TimeWorker(pause_time=0.1)
163162
self.time_worker.time_signal.connect(self.calc_elapsed)
164163

165-
# Sync Methods:
166164
def calc_elapsed(self, elapsed):
167165
self.elapsed = elapsed
168166
self.packet_delay.emit(elapsed)
@@ -190,6 +188,8 @@ def start_audio_streaming(self):
190188

191189
def audio_streaming(self):
192190
first_packet_time = None
191+
if not self.iterable_chunks:
192+
return
193193
for chunk in self.iterable_chunks:
194194
if self.is_interrupted:
195195
break
@@ -203,29 +203,11 @@ def audio_streaming(self):
203203
first_packet_time = self.elapsed
204204
self.time_worker.stop()
205205

206-
def stop_audio_streaming(self):
207-
if self.streaming and self.stream:
208-
self.stream.stop_stream()
209-
self.stream.close()
210-
self.p.terminate()
211-
self.f.close()
212-
213-
def run(self):
214-
self.start_audio_streaming()
215-
if self.iterable_chunks:
216-
self.audio_streaming()
217-
self.stop_audio_streaming()
218-
if not self.is_interrupted:
219-
self.finished_signal.emit(self.audio_path)
220-
221-
def stop(self):
222-
self.is_interrupted = True
223-
self.time_worker.stop()
224-
225-
# Async Methods:
226-
async def async_audio_streaming(self, async_chunks):
206+
async def async_audio_streaming(self):
227207
first_packet_time = None
228-
async for chunk in async_chunks:
208+
if not self.iterable_chunks:
209+
return
210+
async for chunk in self.iterable_chunks:
229211
if self.is_interrupted:
230212
break
231213
if self.streaming:
@@ -238,17 +220,44 @@ async def async_audio_streaming(self, async_chunks):
238220
first_packet_time = self.elapsed
239221
self.time_worker.stop()
240222

241-
async def run_async(self, async_chunks):
242-
self.time_worker.start()
223+
def stop_audio_streaming(self):
224+
if self.streaming and self.stream:
225+
self.stream.stop_stream()
226+
self.stream.close()
227+
self.p.terminate()
228+
self.f.close()
229+
logger.info("Playback Finished")
230+
231+
def set_chunks(self, chunks: Iterator[bytes] | AsyncIterator[bytes] = None):
232+
self.iterable_chunks = chunks
233+
234+
def run(self):
235+
logger.info("Sync Playback Started")
243236
self.start_audio_streaming()
244-
await self.async_audio_streaming(async_chunks)
237+
self.audio_streaming()
245238
self.stop_audio_streaming()
246239
if not self.is_interrupted:
240+
logger.info("Sync Playback Finished")
247241
self.finished_signal.emit(self.audio_path)
242+
243+
async def run_async(self):
244+
logger.info("Async Playback Started")
245+
self.start_audio_streaming()
246+
await self.async_audio_streaming()
247+
self.stop_audio_streaming()
248+
if not self.is_interrupted:
249+
logger.info("Async Playback Finished")
250+
self.finished_signal.emit(self.audio_path)
251+
252+
def stop(self):
253+
self.is_interrupted = True
248254
self.time_worker.stop()
255+
logger.info("Playback Stopped")
249256

250257

251258
class TTSWorker(AudioPlayWorker):
259+
error_signal = pyqtSignal()
260+
252261
def __init__(
253262
self,
254263
ref_files: List[str],
@@ -267,6 +276,7 @@ def __init__(
267276
self.text = text
268277
self.api_key = api_key
269278
self.kwargs = kwargs
279+
self.streaming = streaming
270280

271281
def get_pre_files(self):
272282
return [f for f in self.ref_files if not f.endswith(".lab")]
@@ -316,14 +326,15 @@ def run(self):
316326
},
317327
)
318328
response.raise_for_status()
319-
self.iterable_chunks = response.iter_content(
320-
chunk_size=self.frames_per_buffer
321-
)
329+
audio_chunks = response.iter_content(chunk_size=self.frames_per_buffer)
330+
self.set_chunks(audio_chunks)
322331
super().run()
323332
except requests.RequestException as e:
324333
logger.error(f"Network request failed: {e}")
334+
self.error_signal.emit()
325335
finally:
326336
self.stop() # Ensure the thread stops gracefully if there's an error
337+
response.close()
327338

328339

329340
class AudioRecordWorker(QThread):
@@ -413,3 +424,9 @@ def __init__(self, worker):
413424

414425
def run(self):
415426
self.worker.run()
427+
428+
def cancel(self):
429+
self.worker.cancel()
430+
431+
def stop(self):
432+
self.worker.stop()

locales/en_US.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,4 +262,5 @@ ChatWidget:
262262
send_btn: "Send"
263263
clear_confirm: "Are you sure to clear chat history?"
264264
clear_btn: "Clear"
265+
stop_btn: "Stop"
265266
recording: "Recording: {dur:.1f} s"

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "fish-speech-gui"
3-
version = "1.4.3"
3+
version = "1.4.4"
44
description = "fish-speech comfortable GUI"
55
readme = "README.md"
66
requires-python = "<3.12,>=3.10"

0 commit comments

Comments
 (0)