Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions cosyvoice/flow/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,14 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):

# Do not use concat, it may cause memory format changed and trt infer with wrong results!
# NOTE when flow run in amp mode, x.dtype is float32, which cause nan in trt fp16 inference, so set dtype=spks.dtype
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=spks.dtype)
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
t_in = torch.zeros([2], device=x.device, dtype=spks.dtype)
spks_in = torch.zeros([2, 80], device=x.device, dtype=spks.dtype)
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
# Fix: use x.dtype as fallback when spks is None to avoid AttributeError
dtype = spks.dtype if spks is not None else x.dtype
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=dtype)
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=dtype)
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=dtype)
t_in = torch.zeros([2], device=x.device, dtype=dtype)
spks_in = torch.zeros([2, 80], device=x.device, dtype=dtype)
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=dtype)
for step in range(1, len(t_span)):
# Classifier-Free Guidance inference introduced in VoiceBox
x_in[:] = x
Expand Down
9 changes: 5 additions & 4 deletions cosyvoice/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,10 +447,11 @@ def forward_dpo(
acc = th_accuracy(chosen_logits.view(-1, self.speech_token_size + 3), chosen_lm_target, ignore_label=IGNORE_ID)

# 5. calculate dpo logits
chosen_lm_mask = chosen_lm_target == IGNORE_ID
rejected_lm_mask = rejected_lm_target == IGNORE_ID
chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
# Fix: mask should select NON-IGNORE_ID positions (True for valid tokens)
chosen_lm_mask = chosen_lm_target != IGNORE_ID
rejected_lm_mask = rejected_lm_target != IGNORE_ID
chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(~chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(~rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
chosen_logps = (chosen_logps * chosen_lm_mask).sum(dim=-1) / chosen_lm_mask.sum(dim=-1)
rejected_logps = (rejected_logps * rejected_lm_mask).sum(dim=-1) / rejected_lm_mask.sum(dim=-1)
return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
Expand Down