From b4c68e7eca537d9160d413449d7b52713cdabd16 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 18 Mar 2026 14:13:38 +0800 Subject: [PATCH] fix: DPO implementation bug in forward_dpo (issue #1449) - Changed mask logic from '== IGNORE_ID' to '!= IGNORE_ID' - Fixed gather index to use '~mask' for padding positions - This ensures DPO loss is computed on valid tokens only, not IGNORE_ID positions --- cosyvoice/llm/llm.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index e8e81d942..ddff71f21 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -447,10 +447,11 @@ def forward_dpo( acc = th_accuracy(chosen_logits.view(-1, self.speech_token_size + 3), chosen_lm_target, ignore_label=IGNORE_ID) # 5. calculate dpo logits - chosen_lm_mask = chosen_lm_target == IGNORE_ID - rejected_lm_mask = rejected_lm_target == IGNORE_ID - chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1) - rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1) + # Fix: mask should select NON-IGNORE_ID positions (True for valid tokens) + chosen_lm_mask = chosen_lm_target != IGNORE_ID + rejected_lm_mask = rejected_lm_target != IGNORE_ID + chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(~chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1) + rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(~rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1) chosen_logps = (chosen_logps * chosen_lm_mask).sum(dim=-1) / chosen_lm_mask.sum(dim=-1) rejected_logps = (rejected_logps * rejected_lm_mask).sum(dim=-1) / rejected_lm_mask.sum(dim=-1) return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}