Skip to content

Commit 85cddfd

Browse files
committed
temp2
1 parent 8ebe12f commit 85cddfd

8 files changed

Lines changed: 69 additions & 162 deletions

File tree

example/common/checkpoint_loader.cc

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &
3535
}
3636
}
3737

38-
Checkpoint::Load(resume_dir, args.model.get(), args.optimizer.get(), &args.state, args.load_options);
38+
Checkpoint::Load(resume_dir, *args.model, args.optimizer.get(), args.state);
3939

4040
result.global_step = static_cast<int>(args.state.global_step);
4141
if (args.state.data_batch_stride != static_cast<int64_t>(ddp_world_size)) {
@@ -70,17 +70,12 @@ void SaveCheckpoint(const SaveCheckpointArgs &args) {
7070
state.data_batch_stride = args.ddp_size;
7171
state.last_lr = args.last_lr;
7272
state.optimizer_type = args.optimizer_type;
73-
state.checkpoint_file_format = args.checkpoint_file_format;
7473
state.ddp_size = args.ddp_size;
7574
state.tp_size = args.tp_size;
7675
state.sp_size = args.sp_size;
7776
state.pp_size = args.pp_size;
7877

79-
CheckpointOptions options;
80-
options.format = args.checkpoint_file_format;
81-
options.no_save_optim = args.no_save_optim;
82-
options.model_bin_writer = args.model_bin_writer;
83-
Checkpoint::Save(args.save_dir, args.model, args.optimizer, state, options);
78+
Checkpoint::Save(args.save_dir, args.model, &args.optimizer, state);
8479

8580
const auto ckpt_end = std::chrono::high_resolution_clock::now();
8681
const double ckpt_ms = std::chrono::duration<double, std::milli>(ckpt_end - ckpt_start).count();

example/common/checkpoint_loader.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ struct ResumeFromCheckpointArgs {
2626
DistributedDataLoader &train_loader;
2727
TrainerState &state;
2828
DataLoaderIterator &train_iter;
29-
CheckpointLoadOptions load_options;
3029
};
3130

3231
struct ResumeFromCheckpointResult {
@@ -40,7 +39,6 @@ struct SaveCheckpointArgs {
4039
size_t data_batch_idx = 0;
4140
double last_lr = 0.0;
4241
std::string optimizer_type;
43-
std::string checkpoint_file_format = "bin";
4442
int ddp_size = 1;
4543
int tp_size = 1;
4644
int sp_size = 1;
@@ -52,7 +50,6 @@ struct SaveCheckpointArgs {
5250
const nn::parallel::Rank &rank;
5351
const nn::Module &model;
5452
const Optimizer &optimizer;
55-
std::function<void(const nn::Module &, const std::filesystem::path &)> model_bin_writer;
5653
};
5754

5855
ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args);

example/gpt2/main.cc

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,6 @@ DEFINE_string(load, "", "checkpoint directory to resume from");
8888
DEFINE_string(save, "./checkpoints", "root directory used to store checkpoints");
8989
DEFINE_uint32(max_checkpoint_keep, 3, "max number of checkpoint steps to keep");
9090
DEFINE_bool(no_save_optim, false, "whether optimizer state is persisted in checkpoints");
91-
DEFINE_string(checkpoint_file_format, "ckpt",
92-
"checkpoint format: bin|ckpt. "
93-
"'bin' generates model.bin/optimizer.bin (bin supports LLMC model format via callbacks); "
94-
"'ckpt' generates model.ckpt/optimizer.ckpt (native StateDict binary).");
9591
// precision check
9692
DEFINE_string(
9793
precision_check, "",
@@ -344,22 +340,13 @@ void Train(const nn::parallel::Rank &rank) {
344340

345341
int start_step = 0;
346342
TrainerState state;
347-
CheckpointLoadOptions load_options;
348-
load_options.load_optimizer_state = true;
349-
load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) {
350-
auto loaded_model = gpt2::LoadFromLLMC(model_path.string());
351-
target_model->LoadStateDict(loaded_model->StateDict());
352-
};
353-
const auto resume_result = ResumeFromCheckpoint({
354-
.resume_root = FLAGS_load,
355-
.rank = rank,
356-
.model = model,
357-
.optimizer = optimizer,
358-
.train_loader = train_loader,
359-
.state = state,
360-
.train_iter = train_iter,
361-
.load_options = load_options,
362-
});
343+
const auto resume_result = ResumeFromCheckpoint({.resume_root = FLAGS_load,
344+
.rank = rank,
345+
.model = model,
346+
.optimizer = optimizer,
347+
.train_loader = train_loader,
348+
.state = state,
349+
.train_iter = train_iter});
363350
start_step = resume_result.global_step;
364351
saved_data_batch_idx = resume_result.data_batch_idx;
365352

@@ -371,7 +358,6 @@ void Train(const nn::parallel::Rank &rank) {
371358
.data_batch_idx = saved_data_batch_idx,
372359
.last_lr = FLAGS_learning_rate,
373360
.optimizer_type = "SGD",
374-
.checkpoint_file_format = FLAGS_checkpoint_file_format,
375361
.ddp_size = ddp_world_size,
376362
.tp_size = tp_world_size,
377363
.sp_size = sp_world_size,
@@ -383,9 +369,6 @@ void Train(const nn::parallel::Rank &rank) {
383369
.rank = rank,
384370
.model = *model,
385371
.optimizer = *optimizer,
386-
.model_bin_writer
387-
= [&](const nn::Module &,
388-
const std::filesystem::path &model_path) { gpt2::SaveAsLLMC(llmc_model, model_path.string()); },
389372
});
390373
};
391374

example/llama3/main.cc

Lines changed: 22 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,7 @@ DEFINE_string(load, "", "checkpoint directory to resume from");
8686
DEFINE_string(save, "./checkpoints", "root directory used to store checkpoints");
8787
DEFINE_uint32(max_checkpoint_keep, 3, "max number of checkpoint steps to keep");
8888
DEFINE_bool(no_save_optim, true, "whether optimizer state is persisted in checkpoints");
89-
DEFINE_string(checkpoint_file_format, "ckpt",
90-
"checkpoint format: bin|ckpt. "
91-
"'bin' generates model.bin/optimizer.bin (bin supports LLMC model format via callbacks); "
92-
"'ckpt' generates model.ckpt/optimizer.ckpt (native StateDict binary).");
89+
9390
// precision check
9491
DEFINE_string(
9592
precision_check, "",
@@ -322,12 +319,6 @@ void Train(const nn::parallel::Rank &rank) {
322319

323320
int start_step = 0;
324321
TrainerState state;
325-
CheckpointLoadOptions load_options;
326-
load_options.load_optimizer_state = true;
327-
load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) {
328-
auto loaded_model = llama3::LoadFromLLMC(model_path.string());
329-
target_model->LoadStateDict(loaded_model->StateDict());
330-
};
331322
const auto resume_result = ResumeFromCheckpoint({
332323
.resume_root = FLAGS_load,
333324
.rank = rank,
@@ -336,36 +327,31 @@ void Train(const nn::parallel::Rank &rank) {
336327
.train_loader = train_loader,
337328
.state = state,
338329
.train_iter = train_iter,
339-
.load_options = load_options,
340330
});
341331
start_step = resume_result.global_step;
342332
saved_data_batch_idx = resume_result.data_batch_idx;
343333

344-
auto save_checkpoint = [&](const std::filesystem::path &save_dir, int64_t global_step,
345-
bool prune_step_checkpoints) {
346-
SaveCheckpoint({
347-
.save_dir = save_dir,
348-
.global_step = global_step,
349-
.data_batch_idx = saved_data_batch_idx,
350-
.last_lr = FLAGS_learning_rate,
351-
.optimizer_type = "Adam",
352-
.checkpoint_file_format = FLAGS_checkpoint_file_format,
353-
.ddp_size = ddp_world_size,
354-
.tp_size = tp_world_size,
355-
.sp_size = sp_world_size,
356-
.pp_size = pp_world_size,
357-
.no_save_optim = FLAGS_no_save_optim,
358-
.prune_step_checkpoints = prune_step_checkpoints,
359-
.checkpoint_root_dir = FLAGS_save,
360-
.max_checkpoint_keep = FLAGS_max_checkpoint_keep,
361-
.rank = rank,
362-
.model = *model,
363-
.optimizer = *optimizer,
364-
.model_bin_writer
365-
= [&](const nn::Module &,
366-
const std::filesystem::path &model_path) { llama3::SaveAsLLMC(llmc_model, model_path.string()); },
367-
});
368-
};
334+
auto save_checkpoint
335+
= [&](const std::filesystem::path &save_dir, int64_t global_step, bool prune_step_checkpoints) {
336+
SaveCheckpoint({
337+
.save_dir = save_dir,
338+
.global_step = global_step,
339+
.data_batch_idx = saved_data_batch_idx,
340+
.last_lr = FLAGS_learning_rate,
341+
.optimizer_type = "Adam",
342+
.ddp_size = ddp_world_size,
343+
.tp_size = tp_world_size,
344+
.sp_size = sp_world_size,
345+
.pp_size = pp_world_size,
346+
.no_save_optim = FLAGS_no_save_optim,
347+
.prune_step_checkpoints = prune_step_checkpoints,
348+
.checkpoint_root_dir = FLAGS_save,
349+
.max_checkpoint_keep = FLAGS_max_checkpoint_keep,
350+
.rank = rank,
351+
.model = *model,
352+
.optimizer = *optimizer,
353+
});
354+
};
369355

370356
for (int step = start_step; step < FLAGS_num_iteration + 1; ++step) {
371357
// Reset precision check counters at start of each iteration for file overwrite

infini_train/include/checkpoint.h

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,34 +18,24 @@ struct TrainerState {
1818
int64_t global_step = 0;
1919
int64_t data_batch_idx = 0;
2020
int64_t data_batch_stride = 1;
21+
// FIXME(zbl): learning_rate should be restored from scheduler state, move `last_lr` from TrainerState to
22+
// SchedulerState later
2123
double last_lr = 0.0;
2224
std::string optimizer_type = "unknown";
23-
std::string checkpoint_file_format = "bin";
2425

2526
int ddp_size = 1;
2627
int tp_size = 1;
2728
int sp_size = 1;
2829
int pp_size = 1;
2930
};
3031

31-
struct CheckpointOptions {
32-
std::string format = "bin";
33-
bool no_save_optim = false;
34-
std::function<void(const nn::Module &, const std::filesystem::path &)> model_bin_writer;
35-
};
36-
37-
struct CheckpointLoadOptions {
38-
bool load_optimizer_state = true;
39-
std::function<void(nn::Module *, const std::filesystem::path &)> model_bin_loader;
40-
};
41-
4232
class Checkpoint {
4333
public:
44-
static void Save(const std::filesystem::path &checkpoint_dir, const nn::Module &model, const Optimizer &optimizer,
45-
const TrainerState &state, const CheckpointOptions &options = {});
34+
static void Save(const std::filesystem::path &checkpoint_dir, const nn::Module &model, const Optimizer *optimizer,
35+
const TrainerState &state, bool no_save_optim = false);
4636

47-
static void Load(const std::filesystem::path &checkpoint_dir, nn::Module *model, Optimizer *optimizer,
48-
TrainerState *state, const CheckpointLoadOptions &options = {});
37+
static void Load(const std::filesystem::path &checkpoint_dir, nn::Module &model, Optimizer *optimizer,
38+
TrainerState &state, bool load_optimizer_state = true);
4939

5040
private:
5141
static void SaveStateDictBinary(const std::filesystem::path &path,
@@ -56,7 +46,6 @@ class Checkpoint {
5646

5747
static void SaveTrainerState(const std::filesystem::path &path, const TrainerState &state);
5848
static TrainerState LoadTrainerState(const std::filesystem::path &path);
59-
static std::string InferFormat(const std::filesystem::path &checkpoint_dir);
6049
};
6150

6251
} // namespace infini_train

infini_train/src/checkpoint.cc

Lines changed: 24 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -84,24 +84,20 @@ template <typename T> T ExtractNumberField(const std::string &content, const std
8484
}
8585
} // namespace
8686

87-
void Checkpoint::Save(const std::filesystem::path &checkpoint_dir, const nn::Module &model, const Optimizer &optimizer,
88-
const TrainerState &state, const CheckpointOptions &options) {
89-
CHECK(options.format == "bin" || options.format == "ckpt") << "Unsupported checkpoint format: " << options.format;
87+
void Checkpoint::Save(const std::filesystem::path &checkpoint_dir, const nn::Module &model, const Optimizer *optimizer,
88+
const TrainerState &state, bool no_save_optim) {
9089
std::filesystem::create_directories(checkpoint_dir);
91-
LOG(ERROR) << "[CKPT] Save begin: dir=" << checkpoint_dir << ", format=" << options.format
92-
<< ", global_step=" << state.global_step;
90+
LOG(ERROR) << "[CKPT] Save begin: dir=" << checkpoint_dir << ", global_step=" << state.global_step;
9391

94-
const auto model_path = checkpoint_dir / (options.format == "ckpt" ? "model.ckpt" : "model.bin");
95-
if (options.format == "bin" && options.model_bin_writer) {
96-
options.model_bin_writer(model, model_path);
97-
} else {
98-
SaveStateDictBinary(model_path, model.StateDict());
99-
}
92+
const auto model_path = checkpoint_dir / ("model.ckpt");
10093

101-
if (options.no_save_optim) {
102-
auto opt_state = optimizer.StateDict();
94+
SaveStateDictBinary(model_path, model.StateDict());
95+
96+
if (!no_save_optim) {
97+
CHECK(optimizer != nullptr) << "Optimizer pointer is null, cannot save optimizer state.";
98+
auto opt_state = optimizer->StateDict();
10399
if (!opt_state.empty()) {
104-
const auto opt_path = checkpoint_dir / (options.format == "ckpt" ? "optimizer.ckpt" : "optimizer.bin");
100+
const auto opt_path = checkpoint_dir / "optimizer.ckpt";
105101
SaveStateDictBinary(opt_path, opt_state);
106102
}
107103
}
@@ -110,48 +106,32 @@ void Checkpoint::Save(const std::filesystem::path &checkpoint_dir, const nn::Mod
110106
LOG(ERROR) << "[CKPT] Save done: dir=" << checkpoint_dir;
111107
}
112108

113-
void Checkpoint::Load(const std::filesystem::path &checkpoint_dir, nn::Module *model, Optimizer *optimizer,
114-
TrainerState *state, const CheckpointLoadOptions &options) {
115-
CHECK(model != nullptr);
116-
CHECK(state != nullptr);
117-
118-
const std::string format = InferFormat(checkpoint_dir);
119-
const auto model_path = checkpoint_dir / (format == "ckpt" ? "model.ckpt" : "model.bin");
120-
LOG(ERROR) << "[CKPT] Load begin: dir=" << checkpoint_dir << ", format=" << format;
109+
void Checkpoint::Load(const std::filesystem::path &checkpoint_dir, nn::Module &model, Optimizer *optimizer,
110+
TrainerState &state, bool load_optimizer_state) {
111+
const auto model_path = checkpoint_dir / "model.ckpt";
121112
LOG(ERROR) << "[CKPT] Loading model: " << model_path;
122-
if (format == "bin" && options.model_bin_loader) {
123-
const uint32_t magic = PeekMagic(model_path);
124-
if (magic == kCkptMagic) {
125-
LOG(ERROR) << "[CKPT] Model format detected: native checkpoint binary.";
126-
model->LoadStateDict(LoadStateDictBinary(model_path));
127-
} else {
128-
LOG(ERROR) << "[CKPT] Model format detected: external model.bin (magic=" << magic
129-
<< "), use model_bin_loader callback.";
130-
options.model_bin_loader(model, model_path);
131-
}
132-
} else {
133-
model->LoadStateDict(LoadStateDictBinary(model_path));
134-
}
135113

136-
if (optimizer != nullptr && options.load_optimizer_state) {
137-
const auto opt_path = checkpoint_dir / (format == "ckpt" ? "optimizer.ckpt" : "optimizer.bin");
114+
model.LoadStateDict(LoadStateDictBinary(model_path));
115+
116+
if (optimizer == nullptr) {
117+
LOG(ERROR) << "[CKPT] No optimizer instance, skip optimizer state loading.";
118+
} else if (load_optimizer_state) {
119+
const auto opt_path = checkpoint_dir / "optimizer.ckpt";
138120
if (std::filesystem::exists(opt_path)) {
139121
LOG(ERROR) << "[CKPT] Loading optimizer: " << opt_path;
140122
optimizer->LoadStateDict(LoadStateDictBinary(opt_path));
141123
} else {
142124
LOG(ERROR) << "[CKPT] Optimizer state not found, skip: " << opt_path;
143125
}
144-
} else if (optimizer == nullptr) {
145-
LOG(ERROR) << "[CKPT] No optimizer instance, skip optimizer state loading.";
146126
} else {
147127
LOG(ERROR) << "[CKPT] load_optimizer_state=false, skip optimizer state loading.";
148128
}
149129

150-
*state = LoadTrainerState(checkpoint_dir / "trainer_state.json");
151-
LOG(ERROR) << "[CKPT] Load done: global_step=" << state->global_step << ", data_batch_idx=" << state->data_batch_idx
152-
<< ", data_batch_stride=" << state->data_batch_stride << ", last_lr=" << state->last_lr
153-
<< ", optimizer_type=" << state->optimizer_type << ", topology(ddp,tp,sp,pp)=(" << state->ddp_size << ","
154-
<< state->tp_size << "," << state->sp_size << "," << state->pp_size << ")";
130+
state = LoadTrainerState(checkpoint_dir / "trainer_state.json");
131+
LOG(ERROR) << "[CKPT] Load done: global_step=" << state.global_step << ", data_batch_idx=" << state.data_batch_idx
132+
<< ", data_batch_stride=" << state.data_batch_stride << ", last_lr=" << state.last_lr
133+
<< ", optimizer_type=" << state.optimizer_type << ", topology(ddp,tp,sp,pp)=(" << state.ddp_size << ","
134+
<< state.tp_size << "," << state.sp_size << "," << state.pp_size << ")";
155135
}
156136

157137
void Checkpoint::SaveStateDictBinary(const std::filesystem::path &path,
@@ -233,7 +213,6 @@ void Checkpoint::SaveTrainerState(const std::filesystem::path &path, const Train
233213
ofs << " \"data_batch_stride\": " << state.data_batch_stride << ",\n";
234214
ofs << " \"last_lr\": " << state.last_lr << ",\n";
235215
ofs << " \"optimizer_type\": \"" << state.optimizer_type << "\",\n";
236-
ofs << " \"checkpoint_file_format\": \"" << state.checkpoint_file_format << "\",\n";
237216
ofs << " \"ddp_size\": " << state.ddp_size << ",\n";
238217
ofs << " \"tp_size\": " << state.tp_size << ",\n";
239218
ofs << " \"sp_size\": " << state.sp_size << ",\n";
@@ -252,23 +231,10 @@ TrainerState Checkpoint::LoadTrainerState(const std::filesystem::path &path) {
252231
state.data_batch_stride = ExtractNumberField<int64_t>(content, "data_batch_stride", 1);
253232
state.last_lr = ExtractNumberField<double>(content, "last_lr", 0.0);
254233
state.optimizer_type = ExtractStringField(content, "optimizer_type", "unknown");
255-
state.checkpoint_file_format = ExtractStringField(content, "checkpoint_file_format", "bin");
256234
state.ddp_size = ExtractNumberField<int>(content, "ddp_size", 1);
257235
state.tp_size = ExtractNumberField<int>(content, "tp_size", 1);
258236
state.sp_size = ExtractNumberField<int>(content, "sp_size", 1);
259237
state.pp_size = ExtractNumberField<int>(content, "pp_size", 1);
260238
return state;
261239
}
262-
263-
std::string Checkpoint::InferFormat(const std::filesystem::path &checkpoint_dir) {
264-
if (std::filesystem::exists(checkpoint_dir / "model.ckpt")) {
265-
return "ckpt";
266-
}
267-
if (std::filesystem::exists(checkpoint_dir / "model.bin")) {
268-
return "bin";
269-
}
270-
LOG(FATAL) << "Failed to infer checkpoint format from path: " << checkpoint_dir;
271-
return "bin";
272-
}
273-
274240
} // namespace infini_train

0 commit comments

Comments
 (0)