From b6ec1e500dd1322dbd1a6f5847e980565f47f420 Mon Sep 17 00:00:00 2001 From: BeckYang26 Date: Thu, 21 May 2026 05:04:51 +0000 Subject: [PATCH 1/2] =?UTF-8?q?perf(inference):=20=E9=99=8D=E4=BD=8E?= =?UTF-8?q?=E6=B5=81=E5=BC=8F=E9=A6=96=E5=8C=85=E5=BB=B6=E8=BF=9F=E5=B9=B6?= =?UTF-8?q?=E4=BC=98=E5=8C=96=20TRT=20=E6=8E=A8=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 流式推理用 Condition 替代 sleep 轮询,首包 hop_len 仍保持 25 - solve_euler 内复用 TRT context,用 wait_stream 替代全量 sync - 修复 TrtContextWrapper 中 CUDA Stream 对象创建方式 --- cosyvoice/cli/model.py | 95 +++++++++++++++++++++------ cosyvoice/flow/flow_matching.py | 110 +++++++++++++++++++------------- cosyvoice/utils/common.py | 2 +- 3 files changed, 141 insertions(+), 66 deletions(-) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 92a15d985..f302f0a76 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -272,7 +272,50 @@ def __init__(self, self.tts_speech_token_dict = {} self.llm_end_dict = {} self.hift_cache_dict = {} + self.condition_dict = {} self.silent_tokens = [] + # NOTE first chunk hop, matching token_hop_len / training chunk_size + self.first_token_hop_len = 25 + + def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): + cur_silent_token_num, max_silent_token_num = 0, 5 + with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False): + if isinstance(text, Generator): + assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!' + token_generator = self.llm.inference_bistream(text=text, + prompt_text=prompt_text.to(self.device), + prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), + prompt_speech_token=llm_prompt_speech_token.to(self.device), + prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), + embedding=llm_embedding.to(self.device)) + else: + token_generator = self.llm.inference(text=text.to(self.device), + text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device), + prompt_text=prompt_text.to(self.device), + prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), + prompt_speech_token=llm_prompt_speech_token.to(self.device), + prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), + embedding=llm_embedding.to(self.device), + uuid=uuid) + for i in token_generator: + if i in self.silent_tokens: + cur_silent_token_num += 1 + if cur_silent_token_num > max_silent_token_num: + continue + else: + cur_silent_token_num = 0 + with self.lock: + self.tts_speech_token_dict[uuid].append(i) + cond = self.condition_dict.get(uuid) + if cond is not None: + with cond: + cond.notify() + with self.lock: + self.llm_end_dict[uuid] = True + cond = self.condition_dict.get(uuid) + if cond is not None: + with cond: + cond.notify() def load_jit(self, flow_encoder_model): flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) @@ -335,6 +378,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.hift_cache_dict[this_uuid] = None + self.condition_dict[this_uuid] = threading.Condition(self.lock) if source_speech_token.shape[1] == 0: p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) else: @@ -342,25 +386,33 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze p.start() if stream is True: token_offset = 0 - prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1]) + hop_len = self.token_hop_len + prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / hop_len) * hop_len - flow_prompt_speech_token.shape[1]) while True: - time.sleep(0.1) - this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len - if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len: - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0) - this_tts_speech = self.token2wav(token=this_tts_speech_token, - prompt_token=flow_prompt_speech_token, - prompt_feat=prompt_speech_feat, - embedding=flow_embedding, - token_offset=token_offset, - uuid=this_uuid, - stream=stream, - finalize=False) - token_offset += this_token_hop_len - self.token_hop_len = min(self.token_max_hop_len, self.token_hop_len * self.stream_scale_factor) - yield {'tts_speech': this_tts_speech.cpu()} - if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len: + this_token_hop_len = (self.first_token_hop_len + prompt_token_pad if token_offset == 0 else hop_len) + needed = this_token_hop_len + self.flow.pre_lookahead_len + with self.lock: + while len(self.tts_speech_token_dict[this_uuid]) - token_offset < needed and not self.llm_end_dict[this_uuid]: + self.condition_dict[this_uuid].wait(timeout=0.05) + tokens_slice = list(self.tts_speech_token_dict[this_uuid][:token_offset + needed]) + llm_done = self.llm_end_dict[this_uuid] + if len(tokens_slice) - token_offset < needed and llm_done: break + this_tts_speech_token = torch.tensor(tokens_slice).unsqueeze(dim=0) + this_tts_speech = self.token2wav(token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + token_offset=token_offset, + uuid=this_uuid, + stream=stream, + finalize=False) + token_offset += this_token_hop_len + hop_len = min(self.token_max_hop_len, int(hop_len * self.stream_scale_factor)) + yield {'tts_speech': this_tts_speech.cpu()} + with self.lock: + if self.llm_end_dict[this_uuid] and len(self.tts_speech_token_dict[this_uuid]) - token_offset < needed: + break p.join() # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) @@ -386,9 +438,10 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze speed=speed) yield {'tts_speech': this_tts_speech.cpu()} with self.lock: - self.tts_speech_token_dict.pop(this_uuid) - self.llm_end_dict.pop(this_uuid) - self.hift_cache_dict.pop(this_uuid) + self.tts_speech_token_dict.pop(this_uuid, None) + self.llm_end_dict.pop(this_uuid, None) + self.hift_cache_dict.pop(this_uuid, None) + self.condition_dict.pop(this_uuid, None) if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.current_stream().synchronize() @@ -421,6 +474,8 @@ def __init__(self, self.hift_cache_dict = {} # FSQ silent and breath token self.silent_tokens = [1, 2, 28, 29, 55, 248, 494, 2241, 2242, 2322, 2323] + self.condition_dict = {} + self.first_token_hop_len = 25 def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0): with torch.cuda.amp.autocast(self.fp16): diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index d3beb9ec2..6d0285899 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -98,28 +98,44 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False): t_in = torch.zeros([2], device=x.device, dtype=spks.dtype) spks_in = torch.zeros([2, 80], device=x.device, dtype=spks.dtype) cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype) - for step in range(1, len(t_span)): - # Classifier-Free Guidance inference introduced in VoiceBox - x_in[:] = x - mask_in[:] = mask - mu_in[0] = mu - t_in[:] = t.unsqueeze(0) - spks_in[0] = spks - cond_in[0] = cond - dphi_dt = self.forward_estimator( - x_in, mask_in, - mu_in, t_in, - spks_in, - cond_in, - streaming - ) - dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0) - dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) - x = x + dt * dphi_dt - t = t + dt - sol.append(x) - if step < len(t_span) - 1: - dt = t_span[step + 1] - t + mask_in[:] = mask + mu_in[0] = mu + spks_in[0] = spks + cond_in[0] = cond + + trt_session = None + if not isinstance(self.estimator, torch.nn.Module): + [estimator, stream], trt_engine = self.estimator.acquire_estimator() + trt_session = (estimator, stream, trt_engine) + + try: + for step in range(1, len(t_span)): + # Classifier-Free Guidance inference introduced in VoiceBox + x_in[:] = x + t_in[:] = t.unsqueeze(0) + if isinstance(self.estimator, torch.nn.Module): + dphi_dt = self.forward_estimator( + x_in, mask_in, + mu_in, t_in, + spks_in, + cond_in, + streaming + ) + else: + dphi_dt = self._forward_estimator_trt( + trt_session[0], trt_session[1], trt_session[2], + x_in, mask_in, mu_in, t_in, spks_in, cond_in, + ) + dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0) + dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + finally: + if trt_session is not None: + self.estimator.release_estimator(trt_session[0], trt_session[1]) return sol[-1].float() @@ -128,29 +144,33 @@ def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False): return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming) else: [estimator, stream], trt_engine = self.estimator.acquire_estimator() - # NOTE need to synchronize when switching stream - torch.cuda.current_stream().synchronize() - with stream: - estimator.set_input_shape('x', (2, 80, x.size(2))) - estimator.set_input_shape('mask', (2, 1, x.size(2))) - estimator.set_input_shape('mu', (2, 80, x.size(2))) - estimator.set_input_shape('t', (2,)) - estimator.set_input_shape('spks', (2, 80)) - estimator.set_input_shape('cond', (2, 80, x.size(2))) - data_ptrs = [x.contiguous().data_ptr(), - mask.contiguous().data_ptr(), - mu.contiguous().data_ptr(), - t.contiguous().data_ptr(), - spks.contiguous().data_ptr(), - cond.contiguous().data_ptr(), - x.data_ptr()] - for i, j in enumerate(data_ptrs): - estimator.set_tensor_address(trt_engine.get_tensor_name(i), j) - # run trt engine - assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True - torch.cuda.current_stream().synchronize() - self.estimator.release_estimator(estimator, stream) - return x + try: + return self._forward_estimator_trt(estimator, stream, trt_engine, x, mask, mu, t, spks, cond) + finally: + self.estimator.release_estimator(estimator, stream) + + def _forward_estimator_trt(self, estimator, stream, trt_engine, x, mask, mu, t, spks, cond): + producer_stream = torch.cuda.current_stream() + with torch.cuda.stream(stream): + stream.wait_stream(producer_stream) + estimator.set_input_shape('x', (2, 80, x.size(2))) + estimator.set_input_shape('mask', (2, 1, x.size(2))) + estimator.set_input_shape('mu', (2, 80, x.size(2))) + estimator.set_input_shape('t', (2,)) + estimator.set_input_shape('spks', (2, 80)) + estimator.set_input_shape('cond', (2, 80, x.size(2))) + data_ptrs = [x.contiguous().data_ptr(), + mask.contiguous().data_ptr(), + mu.contiguous().data_ptr(), + t.contiguous().data_ptr(), + spks.contiguous().data_ptr(), + cond.contiguous().data_ptr(), + x.data_ptr()] + for i, j in enumerate(data_ptrs): + estimator.set_tensor_address(trt_engine.get_tensor_name(i), j) + assert estimator.execute_async_v3(stream.cuda_stream) is True + producer_stream.wait_stream(stream) + return x def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False): """Computes diffusion loss diff --git a/cosyvoice/utils/common.py b/cosyvoice/utils/common.py index 3f235a62e..f40ae03f1 100644 --- a/cosyvoice/utils/common.py +++ b/cosyvoice/utils/common.py @@ -202,7 +202,7 @@ def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'): self.trt_engine = trt_engine for _ in range(trt_concurrent): trt_context = trt_engine.create_execution_context() - trt_stream = torch.cuda.stream(torch.cuda.Stream(device)) + trt_stream = torch.cuda.Stream(device) assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent) self.trt_context_pool.put([trt_context, trt_stream]) assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context' From d65b01e187f71c6a2f0c861b7b3ddc06373d3560 Mon Sep 17 00:00:00 2001 From: BeckYang26 Date: Thu, 21 May 2026 08:09:58 +0000 Subject: [PATCH 2/2] =?UTF-8?q?feat(inference):=20Flow=20DiT=20TensorRT=20?= =?UTF-8?q?=E5=88=86=E6=A1=B6=E6=8E=A8=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 TrtBucketedContextWrapper,按 seq_len 路由 256/768/1536/3000 四档 engine - load_trt 支持 trt_bucket 参数,缺失 plan 时自动从 optimize.onnx 构建 - get_trt_kwargs 对齐 export_onnx 六个输入,支持 max_len 分桶 profile - CosyVoice/CosyVoice2/3 新增 trt_bucket 构造参数,优先使用 optimize.onnx - flow_matching 调用 acquire_estimator(seq_len=...) 完成运行时选桶 --- cosyvoice/cli/cosyvoice.py | 27 ++++++++----- cosyvoice/cli/model.py | 49 ++++++++++++++++++----- cosyvoice/flow/flow_matching.py | 4 +- cosyvoice/utils/common.py | 69 ++++++++++++++++++++++++++++++++- cosyvoice/utils/file_utils.py | 17 ++++++++ 5 files changed, 142 insertions(+), 24 deletions(-) diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index 7ab04a70f..a885f8991 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -20,13 +20,13 @@ import torch from cosyvoice.cli.frontend import CosyVoiceFrontEnd from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model -from cosyvoice.utils.file_utils import logging +from cosyvoice.utils.file_utils import logging, flow_decoder_estimator_onnx_model from cosyvoice.utils.class_utils import get_model_type class CosyVoice: - def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1): + def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, trt_bucket=False): self.model_dir = model_dir self.fp16 = fp16 if not os.path.exists(model_dir): @@ -57,9 +57,11 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_co '{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) if load_trt: self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), - '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), + flow_decoder_estimator_onnx_model(model_dir), trt_concurrent, - self.fp16) + self.fp16, + trt_bucket=trt_bucket, + model_dir=model_dir) del configs def list_available_spks(self): @@ -138,7 +140,8 @@ def inference_vc(self, source_wav, prompt_wav, stream=False, speed=1.0): class CosyVoice2(CosyVoice): - def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1): + def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1, + trt_bucket=False): self.model_dir = model_dir self.fp16 = fp16 if not os.path.exists(model_dir): @@ -169,9 +172,11 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, f self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) if load_trt: self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), - '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), + flow_decoder_estimator_onnx_model(model_dir), trt_concurrent, - self.fp16) + self.fp16, + trt_bucket=trt_bucket, + model_dir=model_dir) del configs def inference_instruct2(self, tts_text, instruct_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True): @@ -188,7 +193,7 @@ def inference_instruct2(self, tts_text, instruct_text, prompt_wav, zero_shot_spk class CosyVoice3(CosyVoice2): - def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1): + def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1, trt_bucket=False): self.model_dir = model_dir self.fp16 = fp16 if not os.path.exists(model_dir): @@ -219,9 +224,11 @@ def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_c if self.fp16 is True: logging.warning('DiT tensorRT fp16 engine have some performance issue, use at caution!') self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), - '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), + flow_decoder_estimator_onnx_model(model_dir), trt_concurrent, - self.fp16) + self.fp16, + trt_bucket=trt_bucket, + model_dir=model_dir) del configs diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index f302f0a76..cfe576969 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -22,12 +22,15 @@ from contextlib import nullcontext import uuid from cosyvoice.utils.common import fade_in_out -from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm -from cosyvoice.utils.common import TrtContextWrapper +from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm, flow_decoder_estimator_bucket_plan +from cosyvoice.utils.common import TrtContextWrapper, TrtBucketedContextWrapper class CosyVoiceModel: + TRT_BUCKET_MAX_LENS = (256, 768, 1536, 3000) + TRT_BUCKET_MEM_MB = {256: 1800.0, 768: 2600.0, 1536: 3400.0, 3000: 4600.0} + def __init__(self, llm: torch.nn.Module, flow: torch.nn.Module, @@ -80,22 +83,48 @@ def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model): flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) self.flow.encoder = flow_encoder - def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16): + def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16, + trt_bucket=False, model_dir=''): assert torch.cuda.is_available(), 'tensorrt only supports gpu!' - if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0: - convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16) del self.flow.decoder.estimator import tensorrt as trt + if trt_bucket is True: + assert model_dir, 'model_dir is required when trt_bucket is True' + bucket_engines = [] + os.makedirs(os.path.join(model_dir, 'trt_bucket_plans'), exist_ok=True) + for max_len in self.TRT_BUCKET_MAX_LENS: + plan_path = flow_decoder_estimator_bucket_plan(model_dir, max_len) + if not os.path.exists(plan_path) or os.path.getsize(plan_path) == 0: + convert_onnx_to_trt(plan_path, self.get_trt_kwargs(max_len=max_len), flow_decoder_onnx_model, fp16) + with open(plan_path, 'rb') as f: + estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) + assert estimator_engine is not None, 'failed to load trt {}'.format(plan_path) + bucket_engines.append({ + 'max_seq_len': max_len, + 'engine': estimator_engine, + 'plan_path': plan_path, + 'estimated_mem_mb': self.TRT_BUCKET_MEM_MB.get(max_len, 0.0), + }) + self.flow.decoder.estimator = TrtBucketedContextWrapper( + bucket_engines=bucket_engines, + trt_concurrent=trt_concurrent, + device=self.device, + ) + return + if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0: + convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16) with open(flow_decoder_estimator_model, 'rb') as f: estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device) - def get_trt_kwargs(self): - min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)] - opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)] - max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)] - input_names = ["x", "mask", "mu", "cond"] + def get_trt_kwargs(self, max_len=3000): + max_len = int(max(4, max_len)) + opt_len = int(min(max_len, max(64, max_len * 3 // 5))) + min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2,), (2, 80), (2, 80, 4)] + opt_shape = [(2, 80, opt_len), (2, 1, opt_len), (2, 80, opt_len), (2,), (2, 80), (2, 80, opt_len)] + max_shape = [(2, 80, max_len), (2, 1, max_len), (2, 80, max_len), (2,), (2, 80), (2, 80, max_len)] + input_names = ["x", "mask", "mu", "t", "spks", "cond"] return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 6d0285899..4506365f0 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -105,7 +105,7 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False): trt_session = None if not isinstance(self.estimator, torch.nn.Module): - [estimator, stream], trt_engine = self.estimator.acquire_estimator() + [estimator, stream], trt_engine = self.estimator.acquire_estimator(seq_len=x_in.size(2)) trt_session = (estimator, stream, trt_engine) try: @@ -143,7 +143,7 @@ def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False): if isinstance(self.estimator, torch.nn.Module): return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming) else: - [estimator, stream], trt_engine = self.estimator.acquire_estimator() + [estimator, stream], trt_engine = self.estimator.acquire_estimator(seq_len=x.size(2)) try: return self._forward_estimator_trt(estimator, stream, trt_engine, x, mask, mu, t, spks, cond) finally: diff --git a/cosyvoice/utils/common.py b/cosyvoice/utils/common.py index f40ae03f1..e81172a98 100644 --- a/cosyvoice/utils/common.py +++ b/cosyvoice/utils/common.py @@ -18,7 +18,7 @@ import queue import random -from typing import List +from typing import Dict, List import numpy as np import torch @@ -207,8 +207,73 @@ def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'): self.trt_context_pool.put([trt_context, trt_stream]) assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context' - def acquire_estimator(self): + def acquire_estimator(self, seq_len=None): + del seq_len return self.trt_context_pool.get(), self.trt_engine def release_estimator(self, context, stream): self.trt_context_pool.put([context, stream]) + + +class TrtBucketedContextWrapper: + def __init__(self, bucket_engines: List[Dict], trt_concurrent=1, device='cuda:0', + min_free_mem_mb=2048, reserve_free_mem_ratio=0.12): + self.trt_concurrent = trt_concurrent + self.device = device + self.min_free_mem_mb = float(min_free_mem_mb) + self.reserve_free_mem_ratio = float(reserve_free_mem_ratio) + self.buckets = [] + self._ctx_to_bucket = {} + for idx, item in enumerate(sorted(bucket_engines, key=lambda x: int(x['max_seq_len']))): + trt_engine = item['engine'] + max_seq_len = int(item['max_seq_len']) + plan_path = item.get('plan_path', 'bucket_{}'.format(idx)) + est_mem_mb = float(item.get('estimated_mem_mb', 0.0)) + trt_context_pool = queue.Queue(maxsize=trt_concurrent) + for _ in range(trt_concurrent): + trt_context = trt_engine.create_execution_context() + trt_stream = torch.cuda.Stream(device) + assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent) + trt_context_pool.put([trt_context, trt_stream]) + self._ctx_to_bucket[id(trt_context)] = idx + self.buckets.append({ + 'idx': idx, + 'max_seq_len': max_seq_len, + 'plan_path': plan_path, + 'estimated_mem_mb': est_mem_mb, + 'queue': trt_context_pool, + 'engine': trt_engine, + }) + assert len(self.buckets) > 0, 'no available trt bucket engine' + + def _memory_budget_ok(self, estimated_mem_mb): + if estimated_mem_mb <= 0: + return True + free_mem, total_mem = torch.cuda.mem_get_info(self.device) + free_mb = free_mem / (1024 * 1024) + total_mb = total_mem / (1024 * 1024) + reserve_mb = max(self.min_free_mem_mb, total_mb * self.reserve_free_mem_ratio) + return (free_mb - estimated_mem_mb) >= reserve_mb + + def _choose_bucket(self, seq_len): + candidates = [b for b in self.buckets if seq_len <= b['max_seq_len']] + if len(candidates) == 0: + raise RuntimeError('seq_len={} exceeds all TRT buckets (max={})'.format( + seq_len, self.buckets[-1]['max_seq_len'])) + for bucket in candidates: + if self._memory_budget_ok(bucket['estimated_mem_mb']): + return bucket + return candidates[0] + + def acquire_estimator(self, seq_len=None): + if seq_len is None: + bucket = self.buckets[-1] + else: + bucket = self._choose_bucket(int(seq_len)) + return bucket['queue'].get(), bucket['engine'] + + def release_estimator(self, context, stream): + idx = self._ctx_to_bucket.get(id(context), None) + if idx is None: + raise RuntimeError('unknown TRT context: cannot map back to bucket') + self.buckets[idx]['queue'].put([context, stream]) diff --git a/cosyvoice/utils/file_utils.py b/cosyvoice/utils/file_utils.py index b173ef201..0ccf340c8 100644 --- a/cosyvoice/utils/file_utils.py +++ b/cosyvoice/utils/file_utils.py @@ -50,6 +50,23 @@ def load_wav(wav, target_sr, min_sr=16000): return speech +def flow_decoder_estimator_onnx_model(model_dir): + """Prefer optimized Flow DiT ONNX, fall back to the official export.""" + for name in ( + 'flow.decoder.estimator.fp32.optimize.onnx', + 'flow.decoder.estimator.fp32.onnx', + ): + path = os.path.join(model_dir, name) + if os.path.isfile(path): + return path + return os.path.join(model_dir, 'flow.decoder.estimator.fp32.onnx') + + +def flow_decoder_estimator_bucket_plan(model_dir, max_len): + plan_dir = os.path.join(model_dir, 'trt_bucket_plans') + return os.path.join(plan_dir, 'flow.decoder.estimator.fp32.optimize.b{}.plan'.format(int(max_len))) + + def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16): import tensorrt as trt logging.info("Converting onnx to trt...")