diff --git a/src/megatron/bridge/training/state.py b/src/megatron/bridge/training/state.py index 14113e4d7b..ea839a4ec8 100644 --- a/src/megatron/bridge/training/state.py +++ b/src/megatron/bridge/training/state.py @@ -70,10 +70,10 @@ def state_dict(self) -> dict[str, torch.Tensor]: their corresponding tensor representations. """ return { - "step": torch.tensor(self.step, dtype=torch.int32), - "consumed_train_samples": torch.tensor(self.consumed_train_samples, dtype=torch.int32), - "skipped_train_samples": torch.tensor(self.skipped_train_samples, dtype=torch.int32), - "consumed_valid_samples": torch.tensor(self.consumed_valid_samples, dtype=torch.int32), + "step": torch.tensor(self.step, dtype=torch.uint64), + "consumed_train_samples": torch.tensor(self.consumed_train_samples, dtype=torch.uint64), + "skipped_train_samples": torch.tensor(self.skipped_train_samples, dtype=torch.uint64), + "consumed_valid_samples": torch.tensor(self.consumed_valid_samples, dtype=torch.uint64), "floating_point_operations_so_far": torch.tensor( self.floating_point_operations_so_far, dtype=torch.float64 ), diff --git a/tests/unit_tests/training/test_state.py b/tests/unit_tests/training/test_state.py index fc4dcfd21e..1583b52d30 100644 --- a/tests/unit_tests/training/test_state.py +++ b/tests/unit_tests/training/test_state.py @@ -86,11 +86,11 @@ def test_state_dict_structure_and_types(self): } assert set(state_dict.keys()) == expected_keys - # Check tensor types - assert state_dict["step"].dtype == torch.int32 - assert state_dict["consumed_train_samples"].dtype == torch.int32 - assert state_dict["skipped_train_samples"].dtype == torch.int32 - assert state_dict["consumed_valid_samples"].dtype == torch.int32 + # Check tensor types (uint64 to avoid overflow on long training runs) + assert state_dict["step"].dtype == torch.uint64 + assert state_dict["consumed_train_samples"].dtype == torch.uint64 + assert state_dict["skipped_train_samples"].dtype == torch.uint64 + assert state_dict["consumed_valid_samples"].dtype == torch.uint64 assert state_dict["floating_point_operations_so_far"].dtype == torch.float64 assert state_dict["do_train"].dtype == torch.bool assert state_dict["do_valid"].dtype == torch.bool