|
9 | 9 | #include <unordered_map> |
10 | 10 | #include <unordered_set> |
11 | 11 |
|
| 12 | +#include "example/common/utils.h" |
12 | 13 | #include "gflags/gflags.h" |
13 | 14 | #include "glog/logging.h" |
14 | 15 |
|
@@ -85,8 +86,6 @@ DEFINE_string(checkpoint_dir, "./checkpoints", "root directory used to store che |
85 | 86 | DEFINE_uint32(max_checkpoint_keep, 3, "max number of checkpoint steps to keep"); |
86 | 87 | DEFINE_bool(save_optimizer_state, true, "whether optimizer state is persisted in checkpoints"); |
87 | 88 | DEFINE_string(checkpoint_format, "bin", "checkpoint format: bin|pth"); |
88 | | -DEFINE_bool(use_llmc_checkpoint_io, false, |
89 | | - "whether to use GPT2 LLMC model.bin callback for checkpoint save/load when format=bin"); |
90 | 89 | // precision check |
91 | 90 | DEFINE_string( |
92 | 91 | precision_check, "", |
@@ -337,44 +336,15 @@ void Train(const nn::parallel::Rank &rank) { |
337 | 336 |
|
338 | 337 | int start_step = 0; |
339 | 338 | float best_loss = std::numeric_limits<float>::infinity(); |
340 | | - if (!FLAGS_resume_from.empty()) { |
341 | | - std::filesystem::path resume_dir = FLAGS_resume_from; |
342 | | - if (rank.IsParallel()) { |
343 | | - const auto rank_dir = resume_dir / std::format("rank_{:06d}", rank.GlobalRank()); |
344 | | - if (std::filesystem::exists(rank_dir)) { |
345 | | - resume_dir = rank_dir; |
346 | | - } |
347 | | - } |
348 | | - |
349 | | - TrainerState state; |
350 | | - CheckpointLoadOptions load_options; |
351 | | - load_options.load_optimizer_state = true; |
352 | | - if (FLAGS_use_llmc_checkpoint_io) { |
353 | | - load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) { |
354 | | - auto loaded_model = GPT2::FromLLMC(model_path.string()); |
355 | | - target_model->LoadStateDict(loaded_model->StateDict()); |
356 | | - }; |
357 | | - } |
358 | | - Checkpoint::Load(resume_dir, model.get(), optimizer.get(), &state, load_options); |
359 | | - start_step = static_cast<int>(state.global_step); |
360 | | - best_loss = state.best_loss; |
361 | | - if (state.data_batch_stride != static_cast<int64_t>(ddp_world_size) && rank.IsMainRank()) { |
362 | | - LOG(WARNING) << std::format("Checkpoint data_batch_stride {} mismatches current ddp_world_size {}. " |
363 | | - "Proceeding with recorded data_batch_idx {}.", |
364 | | - state.data_batch_stride, ddp_world_size, state.data_batch_idx); |
365 | | - } |
366 | | - saved_data_batch_idx = static_cast<size_t>(std::max<int64_t>(state.data_batch_idx, 0)); |
367 | | - train_iter = train_loader.IteratorAtBatchIndex(saved_data_batch_idx); |
368 | | - if (rank.IsMainRank()) { |
369 | | - LOG(INFO) << std::format( |
370 | | - "Resume training from step {} with best_loss {:.6f}, last_lr {:.3e}, data_batch_idx {}", |
371 | | - state.global_step, state.best_loss, state.last_lr, state.data_batch_idx); |
372 | | - LOG(INFO) << std::format("Checkpoint model I/O mode during resume: {}", |
373 | | - FLAGS_use_llmc_checkpoint_io ? "llmc-callback" : "native-state-dict"); |
374 | | - } |
375 | | - } |
376 | | - |
377 | | - LOG(INFO) << "start training"; |
| 339 | + TrainerState state; |
| 340 | + CheckpointLoadOptions load_options; |
| 341 | + load_options.load_optimizer_state = true; |
| 342 | + load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) { |
| 343 | + auto loaded_model = GPT2::FromLLMC(model_path.string()); |
| 344 | + target_model->LoadStateDict(loaded_model->StateDict()); |
| 345 | + }; |
| 346 | + std::tie(start_step, best_loss, saved_data_batch_idx) = infini_train::ResumeFromCheckpoint( |
| 347 | + FLAGS_resume_from, rank, model, optimizer, train_loader, state, train_iter, load_options); |
378 | 348 |
|
379 | 349 | auto save_checkpoint = [&](const std::filesystem::path &save_dir, int64_t global_step, |
380 | 350 | bool prune_step_checkpoints) { |
|
0 commit comments