Skip to content

Commit de0167e

Browse files
JYMiracle305kilinchange
authored andcommitted
refactor(checkpoint): centralize config, simplify prune, fix optimizer test
1 parent 0bdef97 commit de0167e

17 files changed

Lines changed: 154 additions & 318 deletions

File tree

CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,6 @@ add_executable(gpt2
194194
example/gpt2/main.cc
195195
example/common/tiny_shakespeare_dataset.cc
196196
example/common/utils.cc
197-
example/common/checkpoint_loader.cc
198197
example/common/tokenizer.cc
199198
example/gpt2/checkpoint_loader.cc
200199
)
@@ -204,7 +203,6 @@ add_executable(llama3
204203
example/llama3/main.cc
205204
example/common/tiny_shakespeare_dataset.cc
206205
example/common/utils.cc
207-
example/common/checkpoint_loader.cc
208206
example/common/tokenizer.cc
209207
example/llama3/checkpoint_loader.cc
210208
)

example/gpt2/main.cc

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
#include "glog/logging.h"
1212

1313
#include "infini_train/include/autocast.h"
14-
#include "infini_train/include/checkpoint.h"
14+
#include "infini_train/include/checkpoint/checkpoint.h"
1515
#include "infini_train/include/core/runtime/device_guard.h"
1616
#include "infini_train/include/dataloader.h"
1717
#include "infini_train/include/device.h"
@@ -31,12 +31,12 @@
3131
#ifdef PROFILE_MODE
3232
#include "infini_train/include/profiler.h"
3333
#endif
34+
#include "infini_train/include/checkpoint/checkpoint_manager.h"
3435
#include "infini_train/include/nn/parallel/utils.h"
3536
#include "infini_train/include/utils/global_module_hook_registry.h"
3637
#include "infini_train/include/utils/precision_check_config.h"
3738
#include "infini_train/include/utils/precision_checker.h"
3839

39-
#include "example/common/checkpoint_loader.h"
4040
#include "example/common/tiny_shakespeare_dataset.h"
4141
#include "example/common/tokenizer.h"
4242
#include "example/gpt2/checkpoint_loader.h"
@@ -85,7 +85,7 @@ DEFINE_uint32(save_interval, 0, "save checkpoint every N steps; 0 disables savin
8585
DEFINE_string(load, "", "checkpoint directory to resume from");
8686
DEFINE_string(save, "./checkpoints", "root directory used to store checkpoints");
8787
DEFINE_uint32(max_checkpoint_keep, 3, "max number of checkpoint steps to keep");
88-
DEFINE_bool(no_save_optim, false, "whether optimizer state is persisted in checkpoints");
88+
DEFINE_bool(save_optimizer_state, true, "whether optimizer state is persisted in checkpoints");
8989
// precision check
9090
DEFINE_string(
9191
precision_check, "",
@@ -331,7 +331,8 @@ void Train(const nn::parallel::Rank &rank) {
331331
.model = model,
332332
.optimizer = optimizer,
333333
.model_config = model_config,
334-
.state = state});
334+
.state = state,
335+
.load_optimizer_state = false});
335336
start_step = resume_result.global_step;
336337
size_t consumed_batches = resume_result.consumed_batches;
337338

@@ -345,31 +346,29 @@ void Train(const nn::parallel::Rank &rank) {
345346
for (size_t i = 0; i < num_skips; ++i) { ++train_iter; }
346347
}
347348

348-
auto save_checkpoint
349-
= [&](const std::filesystem::path &save_dir, int64_t global_step, bool prune_step_checkpoints) {
350-
SaveCheckpoint({
351-
.save_dir = save_dir,
352-
.global_step = global_step,
353-
.consumed_batches = consumed_batches,
354-
.last_lr = FLAGS_learning_rate,
355-
.n_layer = model_config.n_layer,
356-
.n_head = model_config.n_head,
357-
.n_kv_head = model_config.n_kv_head,
358-
.n_embd = model_config.n_embd,
359-
.vocab_size = model_config.vocab_size,
360-
.ddp_size = ddp_world_size,
361-
.tp_size = tp_world_size,
362-
.sp_size = sp_world_size,
363-
.pp_size = pp_world_size,
364-
.no_save_optim = FLAGS_no_save_optim,
365-
.prune_step_checkpoints = prune_step_checkpoints,
366-
.checkpoint_root_dir = FLAGS_save,
367-
.max_checkpoint_keep = FLAGS_max_checkpoint_keep,
368-
.rank = rank,
369-
.model = *model,
370-
.optimizer = *optimizer,
371-
});
372-
};
349+
auto save_checkpoint = [&](const std::filesystem::path &save_dir, int64_t global_step) {
350+
SaveCheckpoint({
351+
.save_dir = save_dir,
352+
.global_step = global_step,
353+
.consumed_batches = consumed_batches,
354+
.last_lr = FLAGS_learning_rate,
355+
.n_layer = model_config.n_layer,
356+
.n_head = model_config.n_head,
357+
.n_kv_head = model_config.n_kv_head,
358+
.n_embd = model_config.n_embd,
359+
.vocab_size = model_config.vocab_size,
360+
.ddp_size = ddp_world_size,
361+
.tp_size = tp_world_size,
362+
.sp_size = sp_world_size,
363+
.pp_size = pp_world_size,
364+
.save_optimizer_state = FLAGS_save_optimizer_state,
365+
.checkpoint_root_dir = FLAGS_save,
366+
.max_checkpoint_keep = FLAGS_max_checkpoint_keep,
367+
.rank = rank,
368+
.model = *model,
369+
.optimizer = *optimizer,
370+
});
371+
};
373372

374373
LOG(INFO) << "start training";
375374

@@ -496,7 +495,7 @@ void Train(const nn::parallel::Rank &rank) {
496495
if (rank.IsParallel()) {
497496
step_dir /= std::format("rank_{:06d}", rank.GlobalRank());
498497
}
499-
save_checkpoint(step_dir, step + 1, true);
498+
save_checkpoint(step_dir, step + 1);
500499
}
501500
}
502501

@@ -510,7 +509,7 @@ void Train(const nn::parallel::Rank &rank) {
510509
if (rank.IsParallel()) {
511510
final_dir /= std::format("rank_{:06d}", rank.GlobalRank());
512511
}
513-
save_checkpoint(final_dir, FLAGS_num_iteration, false);
512+
save_checkpoint(final_dir, FLAGS_num_iteration);
514513

515514
#ifdef PROFILE_MODE
516515
Profiler::Instance().Report("gpt2.report", Profiler::SortBy::DeviceTimePercentage);

example/llama3/main.cc

Lines changed: 35 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
#include "glog/logging.h"
1010

1111
#include "infini_train/include/autocast.h"
12-
#include "infini_train/include/checkpoint.h"
12+
#include "infini_train/include/checkpoint/checkpoint.h"
13+
#include "infini_train/include/checkpoint/checkpoint_manager.h"
1314
#include "infini_train/include/core/runtime/device_guard.h"
1415
#include "infini_train/include/dataloader.h"
1516
#include "infini_train/include/device.h"
@@ -35,7 +36,6 @@
3536
#include "infini_train/include/profiler.h"
3637
#endif
3738

38-
#include "example/common/checkpoint_loader.h"
3939
#include "example/common/tiny_shakespeare_dataset.h"
4040
#include "example/common/tokenizer.h"
4141
#include "example/llama3/checkpoint_loader.h"
@@ -83,7 +83,7 @@ DEFINE_uint32(save_interval, 0, "save checkpoint every N steps; 0 disables savin
8383
DEFINE_string(load, "", "checkpoint directory to resume from");
8484
DEFINE_string(save, "./checkpoints", "root directory used to store checkpoints");
8585
DEFINE_uint32(max_checkpoint_keep, 3, "max number of checkpoint steps to keep");
86-
DEFINE_bool(no_save_optim, false, "whether optimizer state is persisted in checkpoints");
86+
DEFINE_bool(save_optimizer_state, true, "whether optimizer state is persisted in checkpoints");
8787

8888
// precision check
8989
DEFINE_string(
@@ -305,14 +305,13 @@ void Train(const nn::parallel::Rank &rank) {
305305

306306
int start_step = 0;
307307
TrainerState state;
308-
const auto resume_result = ResumeFromCheckpoint({
309-
.resume_root = FLAGS_load,
310-
.rank = rank,
311-
.model = model,
312-
.optimizer = optimizer,
313-
.model_config = model_config,
314-
.state = state,
315-
});
308+
const auto resume_result = ResumeFromCheckpoint({.resume_root = FLAGS_load,
309+
.rank = rank,
310+
.model = model,
311+
.optimizer = optimizer,
312+
.model_config = model_config,
313+
.state = state,
314+
.load_optimizer_state = true});
316315

317316
start_step = resume_result.global_step;
318317
size_t consumed_batches = resume_result.consumed_batches;
@@ -327,31 +326,29 @@ void Train(const nn::parallel::Rank &rank) {
327326
for (size_t i = 0; i < num_skips; ++i) { ++train_iter; }
328327
}
329328

330-
auto save_checkpoint
331-
= [&](const std::filesystem::path &save_dir, int64_t global_step, bool prune_step_checkpoints) {
332-
SaveCheckpoint({
333-
.save_dir = save_dir,
334-
.global_step = global_step,
335-
.consumed_batches = consumed_batches,
336-
.last_lr = FLAGS_learning_rate,
337-
.n_layer = model_config.n_layer,
338-
.n_head = model_config.n_head,
339-
.n_kv_head = model_config.n_kv_head,
340-
.n_embd = model_config.n_embd,
341-
.vocab_size = model_config.vocab_size,
342-
.ddp_size = ddp_world_size,
343-
.tp_size = tp_world_size,
344-
.sp_size = sp_world_size,
345-
.pp_size = pp_world_size,
346-
.no_save_optim = FLAGS_no_save_optim,
347-
.prune_step_checkpoints = prune_step_checkpoints,
348-
.checkpoint_root_dir = FLAGS_save,
349-
.max_checkpoint_keep = FLAGS_max_checkpoint_keep,
350-
.rank = rank,
351-
.model = *model,
352-
.optimizer = *optimizer,
353-
});
354-
};
329+
auto save_checkpoint = [&](const std::filesystem::path &save_dir, int64_t global_step) {
330+
SaveCheckpoint({
331+
.save_dir = save_dir,
332+
.global_step = global_step,
333+
.consumed_batches = consumed_batches,
334+
.last_lr = FLAGS_learning_rate,
335+
.n_layer = model_config.n_layer,
336+
.n_head = model_config.n_head,
337+
.n_kv_head = model_config.n_kv_head,
338+
.n_embd = model_config.n_embd,
339+
.vocab_size = model_config.vocab_size,
340+
.ddp_size = ddp_world_size,
341+
.tp_size = tp_world_size,
342+
.sp_size = sp_world_size,
343+
.pp_size = pp_world_size,
344+
.save_optimizer_state = FLAGS_save_optimizer_state,
345+
.checkpoint_root_dir = FLAGS_save,
346+
.max_checkpoint_keep = FLAGS_max_checkpoint_keep,
347+
.rank = rank,
348+
.model = *model,
349+
.optimizer = *optimizer,
350+
});
351+
};
355352

356353
for (int step = start_step; step < FLAGS_num_iteration + 1; ++step) {
357354
// Reset precision check counters at start of each iteration for file overwrite
@@ -475,7 +472,7 @@ void Train(const nn::parallel::Rank &rank) {
475472
if (rank.IsParallel()) {
476473
step_dir /= std::format("rank_{:06d}", rank.GlobalRank());
477474
}
478-
save_checkpoint(step_dir, step + 1, true);
475+
save_checkpoint(step_dir, step + 1);
479476
}
480477
}
481478

@@ -489,7 +486,7 @@ void Train(const nn::parallel::Rank &rank) {
489486
if (rank.IsParallel()) {
490487
final_dir /= std::format("rank_{:06d}", rank.GlobalRank());
491488
}
492-
save_checkpoint(final_dir, FLAGS_num_iteration, false);
489+
save_checkpoint(final_dir, FLAGS_num_iteration);
493490

494491
#ifdef PROFILE_MODE
495492
Profiler::Instance().Report("llama3.report", Profiler::SortBy::DeviceTimePercentage);

infini_train/include/checkpoint.h renamed to infini_train/include/checkpoint/checkpoint.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,16 @@ struct TrainerState {
3434
class Checkpoint {
3535
public:
3636
static void Save(const std::filesystem::path &checkpoint_dir, const nn::Module &model, const Optimizer *optimizer,
37-
const TrainerState &state, bool no_save_optim);
37+
const TrainerState &state, bool save_optimizer_state);
3838

3939
static void Load(const std::filesystem::path &checkpoint_dir, nn::Module &model, Optimizer *optimizer,
4040
TrainerState &state, bool load_optimizer_state);
4141

4242
private:
43-
static void SaveStateDictBinary(const std::filesystem::path &path,
44-
const std::unordered_map<std::string, std::shared_ptr<Tensor>> &state_dict);
43+
static void SaveStateDict(const std::filesystem::path &path,
44+
const std::unordered_map<std::string, std::shared_ptr<Tensor>> &state_dict);
4545

46-
static std::unordered_map<std::string, std::shared_ptr<Tensor>>
47-
LoadStateDictBinary(const std::filesystem::path &path);
46+
static std::unordered_map<std::string, std::shared_ptr<Tensor>> LoadStateDict(const std::filesystem::path &path);
4847

4948
static void SaveTrainerState(const std::filesystem::path &path, const TrainerState &state);
5049
static TrainerState LoadTrainerState(const std::filesystem::path &path);

example/common/checkpoint_loader.h renamed to infini_train/include/checkpoint/checkpoint_manager.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <cstring>
55
#include <filesystem>
66

7-
#include "infini_train/include/checkpoint.h"
7+
#include "infini_train/include/checkpoint/checkpoint.h"
88
#include "infini_train/include/dataloader.h"
99
#include "infini_train/include/nn/modules/module.h"
1010
#include "infini_train/include/nn/parallel/rank.h"
@@ -24,6 +24,7 @@ struct ResumeFromCheckpointArgs {
2424
std::shared_ptr<Optimizer> optimizer;
2525
const nn::TransformerConfig &model_config;
2626
TrainerState &state;
27+
bool load_optimizer_state;
2728
};
2829

2930
struct ResumeFromCheckpointResult {
@@ -45,8 +46,7 @@ struct SaveCheckpointArgs {
4546
int tp_size = 1;
4647
int sp_size = 1;
4748
int pp_size = 1;
48-
bool no_save_optim = false;
49-
bool prune_step_checkpoints = false;
49+
bool save_optimizer_state = true;
5050
std::filesystem::path checkpoint_root_dir;
5151
size_t max_checkpoint_keep = 0;
5252
const nn::parallel::Rank &rank;

infini_train/include/nn/modules/module.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,8 @@ class Module : public std::enable_shared_from_this<Module> {
4747

4848
const std::string &type() const;
4949

50+
// TODO: Change return type to filterable iterator (like PyTorch's named_parameters with prefix matching)
5051
virtual std::vector<std::shared_ptr<Tensor>> Parameters() const;
51-
std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> NamedParameters(const std::string &prefix = "",
52-
bool remove_duplicate = true) const;
5352
bool has_parameter(const std::string &name) const;
5453
std::shared_ptr<Tensor> *mutable_parameter(const std::string &name);
5554
const std::shared_ptr<Tensor> &parameter(const std::string &name) const;

infini_train/include/optimizer.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,10 @@ namespace infini_train {
1414
class Optimizer;
1515

1616
using OptimizerCreator = std::function<std::shared_ptr<Optimizer>(const std::vector<std::shared_ptr<Tensor>> &params)>;
17-
using OptimizerCreatorNamed = std::function<std::shared_ptr<Optimizer>(
18-
const std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> &named_params)>;
1917

2018
class Optimizer {
2119
public:
2220
explicit Optimizer(const std::vector<std::shared_ptr<Tensor>> &params);
23-
Optimizer(const std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> &named_params);
2421

2522
virtual void ZeroGrad(bool set_to_none = true);
2623

@@ -32,19 +29,16 @@ class Optimizer {
3229

3330
protected:
3431
std::vector<std::shared_ptr<Tensor>> params_;
35-
std::vector<std::string> param_names_;
3632
};
3733

3834
namespace optimizers {
3935
class SGD : public Optimizer {
4036
public:
4137
SGD(const std::vector<std::shared_ptr<Tensor>> &params, float learning_rate);
42-
SGD(const std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> &named_params, float learning_rate);
4338

4439
void Step() override;
4540

4641
static OptimizerCreator Create(float learning_rate);
47-
static OptimizerCreatorNamed CreateNamed(float learning_rate);
4842

4943
private:
5044
const float learning_rate_ = 0.0;
@@ -54,8 +48,6 @@ class Adam : public Optimizer {
5448
public:
5549
Adam(const std::vector<std::shared_ptr<Tensor>> &params, float learning_rate = 1e-3, float beta1 = 0.9,
5650
float beta2 = 0.999, float eps = 1e-8);
57-
Adam(const std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> &named_params, float learning_rate = 1e-3,
58-
float beta1 = 0.9, float beta2 = 0.999, float eps = 1e-8);
5951

6052
void Step() override;
6153

@@ -64,8 +56,6 @@ class Adam : public Optimizer {
6456
void LoadStateDict(const std::unordered_map<std::string, std::shared_ptr<Tensor>> &state_dict) override;
6557
static OptimizerCreator Create(float learning_rate = 1e-3, float beta1 = 0.9, float beta2 = 0.999,
6658
float eps = 1e-8);
67-
static OptimizerCreatorNamed CreateNamed(float learning_rate = 1e-3, float beta1 = 0.9, float beta2 = 0.999,
68-
float eps = 1e-8);
6959

7060
private:
7161
int64_t t_;

0 commit comments

Comments
 (0)