|
1 | 1 | #include "example/common/utils.h" |
2 | 2 |
|
| 3 | +#include <algorithm> |
| 4 | +#include <chrono> |
| 5 | + |
3 | 6 | #include "gflags/gflags.h" |
4 | 7 | #include "gflags/gflags_declare.h" |
5 | 8 | #include "glog/logging.h" |
@@ -66,53 +69,91 @@ void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t s |
66 | 69 | ifs.seekg(base + std::streamoff(len * sizeof(float))); |
67 | 70 | } |
68 | 71 |
|
69 | | -std::tuple<int, float, size_t> ResumeFromCheckpoint( |
70 | | - const fLS::clstring &flag_resume_root, // resume from this checkpoint directory |
71 | | - const nn::parallel::Rank &rank, // rank info for distributed training |
72 | | - std::shared_ptr<nn::Module> model, // model to be loaded with checkpoint state |
73 | | - std::shared_ptr<Optimizer> optimizer, // some optimizer may not have state, but others may have |
74 | | - DistributedDataLoader &train_loader, // distributed dataloader to be resumed |
75 | | - TrainerState &state, // trainer state to be loaded from checkpoint |
76 | | - DataLoaderIterator |
77 | | - &train_iter, // dataloader iterator to be set to the correct position according to checkpoint state |
78 | | - CheckpointLoadOptions model_bin_loader) { |
79 | | - int global_step = 0; |
80 | | - float best_loss = std::numeric_limits<float>::infinity(); |
81 | | - size_t data_batch_idx = 0; |
82 | | - |
| 72 | +ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args) { |
| 73 | + ResumeFromCheckpointResult result; |
83 | 74 | int ddp_world_size = nn::parallel::global::GetDataParallelSize(); |
84 | 75 |
|
85 | | - if (flag_resume_root.empty()) { |
| 76 | + if (args.resume_root.empty()) { |
86 | 77 | LOG(INFO) << "No checkpoint specified for resume. Starting training from scratch."; |
87 | | - return {global_step, best_loss, data_batch_idx}; |
| 78 | + return result; |
88 | 79 | } |
89 | 80 |
|
90 | | - std::filesystem::path resume_dir = flag_resume_root; |
91 | | - if (rank.IsParallel()) { |
92 | | - const auto rank_dir = resume_dir / std::format("rank_{:06d}", rank.GlobalRank()); |
| 81 | + std::filesystem::path resume_dir = args.resume_root; |
| 82 | + if (args.rank.IsParallel()) { |
| 83 | + const auto rank_dir = resume_dir / std::format("rank_{:06d}", args.rank.GlobalRank()); |
93 | 84 | if (std::filesystem::exists(rank_dir)) { |
94 | 85 | resume_dir = rank_dir; |
95 | 86 | } |
96 | 87 | } |
97 | 88 |
|
98 | | - Checkpoint::Load(resume_dir, model.get(), optimizer.get(), &state, model_bin_loader); |
| 89 | + Checkpoint::Load(resume_dir, args.model.get(), args.optimizer.get(), &args.state, args.load_options); |
99 | 90 |
|
100 | | - global_step = static_cast<int>(state.global_step); |
101 | | - best_loss = state.best_loss; |
102 | | - if (state.data_batch_stride != static_cast<int64_t>(ddp_world_size) && rank.IsMainRank()) { |
| 91 | + result.global_step = static_cast<int>(args.state.global_step); |
| 92 | + result.best_loss = args.state.best_loss; |
| 93 | + if (args.state.data_batch_stride != static_cast<int64_t>(ddp_world_size) && args.rank.IsMainRank()) { |
103 | 94 | LOG(WARNING) << std::format("Checkpoint data_batch_stride {} mismatches current ddp_world_size {}. " |
104 | 95 | "Proceeding with recorded data_batch_idx {}.", |
105 | | - state.data_batch_stride, ddp_world_size, state.data_batch_idx); |
| 96 | + args.state.data_batch_stride, ddp_world_size, args.state.data_batch_idx); |
106 | 97 | } |
107 | | - data_batch_idx = static_cast<size_t>(std::max<int64_t>(state.data_batch_idx, 0)); |
108 | | - train_iter = train_loader.IteratorAtBatchIndex(data_batch_idx); |
109 | | - if (rank.IsMainRank()) { |
| 98 | + result.data_batch_idx = static_cast<size_t>(std::max<int64_t>(args.state.data_batch_idx, 0)); |
| 99 | + args.train_iter = args.train_loader.IteratorAtBatchIndex(result.data_batch_idx); |
| 100 | + if (args.rank.IsMainRank()) { |
110 | 101 | LOG(INFO) << std::format( |
111 | | - "Resume training from step {} with best_loss {:.6f}, last_lr {:.3e}, data_batch_idx {}", state.global_step, |
112 | | - state.best_loss, state.last_lr, state.data_batch_idx); |
| 102 | + "Resume training from step {} with best_loss {:.6f}, last_lr {:.3e}, data_batch_idx {}", |
| 103 | + args.state.global_step, args.state.best_loss, args.state.last_lr, args.state.data_batch_idx); |
| 104 | + } |
| 105 | + |
| 106 | + return result; |
| 107 | +} |
| 108 | + |
| 109 | +void SaveCheckpoint(const SaveCheckpointArgs &args) { |
| 110 | + const auto ckpt_start = std::chrono::high_resolution_clock::now(); |
| 111 | + |
| 112 | + TrainerState state; |
| 113 | + state.global_step = args.global_step; |
| 114 | + state.data_batch_idx = static_cast<int64_t>(args.data_batch_idx); |
| 115 | + state.data_batch_stride = args.ddp_size; |
| 116 | + state.best_loss = args.best_loss; |
| 117 | + state.last_lr = args.last_lr; |
| 118 | + state.optimizer_type = args.optimizer_type; |
| 119 | + state.checkpoint_format = args.checkpoint_format; |
| 120 | + state.ddp_size = args.ddp_size; |
| 121 | + state.tp_size = args.tp_size; |
| 122 | + state.sp_size = args.sp_size; |
| 123 | + state.pp_size = args.pp_size; |
| 124 | + |
| 125 | + CheckpointOptions options; |
| 126 | + options.format = args.checkpoint_format; |
| 127 | + options.save_optimizer_state = args.save_optimizer_state; |
| 128 | + options.model_bin_writer = args.model_bin_writer; |
| 129 | + Checkpoint::Save(args.save_dir, args.model, args.optimizer, state, options); |
| 130 | + |
| 131 | + const auto ckpt_end = std::chrono::high_resolution_clock::now(); |
| 132 | + const double ckpt_ms = std::chrono::duration<double, std::milli>(ckpt_end - ckpt_start).count(); |
| 133 | + |
| 134 | + if (!args.rank.IsMainRank()) { |
| 135 | + return; |
113 | 136 | } |
114 | 137 |
|
115 | | - return {global_step, best_loss, data_batch_idx}; |
| 138 | + LOG(INFO) << std::format("Checkpoint saved at: {} ({:.2f} ms)", args.save_dir.string(), ckpt_ms); |
| 139 | + |
| 140 | + if (!args.prune_step_checkpoints) { |
| 141 | + return; |
| 142 | + } |
| 143 | + |
| 144 | + std::vector<std::filesystem::path> ckpts; |
| 145 | + if (std::filesystem::exists(args.checkpoint_root_dir)) { |
| 146 | + for (const auto &entry : std::filesystem::directory_iterator(args.checkpoint_root_dir)) { |
| 147 | + if (entry.is_directory() && entry.path().filename().string().starts_with("checkpoint_step_")) { |
| 148 | + ckpts.push_back(entry.path()); |
| 149 | + } |
| 150 | + } |
| 151 | + std::sort(ckpts.begin(), ckpts.end()); |
| 152 | + while (ckpts.size() > args.max_checkpoint_keep) { |
| 153 | + std::filesystem::remove_all(ckpts.front()); |
| 154 | + ckpts.erase(ckpts.begin()); |
| 155 | + } |
| 156 | + } |
116 | 157 | } |
117 | 158 |
|
118 | 159 | } // namespace infini_train |
0 commit comments