Skip to content

Commit bcf9c6e

Browse files
yaoyu-33claude
andauthored
[training] fix: use int64 for TrainState counters to prevent overflow (#3312)
Signed-off-by: Yu Yao <yaoyu.094@gmail.com> Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 234a770 commit bcf9c6e

2 files changed

Lines changed: 9 additions & 9 deletions

File tree

src/megatron/bridge/training/state.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ def state_dict(self) -> dict[str, torch.Tensor]:
7373
their corresponding tensor representations.
7474
"""
7575
return {
76-
"step": torch.tensor(self.step, dtype=torch.int32),
77-
"consumed_train_samples": torch.tensor(self.consumed_train_samples, dtype=torch.int32),
78-
"skipped_train_samples": torch.tensor(self.skipped_train_samples, dtype=torch.int32),
79-
"consumed_valid_samples": torch.tensor(self.consumed_valid_samples, dtype=torch.int32),
76+
"step": torch.tensor(self.step, dtype=torch.uint64),
77+
"consumed_train_samples": torch.tensor(self.consumed_train_samples, dtype=torch.uint64),
78+
"skipped_train_samples": torch.tensor(self.skipped_train_samples, dtype=torch.uint64),
79+
"consumed_valid_samples": torch.tensor(self.consumed_valid_samples, dtype=torch.uint64),
8080
"floating_point_operations_so_far": torch.tensor(
8181
self.floating_point_operations_so_far, dtype=torch.float64
8282
),

tests/unit_tests/training/test_state.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,11 @@ def test_state_dict_structure_and_types(self):
8686
}
8787
assert set(state_dict.keys()) == expected_keys
8888

89-
# Check tensor types
90-
assert state_dict["step"].dtype == torch.int32
91-
assert state_dict["consumed_train_samples"].dtype == torch.int32
92-
assert state_dict["skipped_train_samples"].dtype == torch.int32
93-
assert state_dict["consumed_valid_samples"].dtype == torch.int32
89+
# Check tensor types (uint64 to avoid overflow on long training runs)
90+
assert state_dict["step"].dtype == torch.uint64
91+
assert state_dict["consumed_train_samples"].dtype == torch.uint64
92+
assert state_dict["skipped_train_samples"].dtype == torch.uint64
93+
assert state_dict["consumed_valid_samples"].dtype == torch.uint64
9494
assert state_dict["floating_point_operations_so_far"].dtype == torch.float64
9595
assert state_dict["do_train"].dtype == torch.bool
9696
assert state_dict["do_valid"].dtype == torch.bool

0 commit comments

Comments
 (0)