diff --git a/cosyvoice/flow/DiT/dit.py b/cosyvoice/flow/DiT/dit.py index 0d637e4ad..c68d35c7f 100644 --- a/cosyvoice/flow/DiT/dit.py +++ b/cosyvoice/flow/DiT/dit.py @@ -87,6 +87,7 @@ def forward( cond: float["b n d"], text_embed: float["b n d"], spks: float["b d"], + conv_cache: torch.Tensor = None, ): to_cat = [x, cond, text_embed] if self.spk_dim > 0: @@ -94,8 +95,10 @@ def forward( to_cat.append(spks) x = self.proj(torch.cat(to_cat, dim=-1)) - x = self.conv_pos_embed(x) + x - return x + + conv, new_conv_cache = self.conv_pos_embed(x, conv_cache=conv_cache) + x = conv + x + return x, new_conv_cache # Transformer backbone using DiT blocks @@ -153,7 +156,7 @@ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False): # t: conditioning time, c: context (text + masked cond audio), x: noised input audio t = self.time_embed(t) - x = self.input_embed(x, cond, mu, spks.squeeze(1)) + x, _ = self.input_embed(x, cond, mu, spks.squeeze(1)) rope = self.rotary_embed.forward_from_seq_len(seq_len) @@ -173,4 +176,119 @@ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False): x = self.norm_out(x, t) output = self.proj_out(x).transpose(1, 2) - return output + return output, None + + def forward_chunk(self, x, x_offset, mask, mu, t, spks=None, cond=None, streaming=False, conv_cache:torch.Tensor=None, att_cache:torch.Tensor=None): + x = x.transpose(1, 2) + mu = mu.transpose(1, 2) + cond = cond.transpose(1, 2) + spks = spks.unsqueeze(dim=1) + batch, seq_len = x.shape[0], x.shape[1] + if t.ndim == 0: + t = t.repeat(batch) + + # t: conditioning time, c: context (text + masked cond audio), x: noised input audio + t = self.time_embed(t) + x, new_conv_cache = self.input_embed(x, cond, mu, spks.squeeze(1), conv_cache=conv_cache) + + x_off = int(x_offset.item()) if isinstance(x_offset, torch.Tensor) else int(x_offset) + x_partial = x[:, x_off:, :] + + rope = self.rotary_embed.forward_from_seq_len(seq_len) + + if self.long_skip_connection is not None: + residual = x_partial + + attn_mask = add_optional_chunk_mask(x_partial, mask.bool(), False, False, 0, self.static_chunk_size, -1).unsqueeze(dim=1) + + new_att_cache = [] + for i, block in enumerate(self.transformer_blocks): + x_partial, new_att_cache_i = block.forward_chunk(x_partial, t, x_offset=x_off, mask=attn_mask.bool(), rope=rope, att_cache=att_cache[i]) + new_att_cache.append(new_att_cache_i) + + if self.long_skip_connection is not None: + x_partial = self.long_skip_connection(torch.cat((x_partial, residual), dim=-1)) + + x_partial = self.norm_out(x_partial, t) + + output = self.proj_out(x_partial).transpose(1, 2) + return output, new_conv_cache, torch.stack(new_att_cache) + +class DiTWithCache(nn.Module): + def __init__( + self, + *, + dim, + depth=8, + heads=8, + dim_head=64, + dropout=0.1, + ff_mult=4, + mel_dim=80, + mu_dim=None, + long_skip_connection=False, + spk_dim=None, + out_channels=None, + static_chunk_size=50, + num_decoding_left_chunks=2 + ): + super().__init__() + + self.time_embed = TimestepEmbedding(dim) + if mu_dim is None: + mu_dim = mel_dim + self.input_embed = InputEmbedding(mel_dim, mu_dim, dim, spk_dim) + + self.rotary_embed = RotaryEmbedding(dim_head) + + self.dim = dim + self.depth = depth + + self.transformer_blocks = nn.ModuleList( + [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)] + ) + self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None + + self.norm_out = AdaLayerNormZero_Final(dim) # final modulation + self.proj_out = nn.Linear(dim, mel_dim) + self.out_channels = out_channels + self.static_chunk_size = static_chunk_size + self.num_decoding_left_chunks = num_decoding_left_chunks + + def forward(self, x, x_offset, mask, mu, t, spks=None, cond=None, conv_cache:torch.Tensor=None, att_cache:torch.Tensor=None, streaming=False): + x = x.transpose(1, 2) + mu = mu.transpose(1, 2) + cond = cond.transpose(1, 2) + spks = spks.unsqueeze(dim=1) + batch, seq_len = x.shape[0], x.shape[1] + seq_len = torch.tensor(seq_len, dtype=torch.int64) + x_offset = x_offset.to(dtype=torch.int64) + if t.ndim == 0: + t = t.repeat(batch) + + # t: conditioning time, c: context (text + masked cond audio), x: noised input audio + t = self.time_embed(t) + x, new_conv_cache = self.input_embed(x, cond, mu, spks.squeeze(1), conv_cache=conv_cache) + + # ONNX cannot capture python int so create a dummy tensor + #dummy_x = torch.zeros((seq_len)) + #dummy_x = F.pad(dummy_x, (0, x_offset)) + rope = self.rotary_embed.forward_from_seq_len(seq_len + x_offset) + + if self.long_skip_connection is not None: + residual = x + + attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, self.static_chunk_size, -1).unsqueeze(dim=1) + + new_att_cache = [] + for i, block in enumerate(self.transformer_blocks): + x, new_att_cache_i = block.forward_chunk(x, t, x_offset=x_offset, mask=attn_mask.bool(), rope=rope, att_cache=att_cache[:, i, :, :]) + new_att_cache.append(new_att_cache_i) + + if self.long_skip_connection is not None: + x = self.long_skip_connection(torch.cat((x, residual), dim=-1)) + + x = self.norm_out(x, t) + + output = self.proj_out(x).transpose(1, 2) + return output, new_conv_cache, torch.stack(new_att_cache, dim=1) \ No newline at end of file diff --git a/cosyvoice/flow/DiT/modules.py b/cosyvoice/flow/DiT/modules.py index be8caecb8..e39ac51f2 100644 --- a/cosyvoice/flow/DiT/modules.py +++ b/cosyvoice/flow/DiT/modules.py @@ -17,8 +17,35 @@ import torch.nn.functional as F import torchaudio -from x_transformers.x_transformers import apply_rotary_pos_emb +#from x_transformers.x_transformers import apply_rotary_pos_emb +import einops + +def rotate_half(x): + x = einops.rearrange(x, '... (d r) -> ... d r', r = 2) + x1, x2 = x.unbind(dim = -1) + x = torch.stack((-x2, x1), dim = -1) + return einops.rearrange(x, '... d r -> ... (d r)') + +@torch.amp.autocast('cuda', enabled = False) +def apply_rotary_pos_emb(t, freqs, scale = 1, offset = 0): + rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype + + freqs = freqs[:, offset:, :] + scale = scale[:, offset:, :] if torch.is_tensor(scale) else scale + + if t.ndim == 4 and freqs.ndim == 3: + freqs = einops.rearrange(freqs, 'b n d -> b 1 n d') + + if torch.is_tensor(scale): + scale = einops.rearrange(scale, 'b n d -> b 1 n d') + + # partial rotary embeddings, Wang et al. GPT-J + t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:] + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + out = torch.cat((t, t_unrotated), dim = -1) + + return out.type(orig_dtype) # raw wav to mel spec class MelSpec(nn.Module): @@ -126,22 +153,33 @@ def __init__(self, dim, kernel_size=31, groups=16): nn.Mish(), ) - def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722 + def forward(self, x: float["b n d"], mask: bool["b n"] | None = None, conv_cache: torch.Tensor = None): # noqa: F722 if mask is not None: mask = mask[..., None] x = x.masked_fill(~mask, 0.0) x = x.permute(0, 2, 1) - x = F.pad(x, (self.kernel_size - 1, 0, 0, 0)) + + if conv_cache is not None: + x = torch.cat((conv_cache[:, 0, :, :], x), dim=-1) + conv_cache[:, 0, :, :] = x[:, :, -self.kernel_size + 1:] + else: + x = F.pad(x, (self.kernel_size - 1, 0, 0, 0)) + x = self.conv1(x) - x = F.pad(x, (self.kernel_size - 1, 0, 0, 0)) + + if conv_cache is not None: + x = torch.cat((conv_cache[:, 1, :, :], x), dim=-1) + conv_cache[:, 1, :, :] = x[:, :, -self.kernel_size + 1:] + else: + x = F.pad(x, (self.kernel_size - 1, 0, 0, 0)) x = self.conv2(x) out = x.permute(0, 2, 1) if mask is not None: out = out.masked_fill(~mask, 0.0) - return out + return out, conv_cache # rotary positional embedding related @@ -342,6 +380,21 @@ def forward( else: return self.processor(self, x, mask=mask, rope=rope) + def forward_chunk( + self, + x: float["b n d"], # noised input x # noqa: F722 + x_offset: int = 0, + c: float["b n d"] = None, # context c # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding for x + c_rope=None, # rotary position embedding for c + att_cache:torch.Tensor=None + ) -> torch.Tensor: + if c is not None: + return self.processor(self, x, x_offset=x_offset, c=c, mask=mask, rope=rope, c_rope=c_rope, att_cache=att_cache) + else: + return self.processor(self, x, x_offset=x_offset, mask=mask, rope=rope, att_cache=att_cache) + # Attention processor @@ -354,23 +407,31 @@ def __call__( self, attn: Attention, x: float["b n d"], # noised input x # noqa: F722 + x_offset: int = 0, mask: bool["b n"] | None = None, # noqa: F722 rope=None, # rotary position embedding + att_cache:torch.Tensor=None ) -> torch.FloatTensor: batch_size = x.shape[0] - # `sample` projections. query = attn.to_q(x) key = attn.to_k(x) value = attn.to_v(x) - # apply rotary position embedding if rope is not None: freqs, xpos_scale = rope q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) - query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) - key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale, offset=x_offset) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale, offset=x_offset) + + new_att_cache = None + if att_cache is not None: + att_cache = att_cache.transpose(0, 1) + k_cache, v_cache = torch.chunk(att_cache, 2, dim=0) + new_att_cache = torch.cat((key, value), dim=0).transpose(0, 1) + key = torch.cat((k_cache, key), dim=1) + value = torch.cat((v_cache, value), dim=1) # attention inner_dim = key.shape[-1] @@ -384,7 +445,11 @@ def __call__( attn_mask = mask if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' - attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) + attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], mask.shape[-1]) + + if key.shape[-2] > mask.shape[-1]: + pad = attn_mask.new_ones(attn_mask.shape[0], 1, mask.shape[-2], key.shape[-2] - mask.shape[-1]) + attn_mask = torch.cat([pad, attn_mask], dim=-1) # (B, 1, 1, key_len) else: attn_mask = None @@ -404,8 +469,7 @@ def __call__( mask = mask[:, 0, -1].unsqueeze(-1) x = x.masked_fill(~mask, 0.0) - return x - + return x, new_att_cache # Joint Attention processor for MM-DiT # modified from diffusers/src/diffusers/models/attention_processor.py @@ -518,7 +582,7 @@ def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embed norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) # attention - attn_output = self.attn(x=norm, mask=mask, rope=rope) + attn_output, _ = self.attn(x=norm, mask=mask, rope=rope) # process attention output for input x x = x + gate_msa.unsqueeze(1) * attn_output @@ -529,6 +593,22 @@ def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embed return x + def forward_chunk(self, x, t, x_offset=0, mask=None, rope=None, att_cache:torch.Tensor=None): # x: noised input, t: time embedding + # pre-norm & modulation for attention input + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) + + # attention + attn_output, new_att_cache = self.attn.forward_chunk(x=norm, x_offset=x_offset, mask=mask, rope=rope, att_cache=att_cache) + + # process attention output for input x + x = x + gate_msa.unsqueeze(1) * attn_output + + ff_norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(ff_norm) + x = x + gate_mlp.unsqueeze(1) * ff_output + + return x, new_att_cache + # MMDiT Block https://arxiv.org/abs/2403.03206 diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index c25518621..790ba9d8d 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -401,18 +401,92 @@ def inference(self, conds = conds.transpose(1, 2) mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) - feat, _ = self.decoder( + feat, _, _ = self.decoder.forward( mu=h.transpose(1, 2).contiguous(), mask=mask.unsqueeze(1), spks=embedding, cond=conds, n_timesteps=10, - streaming=streaming + streaming=streaming, ) feat = feat[:, :, mel_len1:] assert feat.shape[2] == mel_len2 return feat.float(), None + @torch.inference_mode() + def inference_chunk(self, + token, + token_offset, + token_len, + prompt_token, + prompt_token_len, + prompt_feat, + prompt_feat_len, + conv_cache, + att_cache, + embedding, + streaming, + finalize, + init_cache=False, + chunk_size=25, + n_timesteps=10): + assert token.shape[0] == 1 + embedding = F.normalize(embedding, dim=1) + embedding = self.spk_embed_affine_layer(embedding) + + token, real_token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len + if token_len.item() == 0: + real_token_len = (real_token_len - 3) // chunk_size * chunk_size + 3 + token = token[:, :real_token_len] + assert finalize or real_token_len.item() % chunk_size == 3 + + mask = (~make_pad_mask(real_token_len)).unsqueeze(-1).to(embedding) + token = self.input_embedding(torch.clamp(token, min=0)) * mask + + if finalize is True or init_cache is True: + h = self.pre_lookahead_layer(token) + else: + h = self.pre_lookahead_layer(token[:, :-self.pre_lookahead_len], context=token[:, -self.pre_lookahead_len:]) + h = h.repeat_interleave(self.token_mel_ratio, dim=1) + + h_offset = 0 + if att_cache is not None: + h_offset = att_cache[0].shape[0] + + if token_len.item() == 0: + conds = torch.zeros([1, h.shape[1], self.output_size], device=token.device).to(h.dtype) + conds = conds.transpose(1, 2) + mask = (~make_pad_mask(torch.tensor([h.shape[1]]))).to(h) + else: + h = h[:, h_offset:, :] + conds = torch.zeros([1, h.shape[1], self.output_size], device=token.device).to(h.dtype) + if h_offset < prompt_token_len.item() * self.token_mel_ratio: + left = prompt_token_len.item() * self.token_mel_ratio - h_offset + conds[:, :left, :] = prompt_feat[:, -left:, :] + conds = conds.transpose(1, 2) + mask = (~make_pad_mask(torch.tensor([h.shape[1]]))).to(h) + + x_offset = torch.tensor(h_offset, dtype=torch.int32, device=token.device) + feat, new_conv_cache, new_att_cache = self.decoder.forward_chunk( + mu=h.transpose(1, 2).contiguous(), + x_offset=x_offset, + mask=mask.unsqueeze(1), + spks=embedding, + cond=conds, + n_timesteps=n_timesteps, + streaming=streaming, + conv_cache=conv_cache, + att_cache=att_cache, + ) + if init_cache is True: + feat = feat[:, :, :0] + else: + estimated_token_offset = x_offset - prompt_token_len * self.token_mel_ratio + if token_offset * self.token_mel_ratio - estimated_token_offset > 0: + feat = feat[:, :, (token_offset * self.token_mel_ratio - estimated_token_offset):] + + return feat.float(), new_conv_cache, new_att_cache + if __name__ == '__main__': torch.backends.cudnn.deterministic = True diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index d3beb9ec2..d67fc33e2 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -106,12 +106,12 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False): t_in[:] = t.unsqueeze(0) spks_in[0] = spks cond_in[0] = cond - dphi_dt = self.forward_estimator( + dphi_dt, new_att_cache = self.forward_estimator( x_in, mask_in, mu_in, t_in, spks_in, cond_in, - streaming + streaming, ) dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0) dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) @@ -123,13 +123,84 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False): return sol[-1].float() + def solve_euler_chunk(self, x, x_offset, t_span, mu, mask, spks, cond, streaming=False, conv_cache=None,att_cache=None): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + """ + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + t = t.unsqueeze(dim=0) + + cache_dtype = spks.dtype + if att_cache is None: + att_cache = [] + for i in range(len(t_span) - 1): + att_cache.append(torch.zeros(0, 22, x.shape[0] * 4, 1024, dtype=cache_dtype, device=x.device)) + if conv_cache is None: + conv_cache = torch.zeros(len(t_span) - 1, 2, 2, 1024, 30, dtype=cache_dtype, device=x.device) + new_att_cache = [] + new_conv_cache = [] + # I am storing this because I can later plot it by putting a debugger here and saving it to a file + # Or in future might add like a return_all_steps flag + sol = [] + + # Do not use concat, it may cause memory format changed and trt infer with wrong results! + # NOTE when flow run in amp mode, x.dtype is float32, which cause nan in trt fp16 inference, so set dtype=spks.dtype + x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype) + mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=spks.dtype) + mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype) + t_in = torch.zeros([2], device=x.device, dtype=spks.dtype) + spks_in = torch.zeros([2, 80], device=x.device, dtype=spks.dtype) + cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype) + for step in range(1, len(t_span)): + # Classifier-Free Guidance inference introduced in VoiceBox + x_in[:] = x + mask_in[:] = mask + mu_in[0] = mu + t_in[:] = t.unsqueeze(0) + spks_in[0] = spks + cond_in[0] = cond + dphi_dt, new_conv_cache_step, new_att_cache_step = self.forward_estimator_with_cache( + x_in, x_offset, mask_in, + mu_in, t_in, + spks_in, + cond_in, + streaming, + conv_cache=conv_cache[step - 1], + att_cache=att_cache[step - 1] + ) + new_conv_cache.append(new_conv_cache_step) + new_att_cache.append(new_att_cache_step) + dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0) + dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + return sol[-1].float(), torch.stack(tuple(new_conv_cache), dim=0), new_att_cache + def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False): if isinstance(self.estimator, torch.nn.Module): - return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming) + out = self.estimator(x, mask, mu, t, spks, cond, streaming=streaming) + if isinstance(out, tuple): + return out[0], out[1] + return out, None else: - [estimator, stream], trt_engine = self.estimator.acquire_estimator() - # NOTE need to synchronize when switching stream + # 执行前进行同步,避免在triton等多线程复杂运行环境中同时操作一个x对象导致数据污染 torch.cuda.current_stream().synchronize() + [estimator, stream], trt_engine = self.estimator.acquire_estimator() with stream: estimator.set_input_shape('x', (2, 80, x.size(2))) estimator.set_input_shape('mask', (2, 1, x.size(2))) @@ -150,7 +221,57 @@ def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False): assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True torch.cuda.current_stream().synchronize() self.estimator.release_estimator(estimator, stream) - return x + return x, None + + def forward_estimator_with_cache(self, x, x_offset, mask, mu, t, spks, cond, streaming=False, conv_cache=None, att_cache=None): + if isinstance(self.estimator, torch.nn.Module): + return self.estimator.forward_chunk( + x, x_offset, mask, mu, t, spks, cond, + streaming=streaming, conv_cache=conv_cache, att_cache=att_cache) + else: + # 执行前进行同步,避免在triton等多线程复杂运行环境中同时操作一个x对象导致数据污染 + x_offset = x_offset.to(device=x.device, dtype=torch.int64).contiguous() + x_offset = x_offset.cpu() + [estimator, stream], trt_engine = self.estimator.acquire_estimator() + new_conv_cache = torch.empty_like(conv_cache) + new_att_cache = torch.empty((x.size(2), 22, 4, 1024), device=x.device, dtype=conv_cache.dtype) + + # Avoid null pointer error in trt engine + if att_cache.size(0) == 0: + att_cache_for_trt = torch.zeros((1, 22, 4, 1024), device=x.device, dtype=conv_cache.dtype) + else: + att_cache_for_trt = att_cache + + with stream: + torch.cuda.current_stream().synchronize() + estimator.set_input_shape('x', (2, 80, x.size(2))) + estimator.set_input_shape('x_offset', ()) + estimator.set_input_shape('mask', (2, 1, x.size(2))) + estimator.set_input_shape('mu', (2, 80, x.size(2))) + estimator.set_input_shape('t', (2,)) + estimator.set_input_shape('spks', (2, 80)) + estimator.set_input_shape('cond', (2, 80, x.size(2))) + estimator.set_input_shape('conv_cache', (2, 2, 1024, 30)) + estimator.set_input_shape('att_cache', (att_cache.size(0), 22, 4, 1024)) + data_ptrs = [x.contiguous().data_ptr(), + x_offset.contiguous().data_ptr(), + mask.contiguous().data_ptr(), + mu.contiguous().data_ptr(), + t.contiguous().data_ptr(), + spks.contiguous().data_ptr(), + cond.contiguous().data_ptr(), + conv_cache.contiguous().data_ptr(), + att_cache_for_trt.contiguous().data_ptr(), + x.data_ptr(), + new_conv_cache.data_ptr(), + new_att_cache.data_ptr()] + for i, j in enumerate(data_ptrs): + estimator.set_tensor_address(trt_engine.get_tensor_name(i), j) + # run trt engine + assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True + torch.cuda.current_stream().synchronize() + self.estimator.release_estimator(estimator, stream) + return x, new_conv_cache, new_att_cache def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False): """Computes diffusion loss @@ -189,6 +310,8 @@ def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False): cond = cond * cfg_mask.view(-1, 1, 1) pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming) + if isinstance(pred, tuple): + pred = pred[0] loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1]) return loss, y @@ -201,27 +324,24 @@ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: @torch.inference_mode() def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False): - """Forward diffusion + """Forward diffusion (CosyVoice2 UNet path, unchanged from upstream).""" + z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) + if self.t_scheduler == 'cosine': + t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None - Args: - mu (torch.Tensor): output of encoder - shape: (batch_size, n_feats, mel_timesteps) - mask (torch.Tensor): output_mask - shape: (batch_size, 1, mel_timesteps) - n_timesteps (int): number of diffusion steps - temperature (float, optional): temperature for scaling noise. Defaults to 1.0. - spks (torch.Tensor, optional): speaker ids. Defaults to None. - shape: (batch_size, spk_emb_dim) - cond: Not used but kept for future purposes + @torch.inference_mode() + def forward_chunk(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False, + x_offset=0, conv_cache=None, att_cache=None): + """Chunked / KV-cached diffusion path for DiT estimator (torch.nn.Module or TRT-wrapped). - Returns: - sample: generated mel-spectrogram - shape: (batch_size, n_feats, mel_timesteps) + Returns: (mel, new_conv_cache, new_att_cache) """ - z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature - # fix prompt and overlap part mu and z t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) if self.t_scheduler == 'cosine': t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) - return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None + return self.solve_euler_chunk( + z, x_offset, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, + streaming=streaming, conv_cache=conv_cache, att_cache=att_cache)