diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index e8e81d942..ddff71f21 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -447,10 +447,11 @@ def forward_dpo( acc = th_accuracy(chosen_logits.view(-1, self.speech_token_size + 3), chosen_lm_target, ignore_label=IGNORE_ID) # 5. calculate dpo logits - chosen_lm_mask = chosen_lm_target == IGNORE_ID - rejected_lm_mask = rejected_lm_target == IGNORE_ID - chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1) - rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1) + # Fix: mask should select NON-IGNORE_ID positions (True for valid tokens) + chosen_lm_mask = chosen_lm_target != IGNORE_ID + rejected_lm_mask = rejected_lm_target != IGNORE_ID + chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(~chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1) + rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(~rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1) chosen_logps = (chosen_logps * chosen_lm_mask).sum(dim=-1) / chosen_lm_mask.sum(dim=-1) rejected_logps = (rejected_logps * rejected_lm_mask).sum(dim=-1) / rejected_lm_mask.sum(dim=-1) return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}