File tree Expand file tree Collapse file tree
megatron/core/transformer Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -672,8 +672,11 @@ def process_mtp_loss(
672672 mtp_loss = compute_language_model_loss (mtp_labels , mtp_logits )
673673 mtp_loss = loss_mask * mtp_loss
674674 if is_training :
675+ mtp_loss_for_log = (
676+ torch .sum (mtp_loss ) / num_tokens if num_tokens > 0 else mtp_loss .new_tensor (0.0 )
677+ )
675678 MTPLossLoggingHelper .save_loss_to_tracker (
676- torch . sum ( mtp_loss ) / num_tokens ,
679+ mtp_loss_for_log ,
677680 mtp_layer_number ,
678681 config .mtp_num_layers ,
679682 avg_group = parallel_state .get_data_parallel_group (with_context_parallel = True ),
@@ -692,8 +695,9 @@ def process_mtp_loss(
692695 )
693696 hidden_states = MTPLossAutoScaler .apply (hidden_states , mtp_loss_normalized )
694697 else :
698+ safe_num_tokens = num_tokens .clamp (min = 1 )
695699 hidden_states = MTPLossAutoScaler .apply (
696- hidden_states , mtp_loss_scale * mtp_loss / num_tokens
700+ hidden_states , mtp_loss_scale * mtp_loss / safe_num_tokens
697701 )
698702
699703 return hidden_states
You can’t perform that action at this time.
0 commit comments