Skip to content

Commit d7500d4

Browse files
authored
[Bugfix] Fix nan loss caused by zero token in MTP (NVIDIA#3396)
Signed-off-by: lit <lit@nvidia.com>
1 parent a22c40e commit d7500d4

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

megatron/core/transformer/multi_token_prediction.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -672,8 +672,11 @@ def process_mtp_loss(
672672
mtp_loss = compute_language_model_loss(mtp_labels, mtp_logits)
673673
mtp_loss = loss_mask * mtp_loss
674674
if is_training:
675+
mtp_loss_for_log = (
676+
torch.sum(mtp_loss) / num_tokens if num_tokens > 0 else mtp_loss.new_tensor(0.0)
677+
)
675678
MTPLossLoggingHelper.save_loss_to_tracker(
676-
torch.sum(mtp_loss) / num_tokens,
679+
mtp_loss_for_log,
677680
mtp_layer_number,
678681
config.mtp_num_layers,
679682
avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
@@ -692,8 +695,9 @@ def process_mtp_loss(
692695
)
693696
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_normalized)
694697
else:
698+
safe_num_tokens = num_tokens.clamp(min=1)
695699
hidden_states = MTPLossAutoScaler.apply(
696-
hidden_states, mtp_loss_scale * mtp_loss / num_tokens
700+
hidden_states, mtp_loss_scale * mtp_loss / safe_num_tokens
697701
)
698702

699703
return hidden_states

0 commit comments

Comments
 (0)