Skip to content

[training] fix: use int64 (not uint64) for TrainState counters#3380

Merged
cuichenx merged 1 commit intomainfrom
chcui/fix-trainstate-uint64-to-int64
Apr 17, 2026
Merged

[training] fix: use int64 (not uint64) for TrainState counters#3380
cuichenx merged 1 commit intomainfrom
chcui/fix-trainstate-uint64-to-int64

Conversation

@cuichenx
Copy link
Copy Markdown
Contributor

@cuichenx cuichenx commented Apr 17, 2026

Summary

  • Switch TrainState counters (step, consumed_train_samples, skipped_train_samples, consumed_valid_samples) from torch.uint64 back to torch.int64 to restore dist-checkpoint save on current PyTorch.
  • Update the matching dtype assertions in tests/unit_tests/training/test_state.py.

Root cause

PR #3312 intended to bump these counters from int32 to int64 to prevent overflow on long training runs, but the change landed as torch.uint64. Unlike int64, uint64 has no corresponding TypedStorage subclass in PyTorch. During save_checkpointMCoreTensorAwareStateDict.from_state_dict_validate_common_state_dict, PyTorch invokes torch.distributed.broadcast_object_list, which pickles the common state dict on rank 0 and unpickles it on the other ranks via _tensor_to_objecttorch.load_legacy_load.persistent_load. That path executes:

dtype = storage_type.dtype

For normal dtypes, storage_type is a TypedStorage subclass exposing .dtype as a class attribute. For uint64, it falls through to bare torch.UntypedStorage, which has no class-level .dtype, producing:

AttributeError: type object 'torch.storage.UntypedStorage' has no attribute 'dtype'

One rank crashes mid-collective → remaining ranks hang on the next NCCL op → watchdog timeout + segfault in teardown. This matches the regressions in tests/functional_tests/test_groups/training/test_local_checkpointing.py::test_local_checkpoint_save_and_resume and ::test_local_checkpoint_save_resume_with_most_recent_k.

Full traceback excerpt from the failing functional test:

src/megatron/bridge/training/checkpointing.py:1021: in save_checkpoint
    state_dict_for_save, cacheable_metadata = MCoreTensorAwareStateDict.from_state_dict(
3rdparty/Megatron-LM/megatron/core/dist_checkpointing/validation.py:338: in _validate_common_state_dict
    torch.distributed.broadcast_object_list(object_list, src=0)
...
/usr/local/lib/python3.12/dist-packages/torch/serialization.py:1785: in persistent_load
    dtype = storage_type.dtype
E   AttributeError: type object 'torch.storage.UntypedStorage' has no attribute 'dtype'

Fix

Switch the four counter dtypes to torch.int64. This still satisfies the original intent of PR #3312 (max value ~9.2e18 samples, no realistic overflow risk) and goes through the standard LongStorage typed-storage pickle path, so dist-checkpoint save works again.

Test plan

  • uv run pre-commit run passes on touched files.
  • uv run python -m pytest tests/unit_tests/training/test_state.py -v passes with updated int64 assertions.
  • Local-checkpoint functional tests pass on CI:
    • tests/functional_tests/test_groups/training/test_local_checkpointing.py::TestLocalCheckpointing::test_local_checkpoint_save_and_resume
    • tests/functional_tests/test_groups/training/test_local_checkpointing.py::TestLocalCheckpointing::test_local_checkpoint_save_resume_with_most_recent_k

Summary by CodeRabbit

  • Bug Fixes
    • Fixed training state serialization compatibility with distributed checkpoint operations.

PR #3312 intended to bump the TrainState counters from int32 to int64 to
prevent overflow on long training runs, but the change landed as
torch.uint64. Unlike int64, torch.uint64 has no corresponding TypedStorage
subclass, so PyTorch's legacy pickle path (triggered from
torch.distributed.broadcast_object_list during
_validate_common_state_dict in MCore dist-checkpoint save) falls through
to bare torch.UntypedStorage and crashes with:

  AttributeError: type object 'torch.storage.UntypedStorage' has no
  attribute 'dtype'

That kills one rank mid-collective, causing the observed NCCL
watchdog timeout + segfault in the local checkpointing functional tests
(test_local_checkpoint_save_and_resume and the most_recent_k variant).

Switch the counters to torch.int64, which still avoids int32 overflow
(max ~9.2e18 samples) and goes through the standard LongStorage pickle
path, unblocking dist-checkpoint save. Update the matching unit-test
dtype assertions.

Signed-off-by: Chen Cui <chcui@nvidia.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 17, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: f138ad09-d5be-4108-9b68-b7985079924b

📥 Commits

Reviewing files that changed from the base of the PR and between 74695a0 and 99439e2.

📒 Files selected for processing (2)
  • src/megatron/bridge/training/state.py
  • tests/unit_tests/training/test_state.py

📝 Walkthrough

Walkthrough

This PR changes the tensor dtype of four progress counter fields (step, consumed_train_samples, skipped_train_samples, consumed_valid_samples) in TrainState.state_dict() from torch.uint64 to torch.int64 to resolve compatibility issues with legacy pickle serialization during distributed checkpoint saving. The corresponding test assertions are updated to reflect these dtype changes.

Changes

Cohort / File(s) Summary
TrainState dtype conversion
src/megatron/bridge/training/state.py
Changed serialized tensor dtypes from torch.uint64 to torch.int64 for four integer progress counters in the state_dict() method.
Test assertion updates
tests/unit_tests/training/test_state.py
Updated test assertions to expect torch.int64 instead of torch.uint64 for the same four progress counter fields; adjusted comment explaining the incompatibility with legacy pickle path.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~8 minutes

Possibly related PRs

Suggested labels

bug, area:training, needs-review

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR description states 'Pending: run unit tests' indicating test results have not been documented or completed. Complete pending unit tests and functional checkpoint tests, then document results in PR description to demonstrate the fix works.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main change: switching TrainState counter dtypes from uint64 to int64 for serialization compatibility.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch chcui/fix-trainstate-uint64-to-int64

Comment @coderabbitai help to get the list of available commands and usage tips.

@cuichenx cuichenx enabled auto-merge (squash) April 17, 2026 21:34
@cuichenx cuichenx merged commit 299b615 into main Apr 17, 2026
51 checks passed
@cuichenx cuichenx deleted the chcui/fix-trainstate-uint64-to-int64 branch April 17, 2026 21:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants