Skip to content

Fix NaN loss when all labels on a rank are label_ignore_index#657

Open
Suanmd wants to merge 1 commit into
allenai:mainfrom
Suanmd:nan_loss_fix
Open

Fix NaN loss when all labels on a rank are label_ignore_index#657
Suanmd wants to merge 1 commit into
allenai:mainfrom
Suanmd:nan_loss_fix

Conversation

@Suanmd
Copy link
Copy Markdown

@Suanmd Suanmd commented Apr 10, 2026

Summary

When training with packed SFT data across many devices, some ranks may receive batches where all labels are label_ignore_index (-100). This causes batch_num_tokens_for_loss to be 0, leading to a division-by-zero in the loss computation (loss_reduction="sum" / 0 → NaN). The NaN then propagates to all ranks via all_reduce, crashing the entire training run.

This is distinct from the instance-masking NaN fix (which guards against instance_mask filtering out all instances). In this case instance_mask is None — the issue is purely that the data distribution places all-padding / prompt-only sequences on a single rank, leaving zero tokens eligible for loss.

Root cause

batch_num_tokens_for_loss = (labels != label_ignore_index).sum()  # == 0
loss = cross_entropy(..., reduction="sum") / batch_num_tokens_for_loss  # 0/0 → NaN

The existing instance-mask guard (line 407) does not help because instance_mask is None in this scenario.

Fix

Before entering the micro-batch forward loop, clamp batch_num_tokens_for_loss to a non-zero fallback (batch_num_tokens) when it equals 0. This is safe because:

  1. When all labels are ignored, cross_entropy(..., reduction="sum") produces 0 (no valid targets → zero loss sum).
  2. Dividing 0 by any positive number yields 0, so the rank contributes zero loss — correct behavior that does not distort gradients on other ranks.
  3. A log.warning is emitted for observability.

Reproduction conditions

  • SFT training with packed sequences (NumpyPackedFSLDataset)
  • Small rank_microbatch_size (1 instance per rank, e.g., seq_len=1024, global_batch_size = num_gpus * 1024)
  • Dataset contains sequences with no completion tokens (all labels masked)
  • More likely with higher device counts (fewer instances per rank → higher probability of a rank drawing an all-masked batch)

Verified

Tested on GPU/NPU, 1 node, seq_len=1024. Before fix: sporadic NaN on 1-3 ranks per step. After fix: 100 steps completed without NaN.

Test plan

  • Reproduced NaN on 16-NPU cluster with packed SFT data
  • Reproduced NaN on 8-GPU cluster with packed SFT data
  • Confirmed via debug logging: batch_num_tokens_for_loss=0,
    instance_mask=None, all_labels_masked=True on affected ranks
  • Verified fix: 100-step training run completes with no NaN
  • Consider upstream: add unit test with synthetic all-masked batch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant