diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index d3beb9ec2..6b5bf869b 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -92,12 +92,14 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False): # Do not use concat, it may cause memory format changed and trt infer with wrong results! # NOTE when flow run in amp mode, x.dtype is float32, which cause nan in trt fp16 inference, so set dtype=spks.dtype - x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype) - mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=spks.dtype) - mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype) - t_in = torch.zeros([2], device=x.device, dtype=spks.dtype) - spks_in = torch.zeros([2, 80], device=x.device, dtype=spks.dtype) - cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype) + # Fix: use x.dtype as fallback when spks is None to avoid AttributeError + dtype = spks.dtype if spks is not None else x.dtype + x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=dtype) + mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=dtype) + mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=dtype) + t_in = torch.zeros([2], device=x.device, dtype=dtype) + spks_in = torch.zeros([2, 80], device=x.device, dtype=dtype) + cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=dtype) for step in range(1, len(t_span)): # Classifier-Free Guidance inference introduced in VoiceBox x_in[:] = x diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index e8e81d942..ddff71f21 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -447,10 +447,11 @@ def forward_dpo( acc = th_accuracy(chosen_logits.view(-1, self.speech_token_size + 3), chosen_lm_target, ignore_label=IGNORE_ID) # 5. calculate dpo logits - chosen_lm_mask = chosen_lm_target == IGNORE_ID - rejected_lm_mask = rejected_lm_target == IGNORE_ID - chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1) - rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1) + # Fix: mask should select NON-IGNORE_ID positions (True for valid tokens) + chosen_lm_mask = chosen_lm_target != IGNORE_ID + rejected_lm_mask = rejected_lm_target != IGNORE_ID + chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(~chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1) + rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(~rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1) chosen_logps = (chosen_logps * chosen_lm_mask).sum(dim=-1) / chosen_lm_mask.sum(dim=-1) rejected_logps = (rejected_logps * rejected_lm_mask).sum(dim=-1) / rejected_lm_mask.sum(dim=-1) return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}