Skip to content

Commit f7e5cd7

Browse files
committed
feat: extract similar logic in ckpt_save
1 parent 08ed56b commit f7e5cd7

File tree

4 files changed

+187
-141
lines changed

4 files changed

+187
-141
lines changed

example/common/utils.cc

Lines changed: 71 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#include "example/common/utils.h"
22

3+
#include <algorithm>
4+
#include <chrono>
5+
36
#include "gflags/gflags.h"
47
#include "gflags/gflags_declare.h"
58
#include "glog/logging.h"
@@ -66,53 +69,91 @@ void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t s
6669
ifs.seekg(base + std::streamoff(len * sizeof(float)));
6770
}
6871

69-
std::tuple<int, float, size_t> ResumeFromCheckpoint(
70-
const fLS::clstring &flag_resume_root, // resume from this checkpoint directory
71-
const nn::parallel::Rank &rank, // rank info for distributed training
72-
std::shared_ptr<nn::Module> model, // model to be loaded with checkpoint state
73-
std::shared_ptr<Optimizer> optimizer, // some optimizer may not have state, but others may have
74-
DistributedDataLoader &train_loader, // distributed dataloader to be resumed
75-
TrainerState &state, // trainer state to be loaded from checkpoint
76-
DataLoaderIterator
77-
&train_iter, // dataloader iterator to be set to the correct position according to checkpoint state
78-
CheckpointLoadOptions model_bin_loader) {
79-
int global_step = 0;
80-
float best_loss = std::numeric_limits<float>::infinity();
81-
size_t data_batch_idx = 0;
82-
72+
ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args) {
73+
ResumeFromCheckpointResult result;
8374
int ddp_world_size = nn::parallel::global::GetDataParallelSize();
8475

85-
if (flag_resume_root.empty()) {
76+
if (args.resume_root.empty()) {
8677
LOG(INFO) << "No checkpoint specified for resume. Starting training from scratch.";
87-
return {global_step, best_loss, data_batch_idx};
78+
return result;
8879
}
8980

90-
std::filesystem::path resume_dir = flag_resume_root;
91-
if (rank.IsParallel()) {
92-
const auto rank_dir = resume_dir / std::format("rank_{:06d}", rank.GlobalRank());
81+
std::filesystem::path resume_dir = args.resume_root;
82+
if (args.rank.IsParallel()) {
83+
const auto rank_dir = resume_dir / std::format("rank_{:06d}", args.rank.GlobalRank());
9384
if (std::filesystem::exists(rank_dir)) {
9485
resume_dir = rank_dir;
9586
}
9687
}
9788

98-
Checkpoint::Load(resume_dir, model.get(), optimizer.get(), &state, model_bin_loader);
89+
Checkpoint::Load(resume_dir, args.model.get(), args.optimizer.get(), &args.state, args.load_options);
9990

100-
global_step = static_cast<int>(state.global_step);
101-
best_loss = state.best_loss;
102-
if (state.data_batch_stride != static_cast<int64_t>(ddp_world_size) && rank.IsMainRank()) {
91+
result.global_step = static_cast<int>(args.state.global_step);
92+
result.best_loss = args.state.best_loss;
93+
if (args.state.data_batch_stride != static_cast<int64_t>(ddp_world_size) && args.rank.IsMainRank()) {
10394
LOG(WARNING) << std::format("Checkpoint data_batch_stride {} mismatches current ddp_world_size {}. "
10495
"Proceeding with recorded data_batch_idx {}.",
105-
state.data_batch_stride, ddp_world_size, state.data_batch_idx);
96+
args.state.data_batch_stride, ddp_world_size, args.state.data_batch_idx);
10697
}
107-
data_batch_idx = static_cast<size_t>(std::max<int64_t>(state.data_batch_idx, 0));
108-
train_iter = train_loader.IteratorAtBatchIndex(data_batch_idx);
109-
if (rank.IsMainRank()) {
98+
result.data_batch_idx = static_cast<size_t>(std::max<int64_t>(args.state.data_batch_idx, 0));
99+
args.train_iter = args.train_loader.IteratorAtBatchIndex(result.data_batch_idx);
100+
if (args.rank.IsMainRank()) {
110101
LOG(INFO) << std::format(
111-
"Resume training from step {} with best_loss {:.6f}, last_lr {:.3e}, data_batch_idx {}", state.global_step,
112-
state.best_loss, state.last_lr, state.data_batch_idx);
102+
"Resume training from step {} with best_loss {:.6f}, last_lr {:.3e}, data_batch_idx {}",
103+
args.state.global_step, args.state.best_loss, args.state.last_lr, args.state.data_batch_idx);
104+
}
105+
106+
return result;
107+
}
108+
109+
void SaveCheckpoint(const SaveCheckpointArgs &args) {
110+
const auto ckpt_start = std::chrono::high_resolution_clock::now();
111+
112+
TrainerState state;
113+
state.global_step = args.global_step;
114+
state.data_batch_idx = static_cast<int64_t>(args.data_batch_idx);
115+
state.data_batch_stride = args.ddp_size;
116+
state.best_loss = args.best_loss;
117+
state.last_lr = args.last_lr;
118+
state.optimizer_type = args.optimizer_type;
119+
state.checkpoint_format = args.checkpoint_format;
120+
state.ddp_size = args.ddp_size;
121+
state.tp_size = args.tp_size;
122+
state.sp_size = args.sp_size;
123+
state.pp_size = args.pp_size;
124+
125+
CheckpointOptions options;
126+
options.format = args.checkpoint_format;
127+
options.save_optimizer_state = args.save_optimizer_state;
128+
options.model_bin_writer = args.model_bin_writer;
129+
Checkpoint::Save(args.save_dir, args.model, args.optimizer, state, options);
130+
131+
const auto ckpt_end = std::chrono::high_resolution_clock::now();
132+
const double ckpt_ms = std::chrono::duration<double, std::milli>(ckpt_end - ckpt_start).count();
133+
134+
if (!args.rank.IsMainRank()) {
135+
return;
113136
}
114137

115-
return {global_step, best_loss, data_batch_idx};
138+
LOG(INFO) << std::format("Checkpoint saved at: {} ({:.2f} ms)", args.save_dir.string(), ckpt_ms);
139+
140+
if (!args.prune_step_checkpoints) {
141+
return;
142+
}
143+
144+
std::vector<std::filesystem::path> ckpts;
145+
if (std::filesystem::exists(args.checkpoint_root_dir)) {
146+
for (const auto &entry : std::filesystem::directory_iterator(args.checkpoint_root_dir)) {
147+
if (entry.is_directory() && entry.path().filename().string().starts_with("checkpoint_step_")) {
148+
ckpts.push_back(entry.path());
149+
}
150+
}
151+
std::sort(ckpts.begin(), ckpts.end());
152+
while (ckpts.size() > args.max_checkpoint_keep) {
153+
std::filesystem::remove_all(ckpts.front());
154+
ckpts.erase(ckpts.begin());
155+
}
156+
}
116157
}
117158

118159
} // namespace infini_train

example/common/utils.h

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
#include <filesystem>
1414
#include <fstream>
1515
#include <functional>
16-
#include <tuple>
16+
#include <limits>
17+
#include <string>
1718
#include <vector>
1819

1920
namespace infini_train {
@@ -41,19 +42,47 @@ void ReadVectorAllFloat(std::ifstream &ifs, float *dst, int64_t len);
4142

4243
void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t start, int64_t cnt);
4344

44-
/**
45-
* @returns a tuple of (global_step, best_loss, data_batch_idx) loaded from the checkpoint, which can be used to resume
46-
* training.
47-
*/
48-
std::tuple<int, float, size_t> ResumeFromCheckpoint(
49-
const fLS::clstring &flag_resume_root, // resume from this checkpoint directory
50-
const nn::parallel::Rank &rank, // rank info for distributed training
51-
std::shared_ptr<nn::Module> model, // model to be loaded with checkpoint state
52-
std::shared_ptr<Optimizer> optimizer, // some optimizer may not have state, but others may have
53-
DistributedDataLoader &train_loader, // distributed dataloader to be resumed
54-
TrainerState &state, // trainer state to be loaded from checkpoint
55-
DataLoaderIterator
56-
&train_iter, // dataloader iterator to be set to the correct position according to checkpoint state
57-
CheckpointLoadOptions model_bin_loader);
45+
struct ResumeFromCheckpointArgs {
46+
fLS::clstring resume_root;
47+
const nn::parallel::Rank &rank;
48+
std::shared_ptr<nn::Module> model;
49+
std::shared_ptr<Optimizer> optimizer;
50+
DistributedDataLoader &train_loader;
51+
TrainerState &state;
52+
DataLoaderIterator &train_iter;
53+
CheckpointLoadOptions load_options;
54+
};
55+
56+
struct ResumeFromCheckpointResult {
57+
int global_step = 0;
58+
float best_loss = std::numeric_limits<float>::infinity();
59+
size_t data_batch_idx = 0;
60+
};
61+
62+
struct SaveCheckpointArgs {
63+
std::filesystem::path save_dir;
64+
int64_t global_step = 0;
65+
size_t data_batch_idx = 0;
66+
float best_loss = std::numeric_limits<float>::infinity();
67+
double last_lr = 0.0;
68+
std::string optimizer_type;
69+
std::string checkpoint_format = "bin";
70+
int ddp_size = 1;
71+
int tp_size = 1;
72+
int sp_size = 1;
73+
int pp_size = 1;
74+
bool save_optimizer_state = true;
75+
bool prune_step_checkpoints = false;
76+
std::filesystem::path checkpoint_root_dir;
77+
size_t max_checkpoint_keep = 0;
78+
const nn::parallel::Rank &rank;
79+
const nn::Module &model;
80+
const Optimizer &optimizer;
81+
std::function<void(const nn::Module &, const std::filesystem::path &)> model_bin_writer;
82+
};
83+
84+
ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args);
85+
86+
void SaveCheckpoint(const SaveCheckpointArgs &args);
5887

5988
} // namespace infini_train

example/gpt2/main.cc

Lines changed: 36 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -343,57 +343,45 @@ void Train(const nn::parallel::Rank &rank) {
343343
auto loaded_model = GPT2::FromLLMC(model_path.string());
344344
target_model->LoadStateDict(loaded_model->StateDict());
345345
};
346-
std::tie(start_step, best_loss, saved_data_batch_idx) = infini_train::ResumeFromCheckpoint(
347-
FLAGS_resume_from, rank, model, optimizer, train_loader, state, train_iter, load_options);
346+
const auto resume_result = infini_train::ResumeFromCheckpoint({
347+
.resume_root = FLAGS_resume_from,
348+
.rank = rank,
349+
.model = model,
350+
.optimizer = optimizer,
351+
.train_loader = train_loader,
352+
.state = state,
353+
.train_iter = train_iter,
354+
.load_options = load_options,
355+
});
356+
start_step = resume_result.global_step;
357+
best_loss = resume_result.best_loss;
358+
saved_data_batch_idx = resume_result.data_batch_idx;
348359

349360
auto save_checkpoint = [&](const std::filesystem::path &save_dir, int64_t global_step,
350361
bool prune_step_checkpoints) {
351-
const auto ckpt_start = std::chrono::high_resolution_clock::now();
352-
353-
TrainerState state;
354-
state.global_step = global_step;
355-
state.data_batch_idx = saved_data_batch_idx;
356-
state.data_batch_stride = ddp_world_size;
357-
state.best_loss = best_loss;
358-
state.last_lr = FLAGS_learning_rate;
359-
state.optimizer_type = "SGD";
360-
state.checkpoint_format = FLAGS_checkpoint_format;
361-
state.ddp_size = ddp_world_size;
362-
state.tp_size = tp_world_size;
363-
state.sp_size = sp_world_size;
364-
state.pp_size = pp_world_size;
365-
366-
CheckpointOptions options;
367-
options.format = FLAGS_checkpoint_format;
368-
options.save_optimizer_state = FLAGS_save_optimizer_state;
369-
options.model_bin_writer = [&](const nn::Module &, const std::filesystem::path &model_path) {
370-
llmc_model->SaveAsLLMC(model_path.string());
371-
};
372-
Checkpoint::Save(save_dir, *model, *optimizer, state, options);
373-
374-
const auto ckpt_end = std::chrono::high_resolution_clock::now();
375-
const double ckpt_ms = std::chrono::duration<double, std::milli>(ckpt_end - ckpt_start).count();
376-
377-
if (rank.IsMainRank()) {
378-
LOG(INFO) << std::format("Checkpoint saved at: {} ({:.2f} ms)", save_dir.string(), ckpt_ms);
379-
380-
if (prune_step_checkpoints) {
381-
std::vector<std::filesystem::path> ckpts;
382-
const auto root = std::filesystem::path(FLAGS_checkpoint_dir);
383-
if (std::filesystem::exists(root)) {
384-
for (const auto &entry : std::filesystem::directory_iterator(root)) {
385-
if (entry.is_directory() && entry.path().filename().string().starts_with("checkpoint_step_")) {
386-
ckpts.push_back(entry.path());
387-
}
388-
}
389-
std::sort(ckpts.begin(), ckpts.end());
390-
while (ckpts.size() > FLAGS_max_checkpoint_keep) {
391-
std::filesystem::remove_all(ckpts.front());
392-
ckpts.erase(ckpts.begin());
393-
}
394-
}
395-
}
396-
}
362+
infini_train::SaveCheckpoint({
363+
.save_dir = save_dir,
364+
.global_step = global_step,
365+
.data_batch_idx = saved_data_batch_idx,
366+
.best_loss = best_loss,
367+
.last_lr = FLAGS_learning_rate,
368+
.optimizer_type = "SGD",
369+
.checkpoint_format = FLAGS_checkpoint_format,
370+
.ddp_size = ddp_world_size,
371+
.tp_size = tp_world_size,
372+
.sp_size = sp_world_size,
373+
.pp_size = pp_world_size,
374+
.save_optimizer_state = FLAGS_save_optimizer_state,
375+
.prune_step_checkpoints = prune_step_checkpoints,
376+
.checkpoint_root_dir = FLAGS_checkpoint_dir,
377+
.max_checkpoint_keep = FLAGS_max_checkpoint_keep,
378+
.rank = rank,
379+
.model = *model,
380+
.optimizer = *optimizer,
381+
.model_bin_writer = [&](const nn::Module &, const std::filesystem::path &model_path) {
382+
llmc_model->SaveAsLLMC(model_path.string());
383+
},
384+
});
397385
};
398386

399387
for (int step = start_step; step < FLAGS_num_iteration + 1; ++step) {

example/llama3/main.cc

Lines changed: 36 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -314,57 +314,45 @@ void Train(const nn::parallel::Rank &rank) {
314314
auto loaded_model = LLaMA3::FromLLMC(model_path.string());
315315
target_model->LoadStateDict(loaded_model->StateDict());
316316
};
317-
std::tie(start_step, best_loss, saved_data_batch_idx) = infini_train::ResumeFromCheckpoint(
318-
FLAGS_resume_from, rank, model, optimizer, train_loader, state, train_iter, load_options);
317+
const auto resume_result = infini_train::ResumeFromCheckpoint({
318+
.resume_root = FLAGS_resume_from,
319+
.rank = rank,
320+
.model = model,
321+
.optimizer = optimizer,
322+
.train_loader = train_loader,
323+
.state = state,
324+
.train_iter = train_iter,
325+
.load_options = load_options,
326+
});
327+
start_step = resume_result.global_step;
328+
best_loss = resume_result.best_loss;
329+
saved_data_batch_idx = resume_result.data_batch_idx;
319330

320331
auto save_checkpoint = [&](const std::filesystem::path &save_dir, int64_t global_step,
321332
bool prune_step_checkpoints) {
322-
const auto ckpt_start = std::chrono::high_resolution_clock::now();
323-
324-
TrainerState state;
325-
state.global_step = global_step;
326-
state.data_batch_idx = saved_data_batch_idx;
327-
state.data_batch_stride = ddp_world_size;
328-
state.best_loss = best_loss;
329-
state.last_lr = FLAGS_learning_rate;
330-
state.optimizer_type = "Adam";
331-
state.checkpoint_format = FLAGS_checkpoint_format;
332-
state.ddp_size = ddp_world_size;
333-
state.tp_size = tp_world_size;
334-
state.sp_size = sp_world_size;
335-
state.pp_size = pp_world_size;
336-
337-
CheckpointOptions options;
338-
options.format = FLAGS_checkpoint_format;
339-
options.save_optimizer_state = FLAGS_save_optimizer_state;
340-
options.model_bin_writer = [&](const nn::Module &, const std::filesystem::path &model_path) {
341-
llmc_model->SaveAsLLMC(model_path.string());
342-
};
343-
Checkpoint::Save(save_dir, *model, *optimizer, state, options);
344-
345-
const auto ckpt_end = std::chrono::high_resolution_clock::now();
346-
const double ckpt_ms = std::chrono::duration<double, std::milli>(ckpt_end - ckpt_start).count();
347-
348-
if (rank.IsMainRank()) {
349-
LOG(INFO) << std::format("Checkpoint saved at: {} ({:.2f} ms)", save_dir.string(), ckpt_ms);
350-
351-
if (prune_step_checkpoints) {
352-
std::vector<std::filesystem::path> ckpts;
353-
const auto root = std::filesystem::path(FLAGS_checkpoint_dir);
354-
if (std::filesystem::exists(root)) {
355-
for (const auto &entry : std::filesystem::directory_iterator(root)) {
356-
if (entry.is_directory() && entry.path().filename().string().starts_with("checkpoint_step_")) {
357-
ckpts.push_back(entry.path());
358-
}
359-
}
360-
std::sort(ckpts.begin(), ckpts.end());
361-
while (ckpts.size() > FLAGS_max_checkpoint_keep) {
362-
std::filesystem::remove_all(ckpts.front());
363-
ckpts.erase(ckpts.begin());
364-
}
365-
}
366-
}
367-
}
333+
infini_train::SaveCheckpoint({
334+
.save_dir = save_dir,
335+
.global_step = global_step,
336+
.data_batch_idx = saved_data_batch_idx,
337+
.best_loss = best_loss,
338+
.last_lr = FLAGS_learning_rate,
339+
.optimizer_type = "Adam",
340+
.checkpoint_format = FLAGS_checkpoint_format,
341+
.ddp_size = ddp_world_size,
342+
.tp_size = tp_world_size,
343+
.sp_size = sp_world_size,
344+
.pp_size = pp_world_size,
345+
.save_optimizer_state = FLAGS_save_optimizer_state,
346+
.prune_step_checkpoints = prune_step_checkpoints,
347+
.checkpoint_root_dir = FLAGS_checkpoint_dir,
348+
.max_checkpoint_keep = FLAGS_max_checkpoint_keep,
349+
.rank = rank,
350+
.model = *model,
351+
.optimizer = *optimizer,
352+
.model_bin_writer = [&](const nn::Module &, const std::filesystem::path &model_path) {
353+
llmc_model->SaveAsLLMC(model_path.string());
354+
},
355+
});
368356
};
369357

370358
for (int step = start_step; step < FLAGS_num_iteration + 1; ++step) {

0 commit comments

Comments
 (0)