diff --git a/webui.py b/webui.py index 24b795136..bcbe30744 100644 --- a/webui.py +++ b/webui.py @@ -20,9 +20,10 @@ import torchaudio import random import librosa +import soundfile as sf ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR)) -from cosyvoice.cli.cosyvoice import AutoModel +from cosyvoice.cli.cosyvoice import AutoModel, CosyVoice3 from cosyvoice.utils.file_utils import logging from cosyvoice.utils.common import set_all_random_seed @@ -33,6 +34,8 @@ '自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'} stream_mode_list = [('否', False), ('是', True)] max_val = 0.8 +# CosyVoice3's LLM requires an <|endofprompt|> token in prompt_text / tts_text; earlier generations do not. +cosyvoice3_system_prompt = 'You are a helpful assistant.<|endofprompt|>' def generate_seed(): @@ -75,8 +78,8 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro if prompt_wav is None: gr.Warning('prompt音频为空,您是否忘记输入prompt音频?') yield (cosyvoice.sample_rate, default_data) - if torchaudio.info(prompt_wav).sample_rate < prompt_sr: - gr.Warning('prompt音频采样率{}低于{}'.format(torchaudio.info(prompt_wav).sample_rate, prompt_sr)) + if sf.info(prompt_wav).samplerate < prompt_sr: + gr.Warning('prompt音频采样率{}低于{}'.format(sf.info(prompt_wav).samplerate, prompt_sr)) yield (cosyvoice.sample_rate, default_data) # sft mode only use sft_dropdown if mode_checkbox_group in ['预训练音色']: @@ -101,12 +104,14 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro elif mode_checkbox_group == '3s极速复刻': logging.info('get zero_shot inference request') set_all_random_seed(seed) - for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_wav, stream=stream, speed=speed): + zero_shot_prompt_text = cosyvoice3_system_prompt + prompt_text if isinstance(cosyvoice, CosyVoice3) else prompt_text + for i in cosyvoice.inference_zero_shot(tts_text, zero_shot_prompt_text, prompt_wav, stream=stream, speed=speed): yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten()) elif mode_checkbox_group == '跨语种复刻': logging.info('get cross_lingual inference request') set_all_random_seed(seed) - for i in cosyvoice.inference_cross_lingual(tts_text, prompt_wav, stream=stream, speed=speed): + cross_lingual_tts_text = cosyvoice3_system_prompt + tts_text if isinstance(cosyvoice, CosyVoice3) else tts_text + for i in cosyvoice.inference_cross_lingual(cross_lingual_tts_text, prompt_wav, stream=stream, speed=speed): yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten()) else: logging.info('get instruct inference request')