File tree Expand file tree Collapse file tree
src/megatron/bridge/training
tests/unit_tests/training Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -73,10 +73,10 @@ def state_dict(self) -> dict[str, torch.Tensor]:
7373 their corresponding tensor representations.
7474 """
7575 return {
76- "step" : torch .tensor (self .step , dtype = torch .int32 ),
77- "consumed_train_samples" : torch .tensor (self .consumed_train_samples , dtype = torch .int32 ),
78- "skipped_train_samples" : torch .tensor (self .skipped_train_samples , dtype = torch .int32 ),
79- "consumed_valid_samples" : torch .tensor (self .consumed_valid_samples , dtype = torch .int32 ),
76+ "step" : torch .tensor (self .step , dtype = torch .uint64 ),
77+ "consumed_train_samples" : torch .tensor (self .consumed_train_samples , dtype = torch .uint64 ),
78+ "skipped_train_samples" : torch .tensor (self .skipped_train_samples , dtype = torch .uint64 ),
79+ "consumed_valid_samples" : torch .tensor (self .consumed_valid_samples , dtype = torch .uint64 ),
8080 "floating_point_operations_so_far" : torch .tensor (
8181 self .floating_point_operations_so_far , dtype = torch .float64
8282 ),
Original file line number Diff line number Diff line change @@ -86,11 +86,11 @@ def test_state_dict_structure_and_types(self):
8686 }
8787 assert set (state_dict .keys ()) == expected_keys
8888
89- # Check tensor types
90- assert state_dict ["step" ].dtype == torch .int32
91- assert state_dict ["consumed_train_samples" ].dtype == torch .int32
92- assert state_dict ["skipped_train_samples" ].dtype == torch .int32
93- assert state_dict ["consumed_valid_samples" ].dtype == torch .int32
89+ # Check tensor types (uint64 to avoid overflow on long training runs)
90+ assert state_dict ["step" ].dtype == torch .uint64
91+ assert state_dict ["consumed_train_samples" ].dtype == torch .uint64
92+ assert state_dict ["skipped_train_samples" ].dtype == torch .uint64
93+ assert state_dict ["consumed_valid_samples" ].dtype == torch .uint64
9494 assert state_dict ["floating_point_operations_so_far" ].dtype == torch .float64
9595 assert state_dict ["do_train" ].dtype == torch .bool
9696 assert state_dict ["do_valid" ].dtype == torch .bool
You can’t perform that action at this time.
0 commit comments