99#include < unordered_map>
1010#include < unordered_set>
1111
12+ #include " example/common/utils.h"
1213#include " gflags/gflags.h"
1314#include " glog/logging.h"
1415
@@ -85,8 +86,6 @@ DEFINE_string(checkpoint_dir, "./checkpoints", "root directory used to store che
8586DEFINE_uint32 (max_checkpoint_keep, 3 , " max number of checkpoint steps to keep" );
8687DEFINE_bool (save_optimizer_state, true , " whether optimizer state is persisted in checkpoints" );
8788DEFINE_string (checkpoint_format, " bin" , " checkpoint format: bin|pth" );
88- DEFINE_bool (use_llmc_checkpoint_io, false ,
89- " whether to use GPT2 LLMC model.bin callback for checkpoint save/load when format=bin" );
9089// precision check
9190DEFINE_string (
9291 precision_check, " " ,
@@ -337,44 +336,15 @@ void Train(const nn::parallel::Rank &rank) {
337336
338337 int start_step = 0 ;
339338 float best_loss = std::numeric_limits<float >::infinity ();
340- if (!FLAGS_resume_from.empty ()) {
341- std::filesystem::path resume_dir = FLAGS_resume_from;
342- if (rank.IsParallel ()) {
343- const auto rank_dir = resume_dir / std::format (" rank_{:06d}" , rank.GlobalRank ());
344- if (std::filesystem::exists (rank_dir)) {
345- resume_dir = rank_dir;
346- }
347- }
348-
349- TrainerState state;
350- CheckpointLoadOptions load_options;
351- load_options.load_optimizer_state = true ;
352- if (FLAGS_use_llmc_checkpoint_io) {
353- load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) {
354- auto loaded_model = GPT2::FromLLMC (model_path.string ());
355- target_model->LoadStateDict (loaded_model->StateDict ());
356- };
357- }
358- Checkpoint::Load (resume_dir, model.get (), optimizer.get (), &state, load_options);
359- start_step = static_cast <int >(state.global_step );
360- best_loss = state.best_loss ;
361- if (state.data_batch_stride != static_cast <int64_t >(ddp_world_size) && rank.IsMainRank ()) {
362- LOG (WARNING) << std::format (" Checkpoint data_batch_stride {} mismatches current ddp_world_size {}. "
363- " Proceeding with recorded data_batch_idx {}." ,
364- state.data_batch_stride , ddp_world_size, state.data_batch_idx );
365- }
366- saved_data_batch_idx = static_cast <size_t >(std::max<int64_t >(state.data_batch_idx , 0 ));
367- train_iter = train_loader.IteratorAtBatchIndex (saved_data_batch_idx);
368- if (rank.IsMainRank ()) {
369- LOG (INFO) << std::format (
370- " Resume training from step {} with best_loss {:.6f}, last_lr {:.3e}, data_batch_idx {}" ,
371- state.global_step , state.best_loss , state.last_lr , state.data_batch_idx );
372- LOG (INFO) << std::format (" Checkpoint model I/O mode during resume: {}" ,
373- FLAGS_use_llmc_checkpoint_io ? " llmc-callback" : " native-state-dict" );
374- }
375- }
376-
377- LOG (INFO) << " start training" ;
339+ TrainerState state;
340+ CheckpointLoadOptions load_options;
341+ load_options.load_optimizer_state = true ;
342+ load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) {
343+ auto loaded_model = GPT2::FromLLMC (model_path.string ());
344+ target_model->LoadStateDict (loaded_model->StateDict ());
345+ };
346+ std::tie (start_step, best_loss, saved_data_batch_idx) = infini_train::ResumeFromCheckpoint (
347+ FLAGS_resume_from, rank, model, optimizer, train_loader, state, train_iter, load_options);
378348
379349 auto save_checkpoint = [&](const std::filesystem::path &save_dir, int64_t global_step,
380350 bool prune_step_checkpoints) {
@@ -396,11 +366,9 @@ void Train(const nn::parallel::Rank &rank) {
396366 CheckpointOptions options;
397367 options.format = FLAGS_checkpoint_format;
398368 options.save_optimizer_state = FLAGS_save_optimizer_state;
399- if (FLAGS_use_llmc_checkpoint_io) {
400- options.model_bin_writer = [&](const nn::Module &, const std::filesystem::path &model_path) {
401- llmc_model->SaveAsLLMC (model_path.string ());
402- };
403- }
369+ options.model_bin_writer = [&](const nn::Module &, const std::filesystem::path &model_path) {
370+ llmc_model->SaveAsLLMC (model_path.string ());
371+ };
404372 Checkpoint::Save (save_dir, *model, *optimizer, state, options);
405373
406374 const auto ckpt_end = std::chrono::high_resolution_clock::now ();
0 commit comments