Skip to content

Commit e8c5dd5

Browse files
committed
feat(checkpoint): reorganize checkpoint code and improve robustness
1 parent b363779 commit e8c5dd5

12 files changed

Lines changed: 609 additions & 628 deletions

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ add_executable(gpt2
176176
example/gpt2/main.cc
177177
example/common/tiny_shakespeare_dataset.cc
178178
example/common/utils.cc
179-
example/gpt2/checkpoint_loader.cc
179+
example/common/checkpoint_loader.cc
180180
example/common/tokenizer.cc
181181
)
182182
link_infini_train_exe(gpt2)
@@ -185,7 +185,7 @@ add_executable(llama3
185185
example/llama3/main.cc
186186
example/common/tiny_shakespeare_dataset.cc
187187
example/common/utils.cc
188-
example/llama3/checkpoint_loader.cc
188+
example/common/checkpoint_loader.cc
189189
example/common/tokenizer.cc
190190
)
191191
link_infini_train_exe(llama3)
Lines changed: 517 additions & 11 deletions
Large diffs are not rendered by default.

example/common/checkpoint_loader.h

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#pragma once
2+
3+
#include "infini_train/include/checkpoint.h"
4+
#include "infini_train/include/dataloader.h"
5+
#include "infini_train/include/nn/modules/module.h"
6+
#include "infini_train/include/nn/parallel/rank.h"
7+
#include "infini_train/include/optimizer.h"
8+
9+
#include "gflags/gflags.h"
10+
11+
#include <cstdint>
12+
#include <cstring>
13+
#include <filesystem>
14+
15+
#include <functional>
16+
#include <limits>
17+
#include <string>
18+
19+
namespace infini_train {
20+
namespace nn {
21+
class TransformerModel;
22+
}
23+
24+
namespace gpt2 {
25+
std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath);
26+
void SaveAsLLMC(const std::shared_ptr<nn::TransformerModel> &model, const std::string &filepath);
27+
} // namespace gpt2
28+
namespace llama3 {
29+
std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath);
30+
void SaveAsLLMC(const std::shared_ptr<nn::TransformerModel> &model, const std::string &filepath);
31+
} // namespace llama3
32+
33+
struct ResumeFromCheckpointArgs {
34+
fLS::clstring resume_root;
35+
const nn::parallel::Rank &rank;
36+
std::shared_ptr<nn::Module> model;
37+
std::shared_ptr<Optimizer> optimizer;
38+
DistributedDataLoader &train_loader;
39+
TrainerState &state;
40+
DataLoaderIterator &train_iter;
41+
CheckpointLoadOptions load_options;
42+
};
43+
44+
struct ResumeFromCheckpointResult {
45+
int global_step = 0;
46+
float best_loss = std::numeric_limits<float>::infinity();
47+
size_t data_batch_idx = 0;
48+
};
49+
50+
struct SaveCheckpointArgs {
51+
std::filesystem::path save_dir;
52+
int64_t global_step = 0;
53+
size_t data_batch_idx = 0;
54+
float best_loss = std::numeric_limits<float>::infinity();
55+
double last_lr = 0.0;
56+
std::string optimizer_type;
57+
std::string checkpoint_format = "bin";
58+
int ddp_size = 1;
59+
int tp_size = 1;
60+
int sp_size = 1;
61+
int pp_size = 1;
62+
bool save_optimizer_state = true;
63+
bool prune_step_checkpoints = false;
64+
std::filesystem::path checkpoint_root_dir;
65+
size_t max_checkpoint_keep = 0;
66+
const nn::parallel::Rank &rank;
67+
const nn::Module &model;
68+
const Optimizer &optimizer;
69+
std::function<void(const nn::Module &, const std::filesystem::path &)> model_bin_writer;
70+
};
71+
72+
ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args);
73+
74+
void SaveCheckpoint(const SaveCheckpointArgs &args);
75+
76+
} // namespace infini_train

example/common/utils.cc

Lines changed: 0 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -68,92 +68,4 @@ void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t s
6868
ifs.read(reinterpret_cast<char *>(dst), static_cast<std::streamsize>(cnt * sizeof(float)));
6969
ifs.seekg(base + std::streamoff(len * sizeof(float)));
7070
}
71-
72-
ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args) {
73-
ResumeFromCheckpointResult result;
74-
int ddp_world_size = nn::parallel::global::GetDataParallelSize();
75-
76-
if (args.resume_root.empty()) {
77-
LOG(INFO) << "No checkpoint specified for resume. Starting training from scratch.";
78-
return result;
79-
}
80-
81-
std::filesystem::path resume_dir = args.resume_root;
82-
if (args.rank.IsParallel()) {
83-
const auto rank_dir = resume_dir / std::format("rank_{:06d}", args.rank.GlobalRank());
84-
if (std::filesystem::exists(rank_dir)) {
85-
resume_dir = rank_dir;
86-
}
87-
}
88-
89-
Checkpoint::Load(resume_dir, args.model.get(), args.optimizer.get(), &args.state, args.load_options);
90-
91-
result.global_step = static_cast<int>(args.state.global_step);
92-
result.best_loss = args.state.best_loss;
93-
if (args.state.data_batch_stride != static_cast<int64_t>(ddp_world_size) && args.rank.IsMainRank()) {
94-
LOG(WARNING) << std::format("Checkpoint data_batch_stride {} mismatches current ddp_world_size {}. "
95-
"Proceeding with recorded data_batch_idx {}.",
96-
args.state.data_batch_stride, ddp_world_size, args.state.data_batch_idx);
97-
}
98-
result.data_batch_idx = static_cast<size_t>(std::max<int64_t>(args.state.data_batch_idx, 0));
99-
args.train_iter = args.train_loader.IteratorAtBatchIndex(result.data_batch_idx);
100-
if (args.rank.IsMainRank()) {
101-
LOG(INFO) << std::format(
102-
"Resume training from step {} with best_loss {:.6f}, last_lr {:.3e}, data_batch_idx {}",
103-
args.state.global_step, args.state.best_loss, args.state.last_lr, args.state.data_batch_idx);
104-
}
105-
106-
return result;
107-
}
108-
109-
void SaveCheckpoint(const SaveCheckpointArgs &args) {
110-
const auto ckpt_start = std::chrono::high_resolution_clock::now();
111-
112-
TrainerState state;
113-
state.global_step = args.global_step;
114-
state.data_batch_idx = static_cast<int64_t>(args.data_batch_idx);
115-
state.data_batch_stride = args.ddp_size;
116-
state.best_loss = args.best_loss;
117-
state.last_lr = args.last_lr;
118-
state.optimizer_type = args.optimizer_type;
119-
state.checkpoint_format = args.checkpoint_format;
120-
state.ddp_size = args.ddp_size;
121-
state.tp_size = args.tp_size;
122-
state.sp_size = args.sp_size;
123-
state.pp_size = args.pp_size;
124-
125-
CheckpointOptions options;
126-
options.format = args.checkpoint_format;
127-
options.save_optimizer_state = args.save_optimizer_state;
128-
options.model_bin_writer = args.model_bin_writer;
129-
Checkpoint::Save(args.save_dir, args.model, args.optimizer, state, options);
130-
131-
const auto ckpt_end = std::chrono::high_resolution_clock::now();
132-
const double ckpt_ms = std::chrono::duration<double, std::milli>(ckpt_end - ckpt_start).count();
133-
134-
if (!args.rank.IsMainRank()) {
135-
return;
136-
}
137-
138-
LOG(INFO) << std::format("Checkpoint saved at: {} ({:.2f} ms)", args.save_dir.string(), ckpt_ms);
139-
140-
if (!args.prune_step_checkpoints) {
141-
return;
142-
}
143-
144-
std::vector<std::filesystem::path> ckpts;
145-
if (std::filesystem::exists(args.checkpoint_root_dir)) {
146-
for (const auto &entry : std::filesystem::directory_iterator(args.checkpoint_root_dir)) {
147-
if (entry.is_directory() && entry.path().filename().string().starts_with("checkpoint_step_")) {
148-
ckpts.push_back(entry.path());
149-
}
150-
}
151-
std::sort(ckpts.begin(), ckpts.end());
152-
while (ckpts.size() > args.max_checkpoint_keep) {
153-
std::filesystem::remove_all(ckpts.front());
154-
ckpts.erase(ckpts.begin());
155-
}
156-
}
157-
}
158-
15971
} // namespace infini_train

example/common/utils.h

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -42,47 +42,4 @@ void ReadVectorAllFloat(std::ifstream &ifs, float *dst, int64_t len);
4242

4343
void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t start, int64_t cnt);
4444

45-
struct ResumeFromCheckpointArgs {
46-
fLS::clstring resume_root;
47-
const nn::parallel::Rank &rank;
48-
std::shared_ptr<nn::Module> model;
49-
std::shared_ptr<Optimizer> optimizer;
50-
DistributedDataLoader &train_loader;
51-
TrainerState &state;
52-
DataLoaderIterator &train_iter;
53-
CheckpointLoadOptions load_options;
54-
};
55-
56-
struct ResumeFromCheckpointResult {
57-
int global_step = 0;
58-
float best_loss = std::numeric_limits<float>::infinity();
59-
size_t data_batch_idx = 0;
60-
};
61-
62-
struct SaveCheckpointArgs {
63-
std::filesystem::path save_dir;
64-
int64_t global_step = 0;
65-
size_t data_batch_idx = 0;
66-
float best_loss = std::numeric_limits<float>::infinity();
67-
double last_lr = 0.0;
68-
std::string optimizer_type;
69-
std::string checkpoint_format = "bin";
70-
int ddp_size = 1;
71-
int tp_size = 1;
72-
int sp_size = 1;
73-
int pp_size = 1;
74-
bool save_optimizer_state = true;
75-
bool prune_step_checkpoints = false;
76-
std::filesystem::path checkpoint_root_dir;
77-
size_t max_checkpoint_keep = 0;
78-
const nn::parallel::Rank &rank;
79-
const nn::Module &model;
80-
const Optimizer &optimizer;
81-
std::function<void(const nn::Module &, const std::filesystem::path &)> model_bin_writer;
82-
};
83-
84-
ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args);
85-
86-
void SaveCheckpoint(const SaveCheckpointArgs &args);
87-
8845
} // namespace infini_train

example/gpt2/checkpoint_loader.h

Lines changed: 0 additions & 13 deletions
This file was deleted.

example/gpt2/config.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
44

5-
namespace nn = infini_train::nn;
5+
namespace infini_train {
66
namespace gpt2 {
77
inline nn::TransformerConfig GPT2Config() {
88
return {.block_size = 1024,
@@ -22,5 +22,5 @@ inline nn::TransformerConfig GPT2Config() {
2222
.ffn_dim_multiplier = std::nullopt,
2323
.multiple_of = 1};
2424
}
25-
2625
} // namespace gpt2
26+
} // namespace infini_train

example/gpt2/main.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include <unordered_map>
1010
#include <unordered_set>
1111

12-
#include "example/common/utils.h"
1312
#include "gflags/gflags.h"
1413
#include "glog/logging.h"
1514

@@ -39,9 +38,9 @@
3938
#include "infini_train/include/utils/precision_check_config.h"
4039
#include "infini_train/include/utils/precision_checker.h"
4140

41+
#include "example/common/checkpoint_loader.h"
4242
#include "example/common/tiny_shakespeare_dataset.h"
4343
#include "example/common/tokenizer.h"
44-
#include "example/gpt2/checkpoint_loader.h"
4544
#include "example/gpt2/config.h"
4645

4746
// I/O
@@ -87,7 +86,10 @@ DEFINE_string(resume_from, "", "checkpoint directory to resume from");
8786
DEFINE_string(checkpoint_dir, "./checkpoints", "root directory used to store checkpoints");
8887
DEFINE_uint32(max_checkpoint_keep, 3, "max number of checkpoint steps to keep");
8988
DEFINE_bool(save_optimizer_state, true, "whether optimizer state is persisted in checkpoints");
90-
DEFINE_string(checkpoint_format, "bin", "checkpoint format: bin|pth");
89+
DEFINE_string(checkpoint_format, "pth",
90+
"checkpoint format: bin|pth. "
91+
"'bin' generates model.bin/optimizer.bin (bin supports LLMC model format via callbacks); "
92+
"'pth' generates model.pth/optimizer.pth (native StateDict binary).");
9193
// precision check
9294
DEFINE_string(
9395
precision_check, "",

0 commit comments

Comments
 (0)