Skip to content

Commit 08ed56b

Browse files
committed
feat: extract resuming to utils
remove redundent arguments
1 parent 4b248b6 commit 08ed56b

File tree

4 files changed

+103
-77
lines changed

4 files changed

+103
-77
lines changed

example/common/utils.cc

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
#include "example/common/utils.h"
22

3+
#include "gflags/gflags.h"
4+
#include "gflags/gflags_declare.h"
5+
#include "glog/logging.h"
6+
#include "infini_train/include/nn/parallel/global.h"
7+
38
namespace infini_train {
49

510
float ConvertBF16ToFloat(void *ptr) {
@@ -61,4 +66,53 @@ void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t s
6166
ifs.seekg(base + std::streamoff(len * sizeof(float)));
6267
}
6368

69+
std::tuple<int, float, size_t> ResumeFromCheckpoint(
70+
const fLS::clstring &flag_resume_root, // resume from this checkpoint directory
71+
const nn::parallel::Rank &rank, // rank info for distributed training
72+
std::shared_ptr<nn::Module> model, // model to be loaded with checkpoint state
73+
std::shared_ptr<Optimizer> optimizer, // some optimizer may not have state, but others may have
74+
DistributedDataLoader &train_loader, // distributed dataloader to be resumed
75+
TrainerState &state, // trainer state to be loaded from checkpoint
76+
DataLoaderIterator
77+
&train_iter, // dataloader iterator to be set to the correct position according to checkpoint state
78+
CheckpointLoadOptions model_bin_loader) {
79+
int global_step = 0;
80+
float best_loss = std::numeric_limits<float>::infinity();
81+
size_t data_batch_idx = 0;
82+
83+
int ddp_world_size = nn::parallel::global::GetDataParallelSize();
84+
85+
if (flag_resume_root.empty()) {
86+
LOG(INFO) << "No checkpoint specified for resume. Starting training from scratch.";
87+
return {global_step, best_loss, data_batch_idx};
88+
}
89+
90+
std::filesystem::path resume_dir = flag_resume_root;
91+
if (rank.IsParallel()) {
92+
const auto rank_dir = resume_dir / std::format("rank_{:06d}", rank.GlobalRank());
93+
if (std::filesystem::exists(rank_dir)) {
94+
resume_dir = rank_dir;
95+
}
96+
}
97+
98+
Checkpoint::Load(resume_dir, model.get(), optimizer.get(), &state, model_bin_loader);
99+
100+
global_step = static_cast<int>(state.global_step);
101+
best_loss = state.best_loss;
102+
if (state.data_batch_stride != static_cast<int64_t>(ddp_world_size) && rank.IsMainRank()) {
103+
LOG(WARNING) << std::format("Checkpoint data_batch_stride {} mismatches current ddp_world_size {}. "
104+
"Proceeding with recorded data_batch_idx {}.",
105+
state.data_batch_stride, ddp_world_size, state.data_batch_idx);
106+
}
107+
data_batch_idx = static_cast<size_t>(std::max<int64_t>(state.data_batch_idx, 0));
108+
train_iter = train_loader.IteratorAtBatchIndex(data_batch_idx);
109+
if (rank.IsMainRank()) {
110+
LOG(INFO) << std::format(
111+
"Resume training from step {} with best_loss {:.6f}, last_lr {:.3e}, data_batch_idx {}", state.global_step,
112+
state.best_loss, state.last_lr, state.data_batch_idx);
113+
}
114+
115+
return {global_step, best_loss, data_batch_idx};
116+
}
117+
64118
} // namespace infini_train

example/common/utils.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
11
#pragma once
22

3+
#include "infini_train/include/checkpoint.h"
4+
#include "infini_train/include/dataloader.h"
5+
#include "infini_train/include/nn/modules/module.h"
6+
#include "infini_train/include/nn/parallel/rank.h"
7+
#include "infini_train/include/optimizer.h"
8+
9+
#include "gflags/gflags.h"
10+
311
#include <cstdint>
412
#include <cstring>
13+
#include <filesystem>
514
#include <fstream>
15+
#include <functional>
16+
#include <tuple>
617
#include <vector>
718

819
namespace infini_train {
@@ -30,4 +41,19 @@ void ReadVectorAllFloat(std::ifstream &ifs, float *dst, int64_t len);
3041

3142
void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t start, int64_t cnt);
3243

44+
/**
45+
* @returns a tuple of (global_step, best_loss, data_batch_idx) loaded from the checkpoint, which can be used to resume
46+
* training.
47+
*/
48+
std::tuple<int, float, size_t> ResumeFromCheckpoint(
49+
const fLS::clstring &flag_resume_root, // resume from this checkpoint directory
50+
const nn::parallel::Rank &rank, // rank info for distributed training
51+
std::shared_ptr<nn::Module> model, // model to be loaded with checkpoint state
52+
std::shared_ptr<Optimizer> optimizer, // some optimizer may not have state, but others may have
53+
DistributedDataLoader &train_loader, // distributed dataloader to be resumed
54+
TrainerState &state, // trainer state to be loaded from checkpoint
55+
DataLoaderIterator
56+
&train_iter, // dataloader iterator to be set to the correct position according to checkpoint state
57+
CheckpointLoadOptions model_bin_loader);
58+
3359
} // namespace infini_train

example/gpt2/main.cc

Lines changed: 13 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <unordered_map>
1010
#include <unordered_set>
1111

12+
#include "example/common/utils.h"
1213
#include "gflags/gflags.h"
1314
#include "glog/logging.h"
1415

@@ -85,8 +86,6 @@ DEFINE_string(checkpoint_dir, "./checkpoints", "root directory used to store che
8586
DEFINE_uint32(max_checkpoint_keep, 3, "max number of checkpoint steps to keep");
8687
DEFINE_bool(save_optimizer_state, true, "whether optimizer state is persisted in checkpoints");
8788
DEFINE_string(checkpoint_format, "bin", "checkpoint format: bin|pth");
88-
DEFINE_bool(use_llmc_checkpoint_io, false,
89-
"whether to use GPT2 LLMC model.bin callback for checkpoint save/load when format=bin");
9089
// precision check
9190
DEFINE_string(
9291
precision_check, "",
@@ -337,44 +336,15 @@ void Train(const nn::parallel::Rank &rank) {
337336

338337
int start_step = 0;
339338
float best_loss = std::numeric_limits<float>::infinity();
340-
if (!FLAGS_resume_from.empty()) {
341-
std::filesystem::path resume_dir = FLAGS_resume_from;
342-
if (rank.IsParallel()) {
343-
const auto rank_dir = resume_dir / std::format("rank_{:06d}", rank.GlobalRank());
344-
if (std::filesystem::exists(rank_dir)) {
345-
resume_dir = rank_dir;
346-
}
347-
}
348-
349-
TrainerState state;
350-
CheckpointLoadOptions load_options;
351-
load_options.load_optimizer_state = true;
352-
if (FLAGS_use_llmc_checkpoint_io) {
353-
load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) {
354-
auto loaded_model = GPT2::FromLLMC(model_path.string());
355-
target_model->LoadStateDict(loaded_model->StateDict());
356-
};
357-
}
358-
Checkpoint::Load(resume_dir, model.get(), optimizer.get(), &state, load_options);
359-
start_step = static_cast<int>(state.global_step);
360-
best_loss = state.best_loss;
361-
if (state.data_batch_stride != static_cast<int64_t>(ddp_world_size) && rank.IsMainRank()) {
362-
LOG(WARNING) << std::format("Checkpoint data_batch_stride {} mismatches current ddp_world_size {}. "
363-
"Proceeding with recorded data_batch_idx {}.",
364-
state.data_batch_stride, ddp_world_size, state.data_batch_idx);
365-
}
366-
saved_data_batch_idx = static_cast<size_t>(std::max<int64_t>(state.data_batch_idx, 0));
367-
train_iter = train_loader.IteratorAtBatchIndex(saved_data_batch_idx);
368-
if (rank.IsMainRank()) {
369-
LOG(INFO) << std::format(
370-
"Resume training from step {} with best_loss {:.6f}, last_lr {:.3e}, data_batch_idx {}",
371-
state.global_step, state.best_loss, state.last_lr, state.data_batch_idx);
372-
LOG(INFO) << std::format("Checkpoint model I/O mode during resume: {}",
373-
FLAGS_use_llmc_checkpoint_io ? "llmc-callback" : "native-state-dict");
374-
}
375-
}
376-
377-
LOG(INFO) << "start training";
339+
TrainerState state;
340+
CheckpointLoadOptions load_options;
341+
load_options.load_optimizer_state = true;
342+
load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) {
343+
auto loaded_model = GPT2::FromLLMC(model_path.string());
344+
target_model->LoadStateDict(loaded_model->StateDict());
345+
};
346+
std::tie(start_step, best_loss, saved_data_batch_idx) = infini_train::ResumeFromCheckpoint(
347+
FLAGS_resume_from, rank, model, optimizer, train_loader, state, train_iter, load_options);
378348

379349
auto save_checkpoint = [&](const std::filesystem::path &save_dir, int64_t global_step,
380350
bool prune_step_checkpoints) {
@@ -396,11 +366,9 @@ void Train(const nn::parallel::Rank &rank) {
396366
CheckpointOptions options;
397367
options.format = FLAGS_checkpoint_format;
398368
options.save_optimizer_state = FLAGS_save_optimizer_state;
399-
if (FLAGS_use_llmc_checkpoint_io) {
400-
options.model_bin_writer = [&](const nn::Module &, const std::filesystem::path &model_path) {
401-
llmc_model->SaveAsLLMC(model_path.string());
402-
};
403-
}
369+
options.model_bin_writer = [&](const nn::Module &, const std::filesystem::path &model_path) {
370+
llmc_model->SaveAsLLMC(model_path.string());
371+
};
404372
Checkpoint::Save(save_dir, *model, *optimizer, state, options);
405373

406374
const auto ckpt_end = std::chrono::high_resolution_clock::now();

example/llama3/main.cc

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <optional>
88
#include <unordered_set>
99

10+
#include "example/common/utils.h"
1011
#include "gflags/gflags.h"
1112
#include "glog/logging.h"
1213

@@ -306,38 +307,15 @@ void Train(const nn::parallel::Rank &rank) {
306307

307308
int start_step = 0;
308309
float best_loss = std::numeric_limits<float>::infinity();
309-
if (!FLAGS_resume_from.empty()) {
310-
std::filesystem::path resume_dir = FLAGS_resume_from;
311-
if (rank.IsParallel()) {
312-
const auto rank_dir = resume_dir / std::format("rank_{:06d}", rank.GlobalRank());
313-
if (std::filesystem::exists(rank_dir)) {
314-
resume_dir = rank_dir;
315-
}
316-
}
317-
318-
TrainerState state;
319-
CheckpointLoadOptions load_options;
320-
load_options.load_optimizer_state = true;
321-
load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) {
322-
auto loaded_model = LLaMA3::FromLLMC(model_path.string());
323-
target_model->LoadStateDict(loaded_model->StateDict());
324-
};
325-
Checkpoint::Load(resume_dir, model.get(), optimizer.get(), &state, load_options);
326-
start_step = static_cast<int>(state.global_step);
327-
best_loss = state.best_loss;
328-
if (state.data_batch_stride != static_cast<int64_t>(ddp_world_size) && rank.IsMainRank()) {
329-
LOG(WARNING) << std::format("Checkpoint data_batch_stride {} mismatches current ddp_world_size {}. "
330-
"Proceeding with recorded data_batch_idx {}.",
331-
state.data_batch_stride, ddp_world_size, state.data_batch_idx);
332-
}
333-
saved_data_batch_idx = static_cast<size_t>(std::max<int64_t>(state.data_batch_idx, 0));
334-
train_iter = train_loader.IteratorAtBatchIndex(saved_data_batch_idx);
335-
if (rank.IsMainRank()) {
336-
LOG(INFO) << std::format(
337-
"Resume training from step {} with best_loss {:.6f}, last_lr {:.3e}, data_batch_idx {}",
338-
state.global_step, state.best_loss, state.last_lr, state.data_batch_idx);
339-
}
340-
}
310+
TrainerState state;
311+
CheckpointLoadOptions load_options;
312+
load_options.load_optimizer_state = true;
313+
load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) {
314+
auto loaded_model = LLaMA3::FromLLMC(model_path.string());
315+
target_model->LoadStateDict(loaded_model->StateDict());
316+
};
317+
std::tie(start_step, best_loss, saved_data_batch_idx) = infini_train::ResumeFromCheckpoint(
318+
FLAGS_resume_from, rank, model, optimizer, train_loader, state, train_iter, load_options);
341319

342320
auto save_checkpoint = [&](const std::filesystem::path &save_dir, int64_t global_step,
343321
bool prune_step_checkpoints) {

0 commit comments

Comments
 (0)