99#include " glog/logging.h"
1010
1111#include " infini_train/include/autocast.h"
12- #include " infini_train/include/checkpoint.h"
12+ #include " infini_train/include/checkpoint/checkpoint.h"
13+ #include " infini_train/include/checkpoint/checkpoint_manager.h"
1314#include " infini_train/include/core/runtime/device_guard.h"
1415#include " infini_train/include/dataloader.h"
1516#include " infini_train/include/device.h"
3536#include " infini_train/include/profiler.h"
3637#endif
3738
38- #include " example/common/checkpoint_loader.h"
3939#include " example/common/tiny_shakespeare_dataset.h"
4040#include " example/common/tokenizer.h"
4141#include " example/llama3/checkpoint_loader.h"
@@ -83,7 +83,7 @@ DEFINE_uint32(save_interval, 0, "save checkpoint every N steps; 0 disables savin
8383DEFINE_string (load, " " , " checkpoint directory to resume from" );
8484DEFINE_string (save, " ./checkpoints" , " root directory used to store checkpoints" );
8585DEFINE_uint32 (max_checkpoint_keep, 3 , " max number of checkpoint steps to keep" );
86- DEFINE_bool (no_save_optim, false , " whether optimizer state is persisted in checkpoints" );
86+ DEFINE_bool (save_optimizer_state, true , " whether optimizer state is persisted in checkpoints" );
8787
8888// precision check
8989DEFINE_string (
@@ -305,14 +305,13 @@ void Train(const nn::parallel::Rank &rank) {
305305
306306 int start_step = 0 ;
307307 TrainerState state;
308- const auto resume_result = ResumeFromCheckpoint ({
309- .resume_root = FLAGS_load,
310- .rank = rank,
311- .model = model,
312- .optimizer = optimizer,
313- .model_config = model_config,
314- .state = state,
315- });
308+ const auto resume_result = ResumeFromCheckpoint ({.resume_root = FLAGS_load,
309+ .rank = rank,
310+ .model = model,
311+ .optimizer = optimizer,
312+ .model_config = model_config,
313+ .state = state,
314+ .load_optimizer_state = true });
316315
317316 start_step = resume_result.global_step ;
318317 size_t consumed_batches = resume_result.consumed_batches ;
@@ -327,31 +326,29 @@ void Train(const nn::parallel::Rank &rank) {
327326 for (size_t i = 0 ; i < num_skips; ++i) { ++train_iter; }
328327 }
329328
330- auto save_checkpoint
331- = [&](const std::filesystem::path &save_dir, int64_t global_step, bool prune_step_checkpoints) {
332- SaveCheckpoint ({
333- .save_dir = save_dir,
334- .global_step = global_step,
335- .consumed_batches = consumed_batches,
336- .last_lr = FLAGS_learning_rate,
337- .n_layer = model_config.n_layer ,
338- .n_head = model_config.n_head ,
339- .n_kv_head = model_config.n_kv_head ,
340- .n_embd = model_config.n_embd ,
341- .vocab_size = model_config.vocab_size ,
342- .ddp_size = ddp_world_size,
343- .tp_size = tp_world_size,
344- .sp_size = sp_world_size,
345- .pp_size = pp_world_size,
346- .no_save_optim = FLAGS_no_save_optim,
347- .prune_step_checkpoints = prune_step_checkpoints,
348- .checkpoint_root_dir = FLAGS_save,
349- .max_checkpoint_keep = FLAGS_max_checkpoint_keep,
350- .rank = rank,
351- .model = *model,
352- .optimizer = *optimizer,
353- });
354- };
329+ auto save_checkpoint = [&](const std::filesystem::path &save_dir, int64_t global_step) {
330+ SaveCheckpoint ({
331+ .save_dir = save_dir,
332+ .global_step = global_step,
333+ .consumed_batches = consumed_batches,
334+ .last_lr = FLAGS_learning_rate,
335+ .n_layer = model_config.n_layer ,
336+ .n_head = model_config.n_head ,
337+ .n_kv_head = model_config.n_kv_head ,
338+ .n_embd = model_config.n_embd ,
339+ .vocab_size = model_config.vocab_size ,
340+ .ddp_size = ddp_world_size,
341+ .tp_size = tp_world_size,
342+ .sp_size = sp_world_size,
343+ .pp_size = pp_world_size,
344+ .save_optimizer_state = FLAGS_save_optimizer_state,
345+ .checkpoint_root_dir = FLAGS_save,
346+ .max_checkpoint_keep = FLAGS_max_checkpoint_keep,
347+ .rank = rank,
348+ .model = *model,
349+ .optimizer = *optimizer,
350+ });
351+ };
355352
356353 for (int step = start_step; step < FLAGS_num_iteration + 1 ; ++step) {
357354 // Reset precision check counters at start of each iteration for file overwrite
@@ -475,7 +472,7 @@ void Train(const nn::parallel::Rank &rank) {
475472 if (rank.IsParallel ()) {
476473 step_dir /= std::format (" rank_{:06d}" , rank.GlobalRank ());
477474 }
478- save_checkpoint (step_dir, step + 1 , true );
475+ save_checkpoint (step_dir, step + 1 );
479476 }
480477 }
481478
@@ -489,7 +486,7 @@ void Train(const nn::parallel::Rank &rank) {
489486 if (rank.IsParallel ()) {
490487 final_dir /= std::format (" rank_{:06d}" , rank.GlobalRank ());
491488 }
492- save_checkpoint (final_dir, FLAGS_num_iteration, false );
489+ save_checkpoint (final_dir, FLAGS_num_iteration);
493490
494491#ifdef PROFILE_MODE
495492 Profiler::Instance ().Report (" llama3.report" , Profiler::SortBy::DeviceTimePercentage);
0 commit comments