Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions bench_cosyvoice3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""QPS benchmark for CosyVoice3.

Usage:
python bench_cosyvoice3.py # vllm only
python bench_cosyvoice3.py --trt # vllm + trt
python bench_cosyvoice3.py --no-vllm # baseline (no acceleration)
"""
import sys, time, statistics, threading, queue, argparse
sys.path.append('third_party/Matcha-TTS')

PROMPT_TEXT = 'You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。'
PROMPT_WAV = './asset/zero_shot_prompt.wav'

TEXTS = {
'short': '你好,今天天气真不错。',
'medium': '收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。',
'long': '在人工智能技术飞速发展的今天,语音合成已经从早期生硬的拼接方式,进化到如今能够表达丰富情感、自然流畅的神经网络模型。CosyVoice 作为阿里达摩院推出的多语言语音生成模型,在零样本音色克隆、跨语种合成、多方言支持等方面都展现出了令人惊艳的能力,为众多应用场景带来了新的可能性。',
}


def run_once(model, text, seed=0):
from cosyvoice.utils.common import set_all_random_seed
set_all_random_seed(seed)
t0 = time.time()
audio_sec = 0.0
for _, j in enumerate(model.inference_zero_shot(text, PROMPT_TEXT, PROMPT_WAV, stream=False)):
audio_sec += j['tts_speech'].shape[-1] / model.sample_rate
return time.time() - t0, audio_sec


def bench_sequential(model, iters=5):
print('\n=== Sequential ===', flush=True)
for name, text in TEXTS.items():
run_once(model, text, seed=99) # warmup
walls, audios = [], []
for i in range(iters):
w, a = run_once(model, text, seed=i)
walls.append(w); audios.append(a)
avg_w = statistics.mean(walls)
avg_a = statistics.mean(audios)
print(f'{name:>7} | chars={len(text):>3} | wall={avg_w:.2f}s audio={avg_a:.2f}s RTF={avg_w/avg_a:.3f}', flush=True)


def bench_concurrent(model, text_name='medium', concurrencies=(1, 2, 4, 8), per_round=4):
print(f'\n=== Concurrent (text={text_name}, per_round={per_round}) ===', flush=True)
text = TEXTS[text_name]
for conc in concurrencies:
total = conc * per_round
work_q = queue.Queue()
for i in range(total):
work_q.put(i)
latencies, audios = [], []
lock = threading.Lock()

def worker():
while True:
try:
seed = work_q.get_nowait()
except queue.Empty:
return
w, a = run_once(model, text, seed=seed)
with lock:
latencies.append(w); audios.append(a)

t0 = time.time()
threads = [threading.Thread(target=worker) for _ in range(conc)]
for t in threads: t.start()
for t in threads: t.join()
wall = time.time() - t0

if not latencies: continue
latencies.sort()
p50 = latencies[len(latencies) // 2]
p95 = latencies[int(len(latencies) * 0.95)]
qps = total / wall
rt = sum(audios) / wall
print(f'conc={conc} n={total} | QPS={qps:.2f} audio_thru={rt:.2f}x | lat avg={statistics.mean(latencies):.2f}s p50={p50:.2f}s p95={p95:.2f}s', flush=True)


def main():
ap = argparse.ArgumentParser()
ap.add_argument('--trt', action='store_true')
ap.add_argument('--no-vllm', action='store_true')
ap.add_argument('--concurrent-only', action='store_true')
args = ap.parse_args()

use_vllm = not args.no_vllm
use_trt = args.trt

if use_vllm:
from vllm import ModelRegistry
from cosyvoice.vllm.cosyvoice2 import CosyVoice2ForCausalLM
ModelRegistry.register_model("CosyVoice2ForCausalLM", CosyVoice2ForCausalLM)

from cosyvoice.cli.cosyvoice import AutoModel

print(f'Config: vllm={use_vllm} trt={use_trt}', flush=True)
print('Loading...', flush=True)
t0 = time.time()
model = AutoModel(model_dir='pretrained_models/Fun-CosyVoice3-0.5B',
load_trt=use_trt, load_vllm=use_vllm, fp16=False)
print(f'Loaded in {time.time()-t0:.2f}s', flush=True)

if not args.concurrent_only:
bench_sequential(model, iters=5)
bench_concurrent(model, text_name='medium', concurrencies=(1, 2, 4, 8), per_round=4)


if __name__ == '__main__':
main()
22 changes: 22 additions & 0 deletions bench_push.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Push higher concurrency + short text benchmark."""
import sys
sys.path.append('third_party/Matcha-TTS')

from vllm import ModelRegistry
from cosyvoice.vllm.cosyvoice2 import CosyVoice2ForCausalLM
ModelRegistry.register_model('CosyVoice2ForCausalLM', CosyVoice2ForCausalLM)

from cosyvoice.cli.cosyvoice import AutoModel
import bench_cosyvoice3 as B


def main():
m = AutoModel(model_dir='pretrained_models/Fun-CosyVoice3-0.5B', load_trt=True, load_vllm=True, fp16=False)
print('===SHORT TEXT, push concurrency===', flush=True)
B.bench_concurrent(m, text_name='short', concurrencies=(4, 8, 16, 32), per_round=4)
print('===MEDIUM TEXT, push concurrency===', flush=True)
B.bench_concurrent(m, text_name='medium', concurrencies=(8, 16, 32), per_round=2)


if __name__ == '__main__':
main()
169 changes: 169 additions & 0 deletions cosyvoice/bin/export_hift_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Export the conv-only path of (Causal)HiFTGenerator.decode for TRT fp16.
#
# Split point:
# PyTorch (kept): f0_predictor -> sine source -> STFT(s)
# conv_pre (causal, takes 1 or 2 args by finalize flag)
# iSTFT, finalize-truncate, audio_limit clamp
# TRT (this export): leaky_relu + ups + (reflection_pad on last) + source_downs
# + source_resblocks + resblocks (Snake act) + conv_post
# + exp/sin to magnitude/phase -- the dense GPU work
#
# Inputs to the engine: x_post_conv_pre (B, base_channels, T_x), s_stft (B, n_fft+2, T_stft)
# Outputs: magnitude (B, n_fft//2+1, T_out), phase same shape
import argparse, os, sys, random
import torch
import torch.nn as nn
import torch.nn.functional as F
import onnxruntime
from tqdm import tqdm

ROOT = os.path.dirname(os.path.abspath(__file__))
sys.path.append(f'{ROOT}/../..')
sys.path.append(f'{ROOT}/../../third_party/Matcha-TTS')

from cosyvoice.cli.cosyvoice import AutoModel


def _strip_weight_norm(module: nn.Module):
"""Remove weight_norm regardless of legacy hook or new parametrize API."""
from torch.nn.utils import remove_weight_norm as _legacy
from torch.nn.utils.parametrize import remove_parametrizations
for m in module.modules():
# New parametrize style (PyTorch >=2.4)
if hasattr(m, 'parametrizations') and 'weight' in getattr(m, 'parametrizations', {}):
try:
remove_parametrizations(m, 'weight', leave_parametrized=True)
continue
except Exception:
pass
# Legacy hook style
for hook in list(getattr(m, '_forward_pre_hooks', {}).values()):
if hook.__class__.__name__ == 'WeightNorm':
try:
_legacy(m, 'weight')
except Exception:
pass
break


class HiftDecoderConvBlock(nn.Module):
"""The pure-conv post-conv_pre path of (Causal)HiFTGenerator.decode."""

def __init__(self, hift):
super().__init__()
self.ups = hift.ups
self.source_downs = hift.source_downs
self.source_resblocks = hift.source_resblocks
self.resblocks = hift.resblocks
self.conv_post = hift.conv_post
self.reflection_pad = hift.reflection_pad
self.lrelu_slope = hift.lrelu_slope
self.num_upsamples = hift.num_upsamples
self.num_kernels = hift.num_kernels
self.n_fft_half_p1 = hift.istft_params['n_fft'] // 2 + 1

def forward(self, x: torch.Tensor, s_stft: torch.Tensor):
for i in range(self.num_upsamples):
x = F.leaky_relu(x, self.lrelu_slope)
# ups[i] is CausalConv1dUpsample (CausalHiFTGenerator) or ConvTranspose1d.
# Both can be invoked with single arg; default empty cache hits zero-pad path.
x = self.ups[i](x)
if i == self.num_upsamples - 1:
x = self.reflection_pad(x)
si = self.source_downs[i](s_stft)
si = self.source_resblocks[i](si)
x = x + si
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs = xs + self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
magnitude = torch.exp(x[:, :self.n_fft_half_p1, :])
phase = torch.sin(x[:, self.n_fft_half_p1:, :])
return magnitude, phase


def _probe_shapes(hift, device):
# Build a dummy input by running the PyTorch path and snapshotting tensors at split points.
# T_x = mel chunk length post conv_pre (causal pad shrinks input by causal_padding).
# Use a representative chunk size: 25 tokens * 2 mel_ratio = 50 mel frames; conv_pre w/ pad=3 keeps T.
dummy_mel = torch.randn(1, 80, 80, device=device, dtype=torch.float32)
# f0 -> source -> STFT path mirrors CausalHiFTGenerator.inference (needs float64 f0 predictor)
hift.f0_predictor.to(torch.float64)
f0 = hift.f0_predictor(dummy_mel.to(torch.float64), finalize=True).to(dummy_mel)
s = hift.f0_upsamp(f0[:, None]).transpose(1, 2)
s, _, _ = hift.m_source(s)
s = s.transpose(1, 2)
# decode() preamble:
x = hift.conv_pre(dummy_mel)
s_real, s_imag = hift._stft(s.squeeze(1))
s_stft = torch.cat([s_real, s_imag], dim=1)
return x, s_stft


@torch.no_grad()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', default='pretrained_models/Fun-CosyVoice3-0.5B')
args = parser.parse_args()

print(f'[export] loading {args.model_dir} ...', flush=True)
auto = AutoModel(model_dir=args.model_dir, load_trt=False, load_vllm=False, fp16=False)
hift = auto.model.hift
device = next(hift.parameters()).device

print('[export] removing weight_norm on hift (new+legacy APIs) ...', flush=True)
_strip_weight_norm(hift)
hift.eval()

block = HiftDecoderConvBlock(hift).eval().to(device)

print('[export] probing tensor shapes via PyTorch fwd ...', flush=True)
x_dummy, s_stft_dummy = _probe_shapes(hift, device)
print(f' x={tuple(x_dummy.shape)} s_stft={tuple(s_stft_dummy.shape)}', flush=True)

onnx_path = os.path.join(args.model_dir, 'hift.decoder.fp32.onnx')
print(f'[export] torch.onnx.export -> {onnx_path}', flush=True)
torch.onnx.export(
block,
(x_dummy, s_stft_dummy),
onnx_path,
export_params=True,
opset_version=18,
do_constant_folding=True,
input_names=['x', 's_stft'],
output_names=['magnitude', 'phase'],
dynamic_axes={
'x': {2: 'T_x'},
's_stft': {2: 'T_stft'},
'magnitude': {2: 'T_out'},
'phase': {2: 'T_out'},
},
)

# Sanity check: run via onnxruntime and compare to PyTorch.
print('[export] sanity check via onnxruntime CUDA EP ...', flush=True)
sess = onnxruntime.InferenceSession(
onnx_path,
providers=['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'],
)
# Sanity-check on the actual probed shapes (the only ones for which we know
# the exact T_stft / T_x relationship; the source-downs Conv1d ratios make
# arbitrary T_x impossible to test with random stub tensors).
out_pt = block(x_dummy, s_stft_dummy)
out_ort = sess.run(None, {'x': x_dummy.cpu().numpy(), 's_stft': s_stft_dummy.cpu().numpy()})
for name, pt, ort in zip(['magnitude', 'phase'], out_pt, out_ort):
ort_t = torch.from_numpy(ort).to(device)
diff = (pt - ort_t).abs()
print(f' ort vs torch {name}: max_abs={diff.max().item():.3e} mean_abs={diff.mean().item():.3e} '
f'shape={tuple(ort_t.shape)}')

print(f'[export] done. ONNX size = {os.path.getsize(onnx_path) / 1e6:.1f} MB', flush=True)


if __name__ == '__main__':
main()
20 changes: 19 additions & 1 deletion cosyvoice/cli/cosyvoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ def inference_instruct2(self, tts_text, instruct_text, prompt_wav, zero_shot_spk

class CosyVoice3(CosyVoice2):

def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False,
trt_concurrent=int(os.environ.get('FLOW_TRT_CONCURRENT', '4'))):
self.model_dir = model_dir
self.fp16 = fp16
if not os.path.exists(model_dir):
Expand Down Expand Up @@ -222,6 +223,23 @@ def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_c
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
trt_concurrent,
self.fp16)
# HiFi-GAN decoder (post conv_pre) -> TRT, opt-in via env LOAD_TRT_HIFT=1.
# As of Round 13, hift TRT is fp16 by default. The Snake activation
# in cosyvoice/transformer/activation.py was patched to clamp
# inv_alpha at the source (max=6e4), which fixes the fp16 overflow
# that previously saturated audio (Round 6 regression). Pure fp16
# is now safe AND fastest. Set HIFT_TRT_FP16=0 to revert to fp32.
if os.environ.get('LOAD_TRT_HIFT', '0') == '1':
hift_fp16 = os.environ.get('HIFT_TRT_FP16', '1') == '1'
hift_onnx = '{}/hift.decoder.fp32.onnx'.format(model_dir)
hift_engine = '{}/hift.decoder.{}.mygpu.plan'.format(
model_dir, 'fp16' if hift_fp16 else 'fp32')
if os.path.exists(hift_onnx):
self.model.load_trt_hift(hift_engine, hift_onnx, hift_fp16)
logging.info('hift TRT engine loaded ({}); decode patched'.format(
'fp16' if hift_fp16 else 'fp32'))
else:
logging.warning('LOAD_TRT_HIFT=1 but {} not found; skipping'.format(hift_onnx))
del configs


Expand Down
Loading