Skip to content

Commit d4f7fe3

Browse files
yaoyu-33claude
andcommitted
[training] fix: use uint64 instead of int64 for TrainState counters
Unsigned 64-bit doubles the positive range to ~1.8e19 and better reflects the semantics of non-negative counters (step, consumed/skipped samples). Backward compatible via .item(). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
1 parent 45f7814 commit d4f7fe3

2 files changed

Lines changed: 9 additions & 9 deletions

File tree

src/megatron/bridge/training/state.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,10 @@ def state_dict(self) -> dict[str, torch.Tensor]:
7070
their corresponding tensor representations.
7171
"""
7272
return {
73-
"step": torch.tensor(self.step, dtype=torch.int64),
74-
"consumed_train_samples": torch.tensor(self.consumed_train_samples, dtype=torch.int64),
75-
"skipped_train_samples": torch.tensor(self.skipped_train_samples, dtype=torch.int64),
76-
"consumed_valid_samples": torch.tensor(self.consumed_valid_samples, dtype=torch.int64),
73+
"step": torch.tensor(self.step, dtype=torch.uint64),
74+
"consumed_train_samples": torch.tensor(self.consumed_train_samples, dtype=torch.uint64),
75+
"skipped_train_samples": torch.tensor(self.skipped_train_samples, dtype=torch.uint64),
76+
"consumed_valid_samples": torch.tensor(self.consumed_valid_samples, dtype=torch.uint64),
7777
"floating_point_operations_so_far": torch.tensor(
7878
self.floating_point_operations_so_far, dtype=torch.float64
7979
),

tests/unit_tests/training/test_state.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,11 @@ def test_state_dict_structure_and_types(self):
8686
}
8787
assert set(state_dict.keys()) == expected_keys
8888

89-
# Check tensor types (int64 to avoid overflow on long training runs)
90-
assert state_dict["step"].dtype == torch.int64
91-
assert state_dict["consumed_train_samples"].dtype == torch.int64
92-
assert state_dict["skipped_train_samples"].dtype == torch.int64
93-
assert state_dict["consumed_valid_samples"].dtype == torch.int64
89+
# Check tensor types (uint64 to avoid overflow on long training runs)
90+
assert state_dict["step"].dtype == torch.uint64
91+
assert state_dict["consumed_train_samples"].dtype == torch.uint64
92+
assert state_dict["skipped_train_samples"].dtype == torch.uint64
93+
assert state_dict["consumed_valid_samples"].dtype == torch.uint64
9494
assert state_dict["floating_point_operations_so_far"].dtype == torch.float64
9595
assert state_dict["do_train"].dtype == torch.bool
9696
assert state_dict["do_valid"].dtype == torch.bool

0 commit comments

Comments
 (0)