[training] fix: use int64 (not uint64) for TrainState counters#3380
[training] fix: use int64 (not uint64) for TrainState counters#3380
Conversation
PR #3312 intended to bump the TrainState counters from int32 to int64 to prevent overflow on long training runs, but the change landed as torch.uint64. Unlike int64, torch.uint64 has no corresponding TypedStorage subclass, so PyTorch's legacy pickle path (triggered from torch.distributed.broadcast_object_list during _validate_common_state_dict in MCore dist-checkpoint save) falls through to bare torch.UntypedStorage and crashes with: AttributeError: type object 'torch.storage.UntypedStorage' has no attribute 'dtype' That kills one rank mid-collective, causing the observed NCCL watchdog timeout + segfault in the local checkpointing functional tests (test_local_checkpoint_save_and_resume and the most_recent_k variant). Switch the counters to torch.int64, which still avoids int32 overflow (max ~9.2e18 samples) and goes through the standard LongStorage pickle path, unblocking dist-checkpoint save. Update the matching unit-test dtype assertions. Signed-off-by: Chen Cui <chcui@nvidia.com>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Plus Run ID: 📒 Files selected for processing (2)
📝 WalkthroughWalkthroughThis PR changes the tensor dtype of four progress counter fields ( Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes Possibly related PRs
Suggested labels
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
Summary
TrainStatecounters (step,consumed_train_samples,skipped_train_samples,consumed_valid_samples) fromtorch.uint64back totorch.int64to restore dist-checkpoint save on current PyTorch.tests/unit_tests/training/test_state.py.Root cause
PR #3312 intended to bump these counters from
int32toint64to prevent overflow on long training runs, but the change landed astorch.uint64. Unlikeint64,uint64has no correspondingTypedStoragesubclass in PyTorch. Duringsave_checkpoint→MCoreTensorAwareStateDict.from_state_dict→_validate_common_state_dict, PyTorch invokestorch.distributed.broadcast_object_list, which pickles the common state dict on rank 0 and unpickles it on the other ranks via_tensor_to_object→torch.load→_legacy_load.persistent_load. That path executes:For normal dtypes,
storage_typeis aTypedStoragesubclass exposing.dtypeas a class attribute. Foruint64, it falls through to baretorch.UntypedStorage, which has no class-level.dtype, producing:One rank crashes mid-collective → remaining ranks hang on the next NCCL op → watchdog timeout + segfault in teardown. This matches the regressions in
tests/functional_tests/test_groups/training/test_local_checkpointing.py::test_local_checkpoint_save_and_resumeand::test_local_checkpoint_save_resume_with_most_recent_k.Full traceback excerpt from the failing functional test:
Fix
Switch the four counter dtypes to
torch.int64. This still satisfies the original intent of PR #3312 (max value ~9.2e18 samples, no realistic overflow risk) and goes through the standardLongStoragetyped-storage pickle path, so dist-checkpoint save works again.Test plan
uv run pre-commit runpasses on touched files.uv run python -m pytest tests/unit_tests/training/test_state.py -vpasses with updatedint64assertions.tests/functional_tests/test_groups/training/test_local_checkpointing.py::TestLocalCheckpointing::test_local_checkpoint_save_and_resumetests/functional_tests/test_groups/training/test_local_checkpointing.py::TestLocalCheckpointing::test_local_checkpoint_save_resume_with_most_recent_kSummary by CodeRabbit