Skip to content

[training] fix: use int64 for TrainState counters to prevent overflow#3312

Merged
yaoyu-33 merged 3 commits intomainfrom
yuya/fix-train-state-int32-overflow
Apr 16, 2026
Merged

[training] fix: use int64 for TrainState counters to prevent overflow#3312
yaoyu-33 merged 3 commits intomainfrom
yuya/fix-train-state-int32-overflow

Conversation

@yaoyu-33
Copy link
Copy Markdown
Contributor

@yaoyu-33 yaoyu-33 commented Apr 14, 2026

Summary

  • TrainState.state_dict() serialized step, consumed_train_samples, skipped_train_samples, and consumed_valid_samples as torch.int32 tensors, which overflow at ~2.1B
  • Long training runs with large global batch sizes exceed this limit (e.g. 700K iterations × GBS 3072 = 2.15B consumed samples), causing RuntimeError: value cannot be converted to type int32 without overflow at checkpoint save time
  • Fix: switch to torch.int64 (max ~9.2×10¹⁸)
  • Backward compatible — load_state_dict() uses .item() which returns a Python scalar regardless of the source tensor dtype, so existing int32 checkpoints load correctly

Note: Megatron-LM does not have this issue because it pickles the entire args namespace where counters are plain Python int (arbitrary precision). Bridge introduced int32 tensors as part of its Stateful checkpoint design.

Closes NFS-711

Test plan

  • Verify load_state_dict() works with both int32 (old checkpoints) and int64 (new checkpoints) via .item()
  • CI unit tests pass

Made with Cursor

Summary by CodeRabbit

  • Changes
    • Modified training state checkpoint serialization format for progress counters.

TrainState.state_dict() serialized step and sample counters as
torch.int32 tensors, which overflow at ~2.1B. Long training runs with
large batch sizes (e.g. 700K iters × GBS 3072 = 2.15B samples) hit
RuntimeError at checkpoint save time.

Switch to torch.int64, which supports up to ~9.2e18. Backward
compatible: load_state_dict() uses .item() which returns a Python
scalar regardless of the source tensor dtype.

Signed-off-by: Yu Yao <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Made-with: Cursor
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 14, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@yaoyu-33
Copy link
Copy Markdown
Contributor Author

/ok to test a74a3ed

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 14, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: ba838a5c-93ca-4748-9aae-9d11e3117422

📥 Commits

Reviewing files that changed from the base of the PR and between ad27e2c and a74a3ed.

📒 Files selected for processing (1)
  • src/megatron/bridge/training/state.py

📝 Walkthrough

Walkthrough

The TrainState.state_dict() method now serializes four integer progress counters (step, consumed_train_samples, skipped_train_samples, consumed_valid_samples) as torch.int64 tensors instead of torch.int32, maintaining consistency across the state serialization.

Changes

Cohort / File(s) Summary
Integer Tensor Dtype Update
src/megatron/bridge/training/state.py
Changed serialization of four progress counter tensors from torch.int32 to torch.int64 in the state_dict() method: step, consumed_train_samples, skipped_train_samples, and consumed_valid_samples.

Estimated code review effort

🎯 1 (Trivial) | ⏱️ ~3 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main change: converting TrainState counters from int32 to int64 to prevent overflow, which is the core modification in this PR.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Test Results For Major Changes ✅ Passed Minor bug fix preventing integer overflow in serialization by changing tensor dtype from int32 to int64 with maintained backward compatibility and verified through existing unit tests.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch yuya/fix-train-state-int32-overflow

Comment @coderabbitai help to get the list of available commands and usage tips.

The TrainState counters were changed from int32 to int64 to prevent
overflow on long training runs, but the unit test still asserted int32.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Copy link
Copy Markdown
Contributor

@maanug-nv maanug-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about uint32 or uint64? these values shouldn't ever be negative i think

@yaoyu-33 yaoyu-33 added bug Something isn't working area:training Training loop, callbacks, and runtime integration needs-review PR is ready for code review and waiting on a reviewer labels Apr 16, 2026
cuichenx
cuichenx previously approved these changes Apr 16, 2026
Unsigned 64-bit doubles the positive range to ~1.8e19 and better
reflects the semantics of non-negative counters (step, consumed/skipped
samples). Backward compatible via .item().

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@yaoyu-33
Copy link
Copy Markdown
Contributor Author

/ok to test d4f7fe3

@yaoyu-33 yaoyu-33 merged commit bcf9c6e into main Apr 16, 2026
53 of 55 checks passed
@yaoyu-33 yaoyu-33 deleted the yuya/fix-train-state-int32-overflow branch April 16, 2026 23:41
cuichenx added a commit that referenced this pull request Apr 17, 2026
PR #3312 intended to bump the TrainState counters from int32 to int64 to
prevent overflow on long training runs, but the change landed as
torch.uint64. Unlike int64, torch.uint64 has no corresponding TypedStorage
subclass, so PyTorch's legacy pickle path (triggered from
torch.distributed.broadcast_object_list during
_validate_common_state_dict in MCore dist-checkpoint save) falls through
to bare torch.UntypedStorage and crashes with:

  AttributeError: type object 'torch.storage.UntypedStorage' has no
  attribute 'dtype'

That kills one rank mid-collective, causing the observed NCCL
watchdog timeout + segfault in the local checkpointing functional tests
(test_local_checkpoint_save_and_resume and the most_recent_k variant).

Switch the counters to torch.int64, which still avoids int32 overflow
(max ~9.2e18 samples) and goes through the standard LongStorage pickle
path, unblocking dist-checkpoint save. Update the matching unit-test
dtype assertions.

Signed-off-by: Chen Cui <chcui@nvidia.com>
@pruprakash
Copy link
Copy Markdown

QA RCCA Analysis

1. Fix Reference

2. Root Cause

TrainState.state_dict() serialized counters (step, consumed_train_samples, etc.) as torch.int32 tensors, which overflow at ~2.1B. Long training runs with large global batch sizes exceed this limit (e.g. 700K iterations × GBS 3072 = 2.15B samples), causing RuntimeError: value cannot be converted to type int32 without overflow.

3. Trigger Configuration

  • Long training runs with large global batch sizes
  • Checkpoint save after >2.1B consumed samples

4. Nature of the Bug

Classification: CODE BUG - Integer overflow in checkpoint serialization

5. Existing Test Coverage

In Fix PR: YES - 1 test file:

  • tests/unit_tests/training/test_state.py

In NMFW Tests: Not specifically covering large sample counts

6. Coverage Assessment

Test Type Exists Covers Bug
Fix PR unit tests YES YES
NMFW regression tests NO N/A

7. New Regression Test

NOT NEEDED - Fix PR includes unit tests for TrainState serialization.

8. Conclusion

Verdict: ADEQUATE COVERAGE - Fix PR includes unit tests for int64 counter serialization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:training Training loop, callbacks, and runtime integration bug Something isn't working needs-review PR is ready for code review and waiting on a reviewer qa_rcca_done

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants