[training] fix: use int64 for TrainState counters to prevent overflow#3312
[training] fix: use int64 for TrainState counters to prevent overflow#3312
Conversation
TrainState.state_dict() serialized step and sample counters as torch.int32 tensors, which overflow at ~2.1B. Long training runs with large batch sizes (e.g. 700K iters × GBS 3072 = 2.15B samples) hit RuntimeError at checkpoint save time. Switch to torch.int64, which supports up to ~9.2e18. Backward compatible: load_state_dict() uses .item() which returns a Python scalar regardless of the source tensor dtype. Signed-off-by: Yu Yao <yaoyu.094@gmail.com> Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> Made-with: Cursor
|
/ok to test a74a3ed |
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Plus Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughThe Changes
Estimated code review effort🎯 1 (Trivial) | ⏱️ ~3 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
The TrainState counters were changed from int32 to int64 to prevent overflow on long training runs, but the unit test still asserted int32. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
maanug-nv
left a comment
There was a problem hiding this comment.
how about uint32 or uint64? these values shouldn't ever be negative i think
Unsigned 64-bit doubles the positive range to ~1.8e19 and better reflects the semantics of non-negative counters (step, consumed/skipped samples). Backward compatible via .item(). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test d4f7fe3 |
PR #3312 intended to bump the TrainState counters from int32 to int64 to prevent overflow on long training runs, but the change landed as torch.uint64. Unlike int64, torch.uint64 has no corresponding TypedStorage subclass, so PyTorch's legacy pickle path (triggered from torch.distributed.broadcast_object_list during _validate_common_state_dict in MCore dist-checkpoint save) falls through to bare torch.UntypedStorage and crashes with: AttributeError: type object 'torch.storage.UntypedStorage' has no attribute 'dtype' That kills one rank mid-collective, causing the observed NCCL watchdog timeout + segfault in the local checkpointing functional tests (test_local_checkpoint_save_and_resume and the most_recent_k variant). Switch the counters to torch.int64, which still avoids int32 overflow (max ~9.2e18 samples) and goes through the standard LongStorage pickle path, unblocking dist-checkpoint save. Update the matching unit-test dtype assertions. Signed-off-by: Chen Cui <chcui@nvidia.com>
QA RCCA Analysis1. Fix Reference
2. Root Cause
3. Trigger Configuration
4. Nature of the BugClassification: CODE BUG - Integer overflow in checkpoint serialization 5. Existing Test CoverageIn Fix PR: YES - 1 test file:
In NMFW Tests: Not specifically covering large sample counts 6. Coverage Assessment
7. New Regression TestNOT NEEDED - Fix PR includes unit tests for TrainState serialization. 8. ConclusionVerdict: ADEQUATE COVERAGE - Fix PR includes unit tests for int64 counter serialization. |
Summary
TrainState.state_dict()serializedstep,consumed_train_samples,skipped_train_samples, andconsumed_valid_samplesastorch.int32tensors, which overflow at ~2.1BRuntimeError: value cannot be converted to type int32 without overflowat checkpoint save timetorch.int64(max ~9.2×10¹⁸)load_state_dict()uses.item()which returns a Python scalar regardless of the source tensor dtype, so existing int32 checkpoints load correctlyNote: Megatron-LM does not have this issue because it pickles the entire
argsnamespace where counters are plain Pythonint(arbitrary precision). Bridge introduced int32 tensors as part of itsStatefulcheckpoint design.Closes NFS-711
Test plan
load_state_dict()works with both int32 (old checkpoints) and int64 (new checkpoints) via.item()Made with Cursor
Summary by CodeRabbit