Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions cosyvoice/flow/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,20 +92,23 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):

# Do not use concat, it may cause memory format changed and trt infer with wrong results!
# NOTE when flow run in amp mode, x.dtype is float32, which cause nan in trt fp16 inference, so set dtype=spks.dtype
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=spks.dtype)
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
t_in = torch.zeros([2], device=x.device, dtype=spks.dtype)
spks_in = torch.zeros([2, 80], device=x.device, dtype=spks.dtype)
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
bsz = x.size(0)
x_in = torch.zeros([2 * bsz, 80, x.size(2)], device=x.device, dtype=spks.dtype)
mask_in = torch.zeros([2 * bsz, 1, x.size(2)], device=x.device, dtype=spks.dtype)
mu_in = torch.zeros([2 * bsz, 80, x.size(2)], device=x.device, dtype=spks.dtype)
t_in = torch.zeros([2 * bsz], device=x.device, dtype=spks.dtype)
spks_in = torch.zeros([2 * bsz, 80], device=x.device, dtype=spks.dtype)
cond_in = torch.zeros([2 * bsz, 80, x.size(2)], device=x.device, dtype=spks.dtype)
for step in range(1, len(t_span)):
# Classifier-Free Guidance inference introduced in VoiceBox
x_in[:] = x
mask_in[:] = mask
mu_in[0] = mu
x_in[:bsz] = x
x_in[bsz:] = x
mask_in[:bsz] = mask
mask_in[bsz:] = mask
mu_in[:bsz] = mu
t_in[:] = t.unsqueeze(0)
spks_in[0] = spks
cond_in[0] = cond
spks_in[:bsz] = spks
cond_in[:bsz] = cond
dphi_dt = self.forward_estimator(
x_in, mask_in,
mu_in, t_in,
Expand Down