Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 122 additions & 4 deletions cosyvoice/flow/DiT/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,18 @@ def forward(
cond: float["b n d"],
text_embed: float["b n d"],
spks: float["b d"],
conv_cache: torch.Tensor = None,
):
to_cat = [x, cond, text_embed]
if self.spk_dim > 0:
spks = repeat(spks, "b c -> b t c", t=x.shape[1])
to_cat.append(spks)

x = self.proj(torch.cat(to_cat, dim=-1))
x = self.conv_pos_embed(x) + x
return x

conv, new_conv_cache = self.conv_pos_embed(x, conv_cache=conv_cache)
x = conv + x
return x, new_conv_cache


# Transformer backbone using DiT blocks
Expand Down Expand Up @@ -153,7 +156,7 @@ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):

# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(t)
x = self.input_embed(x, cond, mu, spks.squeeze(1))
x, _ = self.input_embed(x, cond, mu, spks.squeeze(1))

rope = self.rotary_embed.forward_from_seq_len(seq_len)

Expand All @@ -173,4 +176,119 @@ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):

x = self.norm_out(x, t)
output = self.proj_out(x).transpose(1, 2)
return output
return output, None

def forward_chunk(self, x, x_offset, mask, mu, t, spks=None, cond=None, streaming=False, conv_cache:torch.Tensor=None, att_cache:torch.Tensor=None):
x = x.transpose(1, 2)
mu = mu.transpose(1, 2)
cond = cond.transpose(1, 2)
spks = spks.unsqueeze(dim=1)
batch, seq_len = x.shape[0], x.shape[1]
if t.ndim == 0:
t = t.repeat(batch)

# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(t)
x, new_conv_cache = self.input_embed(x, cond, mu, spks.squeeze(1), conv_cache=conv_cache)

x_off = int(x_offset.item()) if isinstance(x_offset, torch.Tensor) else int(x_offset)
x_partial = x[:, x_off:, :]

rope = self.rotary_embed.forward_from_seq_len(seq_len)

if self.long_skip_connection is not None:
residual = x_partial

attn_mask = add_optional_chunk_mask(x_partial, mask.bool(), False, False, 0, self.static_chunk_size, -1).unsqueeze(dim=1)

new_att_cache = []
for i, block in enumerate(self.transformer_blocks):
x_partial, new_att_cache_i = block.forward_chunk(x_partial, t, x_offset=x_off, mask=attn_mask.bool(), rope=rope, att_cache=att_cache[i])
new_att_cache.append(new_att_cache_i)

if self.long_skip_connection is not None:
x_partial = self.long_skip_connection(torch.cat((x_partial, residual), dim=-1))

x_partial = self.norm_out(x_partial, t)

output = self.proj_out(x_partial).transpose(1, 2)
return output, new_conv_cache, torch.stack(new_att_cache)

class DiTWithCache(nn.Module):
def __init__(
self,
*,
dim,
depth=8,
heads=8,
dim_head=64,
dropout=0.1,
ff_mult=4,
mel_dim=80,
mu_dim=None,
long_skip_connection=False,
spk_dim=None,
out_channels=None,
static_chunk_size=50,
num_decoding_left_chunks=2
):
super().__init__()

self.time_embed = TimestepEmbedding(dim)
if mu_dim is None:
mu_dim = mel_dim
self.input_embed = InputEmbedding(mel_dim, mu_dim, dim, spk_dim)

self.rotary_embed = RotaryEmbedding(dim_head)

self.dim = dim
self.depth = depth

self.transformer_blocks = nn.ModuleList(
[DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
)
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None

self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)
self.out_channels = out_channels
self.static_chunk_size = static_chunk_size
self.num_decoding_left_chunks = num_decoding_left_chunks

def forward(self, x, x_offset, mask, mu, t, spks=None, cond=None, conv_cache:torch.Tensor=None, att_cache:torch.Tensor=None, streaming=False):
x = x.transpose(1, 2)
mu = mu.transpose(1, 2)
cond = cond.transpose(1, 2)
spks = spks.unsqueeze(dim=1)
batch, seq_len = x.shape[0], x.shape[1]
seq_len = torch.tensor(seq_len, dtype=torch.int64)
x_offset = x_offset.to(dtype=torch.int64)
if t.ndim == 0:
t = t.repeat(batch)

# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(t)
x, new_conv_cache = self.input_embed(x, cond, mu, spks.squeeze(1), conv_cache=conv_cache)

# ONNX cannot capture python int so create a dummy tensor
#dummy_x = torch.zeros((seq_len))
#dummy_x = F.pad(dummy_x, (0, x_offset))
rope = self.rotary_embed.forward_from_seq_len(seq_len + x_offset)

if self.long_skip_connection is not None:
residual = x

attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, self.static_chunk_size, -1).unsqueeze(dim=1)

new_att_cache = []
for i, block in enumerate(self.transformer_blocks):
x, new_att_cache_i = block.forward_chunk(x, t, x_offset=x_offset, mask=attn_mask.bool(), rope=rope, att_cache=att_cache[:, i, :, :])
new_att_cache.append(new_att_cache_i)

if self.long_skip_connection is not None:
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))

x = self.norm_out(x, t)

output = self.proj_out(x).transpose(1, 2)
return output, new_conv_cache, torch.stack(new_att_cache, dim=1)
106 changes: 93 additions & 13 deletions cosyvoice/flow/DiT/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,35 @@
import torch.nn.functional as F
import torchaudio

from x_transformers.x_transformers import apply_rotary_pos_emb
#from x_transformers.x_transformers import apply_rotary_pos_emb

import einops

def rotate_half(x):
x = einops.rearrange(x, '... (d r) -> ... d r', r = 2)
x1, x2 = x.unbind(dim = -1)
x = torch.stack((-x2, x1), dim = -1)
return einops.rearrange(x, '... d r -> ... (d r)')

@torch.amp.autocast('cuda', enabled = False)
def apply_rotary_pos_emb(t, freqs, scale = 1, offset = 0):
rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype

freqs = freqs[:, offset:, :]
scale = scale[:, offset:, :] if torch.is_tensor(scale) else scale

if t.ndim == 4 and freqs.ndim == 3:
freqs = einops.rearrange(freqs, 'b n d -> b 1 n d')

if torch.is_tensor(scale):
scale = einops.rearrange(scale, 'b n d -> b 1 n d')

# partial rotary embeddings, Wang et al. GPT-J
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
out = torch.cat((t, t_unrotated), dim = -1)

return out.type(orig_dtype)

# raw wav to mel spec
class MelSpec(nn.Module):
Expand Down Expand Up @@ -126,22 +153,33 @@ def __init__(self, dim, kernel_size=31, groups=16):
nn.Mish(),
)

def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None, conv_cache: torch.Tensor = None): # noqa: F722
if mask is not None:
mask = mask[..., None]
x = x.masked_fill(~mask, 0.0)

x = x.permute(0, 2, 1)
x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))

if conv_cache is not None:
x = torch.cat((conv_cache[:, 0, :, :], x), dim=-1)
conv_cache[:, 0, :, :] = x[:, :, -self.kernel_size + 1:]
else:
x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))

x = self.conv1(x)
x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))

if conv_cache is not None:
x = torch.cat((conv_cache[:, 1, :, :], x), dim=-1)
conv_cache[:, 1, :, :] = x[:, :, -self.kernel_size + 1:]
else:
x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
x = self.conv2(x)
out = x.permute(0, 2, 1)

if mask is not None:
out = out.masked_fill(~mask, 0.0)

return out
return out, conv_cache


# rotary positional embedding related
Expand Down Expand Up @@ -342,6 +380,21 @@ def forward(
else:
return self.processor(self, x, mask=mask, rope=rope)

def forward_chunk(
self,
x: float["b n d"], # noised input x # noqa: F722
x_offset: int = 0,
c: float["b n d"] = None, # context c # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
att_cache:torch.Tensor=None
) -> torch.Tensor:
if c is not None:
return self.processor(self, x, x_offset=x_offset, c=c, mask=mask, rope=rope, c_rope=c_rope, att_cache=att_cache)
else:
return self.processor(self, x, x_offset=x_offset, mask=mask, rope=rope, att_cache=att_cache)


# Attention processor

Expand All @@ -354,23 +407,31 @@ def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722
x_offset: int = 0,
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding
att_cache:torch.Tensor=None
) -> torch.FloatTensor:
batch_size = x.shape[0]

# `sample` projections.
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)

# apply rotary position embedding
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)

query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale, offset=x_offset)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale, offset=x_offset)

new_att_cache = None
if att_cache is not None:
att_cache = att_cache.transpose(0, 1)
k_cache, v_cache = torch.chunk(att_cache, 2, dim=0)
new_att_cache = torch.cat((key, value), dim=0).transpose(0, 1)
key = torch.cat((k_cache, key), dim=1)
value = torch.cat((v_cache, value), dim=1)

# attention
inner_dim = key.shape[-1]
Expand All @@ -384,7 +445,11 @@ def __call__(
attn_mask = mask
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], mask.shape[-1])

if key.shape[-2] > mask.shape[-1]:
pad = attn_mask.new_ones(attn_mask.shape[0], 1, mask.shape[-2], key.shape[-2] - mask.shape[-1])
attn_mask = torch.cat([pad, attn_mask], dim=-1) # (B, 1, 1, key_len)
else:
attn_mask = None

Expand All @@ -404,8 +469,7 @@ def __call__(
mask = mask[:, 0, -1].unsqueeze(-1)
x = x.masked_fill(~mask, 0.0)

return x

return x, new_att_cache

# Joint Attention processor for MM-DiT
# modified from diffusers/src/diffusers/models/attention_processor.py
Expand Down Expand Up @@ -518,7 +582,7 @@ def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embed
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)

# attention
attn_output = self.attn(x=norm, mask=mask, rope=rope)
attn_output, _ = self.attn(x=norm, mask=mask, rope=rope)

# process attention output for input x
x = x + gate_msa.unsqueeze(1) * attn_output
Expand All @@ -529,6 +593,22 @@ def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embed

return x

def forward_chunk(self, x, t, x_offset=0, mask=None, rope=None, att_cache:torch.Tensor=None): # x: noised input, t: time embedding
# pre-norm & modulation for attention input
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)

# attention
attn_output, new_att_cache = self.attn.forward_chunk(x=norm, x_offset=x_offset, mask=mask, rope=rope, att_cache=att_cache)

# process attention output for input x
x = x + gate_msa.unsqueeze(1) * attn_output

ff_norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(ff_norm)
x = x + gate_mlp.unsqueeze(1) * ff_output

return x, new_att_cache


# MMDiT Block https://arxiv.org/abs/2403.03206

Expand Down
Loading