@@ -21,6 +21,11 @@ TEST_P(TrainerStateTest, DefaultValues) {
2121 TrainerState state;
2222 EXPECT_EQ (state.global_step , 0 );
2323 EXPECT_EQ (state.consumed_batches , 0 );
24+ EXPECT_EQ (state.n_layer , 0 );
25+ EXPECT_EQ (state.n_head , 0 );
26+ EXPECT_EQ (state.n_kv_head , 0 );
27+ EXPECT_EQ (state.n_embd , 0 );
28+ EXPECT_EQ (state.vocab_size , 0 );
2429 EXPECT_EQ (state.ddp_size , 1 );
2530 EXPECT_EQ (state.tp_size , 1 );
2631 EXPECT_EQ (state.sp_size , 1 );
@@ -48,7 +53,6 @@ TEST_P(TrainerStateTest, TrainerStateFileCreated) {
4853 std::string content ((std::istreambuf_iterator<char >(ifs)), std::istreambuf_iterator<char >());
4954 EXPECT_NE (content.find (" \" global_step\" " ), std::string::npos);
5055 EXPECT_NE (content.find (" \" consumed_batches \" " ), std::string::npos);
51- EXPECT_NE (content.find (" \" Adam\" " ), std::string::npos);
5256
5357 std::filesystem::remove_all (dir);
5458}
@@ -61,6 +65,11 @@ TEST_P(TrainerStateTest, RoundTrip) {
6165 .global_step = 99 ,
6266 .consumed_batches = 5000 ,
6367 .last_lr = 3e-4 ,
68+ .n_layer = 24 ,
69+ .n_head = 16 ,
70+ .n_kv_head = 8 ,
71+ .n_embd = 1024 ,
72+ .vocab_size = 128256 ,
6473 .ddp_size = 2 ,
6574 .tp_size = 1 ,
6675 .sp_size = 1 ,
@@ -87,6 +96,11 @@ TEST_P(TrainerStateTest, RoundTrip) {
8796 EXPECT_EQ (loaded.global_step , 99 );
8897 EXPECT_EQ (loaded.consumed_batches , 5000 );
8998 EXPECT_NEAR (loaded.last_lr , 3e-4 , 1e-10 );
99+ EXPECT_EQ (loaded.n_layer , 24 );
100+ EXPECT_EQ (loaded.n_head , 16 );
101+ EXPECT_EQ (loaded.n_kv_head , 8 );
102+ EXPECT_EQ (loaded.n_embd , 1024 );
103+ EXPECT_EQ (loaded.vocab_size , 128256 );
90104 EXPECT_EQ (loaded.ddp_size , 2 );
91105 EXPECT_EQ (loaded.pp_size , 2 );
92106
0 commit comments