Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions cosyvoice/cli/cosyvoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
import torch
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.file_utils import logging, flow_decoder_estimator_onnx_model
from cosyvoice.utils.class_utils import get_model_type


class CosyVoice:

def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, trt_bucket=False):
self.model_dir = model_dir
self.fp16 = fp16
if not os.path.exists(model_dir):
Expand Down Expand Up @@ -57,9 +57,11 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_co
'{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
if load_trt:
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
flow_decoder_estimator_onnx_model(model_dir),
trt_concurrent,
self.fp16)
self.fp16,
trt_bucket=trt_bucket,
model_dir=model_dir)
del configs

def list_available_spks(self):
Expand Down Expand Up @@ -138,7 +140,8 @@ def inference_vc(self, source_wav, prompt_wav, stream=False, speed=1.0):

class CosyVoice2(CosyVoice):

def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1,
trt_bucket=False):
self.model_dir = model_dir
self.fp16 = fp16
if not os.path.exists(model_dir):
Expand Down Expand Up @@ -169,9 +172,11 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, f
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
if load_trt:
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
flow_decoder_estimator_onnx_model(model_dir),
trt_concurrent,
self.fp16)
self.fp16,
trt_bucket=trt_bucket,
model_dir=model_dir)
del configs

def inference_instruct2(self, tts_text, instruct_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
Expand All @@ -188,7 +193,7 @@ def inference_instruct2(self, tts_text, instruct_text, prompt_wav, zero_shot_spk

class CosyVoice3(CosyVoice2):

def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1, trt_bucket=False):
self.model_dir = model_dir
self.fp16 = fp16
if not os.path.exists(model_dir):
Expand Down Expand Up @@ -219,9 +224,11 @@ def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_c
if self.fp16 is True:
logging.warning('DiT tensorRT fp16 engine have some performance issue, use at caution!')
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
flow_decoder_estimator_onnx_model(model_dir),
trt_concurrent,
self.fp16)
self.fp16,
trt_bucket=trt_bucket,
model_dir=model_dir)
del configs


Expand Down
144 changes: 114 additions & 30 deletions cosyvoice/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@
from contextlib import nullcontext
import uuid
from cosyvoice.utils.common import fade_in_out
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
from cosyvoice.utils.common import TrtContextWrapper
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm, flow_decoder_estimator_bucket_plan
from cosyvoice.utils.common import TrtContextWrapper, TrtBucketedContextWrapper


class CosyVoiceModel:

TRT_BUCKET_MAX_LENS = (256, 768, 1536, 3000)
TRT_BUCKET_MEM_MB = {256: 1800.0, 768: 2600.0, 1536: 3400.0, 3000: 4600.0}

def __init__(self,
llm: torch.nn.Module,
flow: torch.nn.Module,
Expand Down Expand Up @@ -80,22 +83,48 @@ def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder

def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16,
trt_bucket=False, model_dir=''):
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
del self.flow.decoder.estimator
import tensorrt as trt
if trt_bucket is True:
assert model_dir, 'model_dir is required when trt_bucket is True'
bucket_engines = []
os.makedirs(os.path.join(model_dir, 'trt_bucket_plans'), exist_ok=True)
for max_len in self.TRT_BUCKET_MAX_LENS:
plan_path = flow_decoder_estimator_bucket_plan(model_dir, max_len)
if not os.path.exists(plan_path) or os.path.getsize(plan_path) == 0:
convert_onnx_to_trt(plan_path, self.get_trt_kwargs(max_len=max_len), flow_decoder_onnx_model, fp16)
with open(plan_path, 'rb') as f:
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert estimator_engine is not None, 'failed to load trt {}'.format(plan_path)
bucket_engines.append({
'max_seq_len': max_len,
'engine': estimator_engine,
'plan_path': plan_path,
'estimated_mem_mb': self.TRT_BUCKET_MEM_MB.get(max_len, 0.0),
})
self.flow.decoder.estimator = TrtBucketedContextWrapper(
bucket_engines=bucket_engines,
trt_concurrent=trt_concurrent,
device=self.device,
)
return
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
with open(flow_decoder_estimator_model, 'rb') as f:
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)

def get_trt_kwargs(self):
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
input_names = ["x", "mask", "mu", "cond"]
def get_trt_kwargs(self, max_len=3000):
max_len = int(max(4, max_len))
opt_len = int(min(max_len, max(64, max_len * 3 // 5)))
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2,), (2, 80), (2, 80, 4)]
opt_shape = [(2, 80, opt_len), (2, 1, opt_len), (2, 80, opt_len), (2,), (2, 80), (2, 80, opt_len)]
max_shape = [(2, 80, max_len), (2, 1, max_len), (2, 80, max_len), (2,), (2, 80), (2, 80, max_len)]
input_names = ["x", "mask", "mu", "t", "spks", "cond"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}

def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
Expand Down Expand Up @@ -272,7 +301,50 @@ def __init__(self,
self.tts_speech_token_dict = {}
self.llm_end_dict = {}
self.hift_cache_dict = {}
self.condition_dict = {}
self.silent_tokens = []
# NOTE first chunk hop, matching token_hop_len / training chunk_size
self.first_token_hop_len = 25

def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
cur_silent_token_num, max_silent_token_num = 0, 5
with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
if isinstance(text, Generator):
assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'
token_generator = self.llm.inference_bistream(text=text,
prompt_text=prompt_text.to(self.device),
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
embedding=llm_embedding.to(self.device))
else:
token_generator = self.llm.inference(text=text.to(self.device),
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
prompt_text=prompt_text.to(self.device),
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
embedding=llm_embedding.to(self.device),
uuid=uuid)
for i in token_generator:
if i in self.silent_tokens:
cur_silent_token_num += 1
if cur_silent_token_num > max_silent_token_num:
continue
else:
cur_silent_token_num = 0
with self.lock:
self.tts_speech_token_dict[uuid].append(i)
cond = self.condition_dict.get(uuid)
if cond is not None:
with cond:
cond.notify()
with self.lock:
self.llm_end_dict[uuid] = True
cond = self.condition_dict.get(uuid)
if cond is not None:
with cond:
cond.notify()

def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
Expand Down Expand Up @@ -335,32 +407,41 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.hift_cache_dict[this_uuid] = None
self.condition_dict[this_uuid] = threading.Condition(self.lock)
if source_speech_token.shape[1] == 0:
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
else:
p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
p.start()
if stream is True:
token_offset = 0
prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1])
hop_len = self.token_hop_len
prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / hop_len) * hop_len - flow_prompt_speech_token.shape[1])
while True:
time.sleep(0.1)
this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len:
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
token_offset=token_offset,
uuid=this_uuid,
stream=stream,
finalize=False)
token_offset += this_token_hop_len
self.token_hop_len = min(self.token_max_hop_len, self.token_hop_len * self.stream_scale_factor)
yield {'tts_speech': this_tts_speech.cpu()}
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
this_token_hop_len = (self.first_token_hop_len + prompt_token_pad if token_offset == 0 else hop_len)
needed = this_token_hop_len + self.flow.pre_lookahead_len
with self.lock:
while len(self.tts_speech_token_dict[this_uuid]) - token_offset < needed and not self.llm_end_dict[this_uuid]:
self.condition_dict[this_uuid].wait(timeout=0.05)
tokens_slice = list(self.tts_speech_token_dict[this_uuid][:token_offset + needed])
llm_done = self.llm_end_dict[this_uuid]
if len(tokens_slice) - token_offset < needed and llm_done:
break
this_tts_speech_token = torch.tensor(tokens_slice).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
token_offset=token_offset,
uuid=this_uuid,
stream=stream,
finalize=False)
token_offset += this_token_hop_len
hop_len = min(self.token_max_hop_len, int(hop_len * self.stream_scale_factor))
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
if self.llm_end_dict[this_uuid] and len(self.tts_speech_token_dict[this_uuid]) - token_offset < needed:
break
p.join()
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
Expand All @@ -386,9 +467,10 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
speed=speed)
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
self.tts_speech_token_dict.pop(this_uuid, None)
self.llm_end_dict.pop(this_uuid, None)
self.hift_cache_dict.pop(this_uuid, None)
self.condition_dict.pop(this_uuid, None)
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.current_stream().synchronize()
Expand Down Expand Up @@ -421,6 +503,8 @@ def __init__(self,
self.hift_cache_dict = {}
# FSQ silent and breath token
self.silent_tokens = [1, 2, 28, 29, 55, 248, 494, 2241, 2242, 2322, 2323]
self.condition_dict = {}
self.first_token_hop_len = 25
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里first_token_hop_len可以改更小一点,首包会更快,但可能音质有损失


def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16):
Expand Down
Loading