Skip to content

Commit 355d1ef

Browse files
fix: adapt to megatron-style arguments
1 parent 327d263 commit 355d1ef

14 files changed

Lines changed: 350 additions & 554 deletions

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,5 +225,8 @@ link_infini_train_exe(test_sequential_lr)
225225
add_executable(test_chained_lr test/lr_scheduler/test_chained_lr.cc)
226226
link_infini_train_exe(test_chained_lr)
227227

228+
add_executable(test_training_lr_scheduler test/lr_scheduler/test_training_lr_scheduler.cc)
229+
link_infini_train_exe(test_training_lr_scheduler)
230+
228231
add_executable(test_lr_scheduler_validation test/lr_scheduler/test_lr_scheduler_validation.cc)
229232
link_infini_train_exe(test_lr_scheduler_validation)

example/gpt2/main.cc

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
#include "infini_train/include/core/runtime/device_guard.h"
1414
#include "infini_train/include/dataloader.h"
1515
#include "infini_train/include/device.h"
16-
#include "infini_train/include/nn/lora/lora_utils.h"
1716
#include "infini_train/include/lr_scheduler.h"
17+
#include "infini_train/include/nn/lora/lora_utils.h"
1818
#include "infini_train/include/nn/modules/loss.h"
1919
#include "infini_train/include/nn/modules/module.h"
2020
#include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h"
@@ -55,18 +55,14 @@ DEFINE_uint32(num_iteration, 10, "number of iterations to run");
5555
DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation");
5656
DEFINE_uint32(text_length, 64, "the length of the generated text");
5757
// optimization
58-
DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations");
58+
DEFINE_double(learning_rate, 1e-4, "Peak learning rate.");
5959
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
6060
// lr scheduler
61-
DEFINE_string(lr_scheduler, "none", "Learning rate scheduler type: none|constant|step|linear");
62-
DEFINE_int64(warmup_steps, 0, "Number of linear warmup steps (0 = no warmup)");
63-
DEFINE_double(warmup_start_factor, 0.333333, "Starting learning rate factor for linear warmup (multiplied by base LR)");
64-
DEFINE_double(warmup_end_factor, 1.0, "Ending learning rate factor for linear warmup (multiplied by base LR)");
65-
DEFINE_int64(step_size, 30, "StepLR: period of learning rate decay");
66-
DEFINE_double(gamma, 0.1, "StepLR: multiplicative factor of lr decay");
67-
DEFINE_double(start_factor, 0.333333, "LinearLR: starting multiplicative factor");
68-
DEFINE_double(end_factor, 1.0, "LinearLR: ending multiplicative factor");
69-
DEFINE_int64(lr_total_iters, 5, "ConstantLR/LinearLR: total iterations for the scheduler");
61+
DEFINE_double(min_lr, 0.0, "Minimum learning rate.");
62+
DEFINE_string(lr_decay_style, "constant", "LR decay style: none|constant|linear|cosine|inverse-square-root");
63+
DEFINE_int64(lr_warmup_iters, 0, "Number of linear warmup iterations.");
64+
DEFINE_double(lr_warmup_init, 0.0, "Initial learning rate at the start of warmup.");
65+
DEFINE_int64(lr_decay_iters, 0, "Number of iterations to decay LR over (0 = num_iteration).");
7066
// evaluation
7167
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
7268
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
@@ -109,6 +105,8 @@ constexpr char kDeviceCPU[] = "cpu";
109105
constexpr char kDeviceCUDA[] = "cuda";
110106
constexpr char kDtypeFP32[] = "float32";
111107
constexpr char kDtypeBF16[] = "bfloat16";
108+
const std::unordered_set<std::string> kSupportedLRDecayStyles
109+
= {"none", "constant", "linear", "cosine", "inverse-square-root"};
112110

113111
//
114112
const std::unordered_map<std::string, GPT2Config> kModelToConfigs = {
@@ -129,6 +127,8 @@ const std::unordered_map<std::string, GPT2::ModelType> kStrToModelType = {
129127
DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); });
130128
DEFINE_validator(device,
131129
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
130+
DEFINE_validator(lr_decay_style,
131+
[](const char *, const std::string &value) { return kSupportedLRDecayStyles.contains(value); });
132132

133133
void Train(const nn::parallel::Rank &rank) {
134134
using namespace nn::parallel;
@@ -321,18 +321,14 @@ void Train(const nn::parallel::Rank &rank) {
321321
optimizer = optimizer_creator(params_to_optimize);
322322
}
323323

324-
LRSchedulerConfig sched_config;
325-
sched_config.type = FLAGS_lr_scheduler;
326-
sched_config.warmup_steps = FLAGS_warmup_steps;
327-
sched_config.warmup_start_factor = static_cast<float>(FLAGS_warmup_start_factor);
328-
sched_config.warmup_end_factor = static_cast<float>(FLAGS_warmup_end_factor);
329-
sched_config.step_size = FLAGS_step_size;
330-
sched_config.step_gamma = static_cast<float>(FLAGS_gamma);
331-
sched_config.linear_start_factor = static_cast<float>(FLAGS_start_factor);
332-
sched_config.linear_end_factor = static_cast<float>(FLAGS_end_factor);
333-
sched_config.constant_factor = static_cast<float>(FLAGS_start_factor); // 复用
334-
sched_config.constant_total_iters = FLAGS_lr_total_iters;
335-
sched_config.linear_total_iters = FLAGS_lr_total_iters;
324+
const int64_t lr_decay_iters = FLAGS_lr_decay_iters > 0 ? FLAGS_lr_decay_iters : FLAGS_num_iteration;
325+
TrainingLRSchedulerConfig sched_config;
326+
sched_config.lr = static_cast<float>(FLAGS_learning_rate);
327+
sched_config.min_lr = static_cast<float>(FLAGS_min_lr);
328+
sched_config.lr_decay_style = FLAGS_lr_decay_style;
329+
sched_config.lr_decay_iters = lr_decay_iters;
330+
sched_config.lr_warmup_iters = FLAGS_lr_warmup_iters;
331+
sched_config.lr_warmup_init = static_cast<float>(FLAGS_lr_warmup_init);
336332
auto scheduler = CreateLRScheduler(optimizer, sched_config);
337333

338334
auto train_iter = train_loader.begin();

example/llama3/main.cc

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
#include "infini_train/include/core/runtime/device_guard.h"
1212
#include "infini_train/include/dataloader.h"
1313
#include "infini_train/include/device.h"
14-
#include "infini_train/include/nn/lora/lora_utils.h"
1514
#include "infini_train/include/lr_scheduler.h"
15+
#include "infini_train/include/nn/lora/lora_utils.h"
1616
#include "infini_train/include/nn/modules/loss.h"
1717
#include "infini_train/include/nn/modules/module.h"
1818
#include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h"
@@ -54,18 +54,14 @@ DEFINE_uint32(num_iteration, 10, "number of iterations to run");
5454
DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation");
5555
DEFINE_uint32(text_length, 64, "the length of the generated text");
5656
// optimization
57-
DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations");
57+
DEFINE_double(learning_rate, 1e-5, "Peak learning rate.");
5858
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
5959
// lr scheduler
60-
DEFINE_string(lr_scheduler, "none", "Learning rate scheduler type: none|constant|step|linear");
61-
DEFINE_int64(warmup_steps, 0, "Number of linear warmup steps (0 = no warmup)");
62-
DEFINE_double(warmup_start_factor, 0.333333, "Starting learning rate factor for linear warmup (multiplied by base LR)");
63-
DEFINE_double(warmup_end_factor, 1.0, "Ending learning rate factor for linear warmup (multiplied by base LR)");
64-
DEFINE_int64(step_size, 30, "StepLR: period of learning rate decay");
65-
DEFINE_double(gamma, 0.1, "StepLR: multiplicative factor of lr decay");
66-
DEFINE_double(start_factor, 0.333333, "LinearLR: starting multiplicative factor");
67-
DEFINE_double(end_factor, 1.0, "LinearLR: ending multiplicative factor");
68-
DEFINE_int64(lr_total_iters, 5, "ConstantLR/LinearLR: total iterations for the scheduler");
60+
DEFINE_double(min_lr, 0.0, "Minimum learning rate.");
61+
DEFINE_string(lr_decay_style, "constant", "LR decay style: none|constant|linear|cosine|inverse-square-root");
62+
DEFINE_int64(lr_warmup_iters, 0, "Number of linear warmup iterations.");
63+
DEFINE_double(lr_warmup_init, 0.0, "Initial learning rate at the start of warmup.");
64+
DEFINE_int64(lr_decay_iters, 0, "Number of iterations to decay LR over (0 = num_iteration).");
6965
// evaluation
7066
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
7167
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
@@ -104,11 +100,15 @@ constexpr char kDeviceCPU[] = "cpu";
104100
constexpr char kDeviceCUDA[] = "cuda";
105101
constexpr char kDtypeFP32[] = "float32";
106102
constexpr char kDtypeBF16[] = "bfloat16";
103+
const std::unordered_set<std::string> kSupportedLRDecayStyles
104+
= {"none", "constant", "linear", "cosine", "inverse-square-root"};
107105
} // namespace
108106

109107
DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); });
110108
DEFINE_validator(device,
111109
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
110+
DEFINE_validator(lr_decay_style,
111+
[](const char *, const std::string &value) { return kSupportedLRDecayStyles.contains(value); });
112112

113113
void Train(const nn::parallel::Rank &rank) {
114114
using namespace nn::parallel;
@@ -293,18 +293,14 @@ void Train(const nn::parallel::Rank &rank) {
293293
optimizer = optimizer_creator(params_to_optimize);
294294
}
295295

296-
LRSchedulerConfig sched_config;
297-
sched_config.type = FLAGS_lr_scheduler;
298-
sched_config.warmup_steps = FLAGS_warmup_steps;
299-
sched_config.warmup_start_factor = static_cast<float>(FLAGS_warmup_start_factor);
300-
sched_config.warmup_end_factor = static_cast<float>(FLAGS_warmup_end_factor);
301-
sched_config.step_size = FLAGS_step_size;
302-
sched_config.step_gamma = static_cast<float>(FLAGS_gamma);
303-
sched_config.linear_start_factor = static_cast<float>(FLAGS_start_factor);
304-
sched_config.linear_end_factor = static_cast<float>(FLAGS_end_factor);
305-
sched_config.constant_factor = static_cast<float>(FLAGS_start_factor); // 复用
306-
sched_config.constant_total_iters = FLAGS_lr_total_iters;
307-
sched_config.linear_total_iters = FLAGS_lr_total_iters;
296+
const int64_t lr_decay_iters = FLAGS_lr_decay_iters > 0 ? FLAGS_lr_decay_iters : FLAGS_num_iteration;
297+
TrainingLRSchedulerConfig sched_config;
298+
sched_config.lr = static_cast<float>(FLAGS_learning_rate);
299+
sched_config.min_lr = static_cast<float>(FLAGS_min_lr);
300+
sched_config.lr_decay_style = FLAGS_lr_decay_style;
301+
sched_config.lr_decay_iters = lr_decay_iters;
302+
sched_config.lr_warmup_iters = FLAGS_lr_warmup_iters;
303+
sched_config.lr_warmup_init = static_cast<float>(FLAGS_lr_warmup_init);
308304
auto scheduler = CreateLRScheduler(optimizer, sched_config);
309305

310306
auto train_iter = train_loader.begin();

infini_train/include/lr_scheduler.h

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,13 @@ class Optimizer;
1616
using StateValue = std::variant<int64_t, float, double, std::string, std::vector<float>>;
1717
using StateDict = std::unordered_map<std::string, StateValue>;
1818

19-
struct LRSchedulerConfig {
20-
std::string type = "none";
21-
// ConstantLR
22-
float constant_factor = 1.0f / 3.0f;
23-
int constant_total_iters = 5;
24-
// StepLR
25-
int64_t step_size = 10;
26-
float step_gamma = 0.1f;
27-
// LinearLR
28-
float linear_start_factor = 1.0f / 3.0f;
29-
float linear_end_factor = 1.0f;
30-
int linear_total_iters = 5;
31-
// LambdaLR
32-
std::function<float(int64_t)> lambda_fn = nullptr;
33-
// SequentialLR
34-
std::vector<LRSchedulerConfig> sequential_configs;
35-
std::vector<int64_t> sequential_milestones;
36-
// ChainedScheduler
37-
std::vector<LRSchedulerConfig> chained_configs;
38-
// warmup
39-
int64_t warmup_steps = 0;
40-
float warmup_start_factor = 1.0f / 3.0f;
41-
float warmup_end_factor = 1.0f;
19+
struct TrainingLRSchedulerConfig {
20+
std::string lr_decay_style = "constant";
21+
float lr = 0.0f;
22+
float min_lr = 0.0f;
23+
int64_t lr_decay_iters = 1;
24+
int64_t lr_warmup_iters = 0;
25+
float lr_warmup_init = 0.0f;
4226
};
4327

4428
class LRScheduler {
@@ -81,7 +65,8 @@ class LRScheduler {
8165
bool is_initial_ = false;
8266
};
8367

84-
std::shared_ptr<LRScheduler> CreateLRScheduler(std::shared_ptr<Optimizer> optimizer, const LRSchedulerConfig &config);
68+
std::shared_ptr<LRScheduler> CreateLRScheduler(std::shared_ptr<Optimizer> optimizer,
69+
const TrainingLRSchedulerConfig &config);
8570

8671
namespace lr_schedulers {
8772

infini_train/src/lr_scheduler.cc

Lines changed: 68 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,87 @@
11
#include "infini_train/include/lr_scheduler.h"
22

3+
#include <algorithm>
4+
#include <cmath>
5+
#include <numbers>
6+
#include <utility>
7+
38
#include "glog/logging.h"
49

510
#include "infini_train/include/optimizer.h"
611

712
namespace infini_train {
813

9-
std::shared_ptr<LRScheduler> CreateLRScheduler(std::shared_ptr<Optimizer> optimizer, const LRSchedulerConfig &config) {
10-
if (config.type == "none") {
14+
std::shared_ptr<LRScheduler> CreateLRScheduler(std::shared_ptr<Optimizer> optimizer,
15+
const TrainingLRSchedulerConfig &config) {
16+
if (config.lr_decay_style == "none") {
1117
return nullptr;
1218
}
1319

14-
auto create_main = [&](std::shared_ptr<Optimizer> opt) -> std::shared_ptr<LRScheduler> {
15-
if (config.type == "constant") {
16-
return LRScheduler::Create<lr_schedulers::ConstantLR>(opt, config.constant_factor,
17-
config.constant_total_iters);
18-
}
19-
if (config.type == "step") {
20-
return LRScheduler::Create<lr_schedulers::StepLR>(opt, config.step_size, config.step_gamma);
21-
}
22-
if (config.type == "linear") {
23-
return LRScheduler::Create<lr_schedulers::LinearLR>(opt, config.linear_start_factor,
24-
config.linear_end_factor, config.linear_total_iters);
25-
}
26-
if (config.type == "lambda") {
27-
return LRScheduler::Create<lr_schedulers::LambdaLR>(opt, config.lambda_fn);
28-
}
29-
if (config.type == "sequential") {
30-
std::vector<std::shared_ptr<LRScheduler>> schedulers;
31-
std::vector<int64_t> milestones = config.sequential_milestones;
32-
for (const auto &sub_config : config.sequential_configs) {
33-
auto sub_sched = CreateLRScheduler(opt, sub_config);
34-
if (sub_sched) {
35-
schedulers.push_back(sub_sched);
20+
CHECK(optimizer) << "CreateLRScheduler: optimizer must not be null.";
21+
const float max_lr = config.lr != 0.0f ? config.lr : optimizer->GetLearningRate();
22+
CHECK_GT(max_lr, 0.0f) << "CreateLRScheduler: max_lr must be > 0.";
23+
CHECK_GE(config.lr_warmup_init, 0.0f) << "CreateLRScheduler: lr_warmup_init must be >= 0.";
24+
CHECK_GE(config.min_lr, 0.0f) << "CreateLRScheduler: min_lr must be >= 0.";
25+
CHECK_GE(max_lr, config.min_lr) << "CreateLRScheduler: max_lr must be >= min_lr.";
26+
CHECK_LE(config.lr_warmup_init, max_lr) << "CreateLRScheduler: lr_warmup_init must be <= max_lr.";
27+
CHECK_GE(config.lr_warmup_iters, 0) << "CreateLRScheduler: lr_warmup_iters must be >= 0.";
28+
CHECK_GT(config.lr_decay_iters, 0) << "CreateLRScheduler: lr_decay_iters must be > 0.";
29+
CHECK_LT(config.lr_warmup_iters, config.lr_decay_iters)
30+
<< "CreateLRScheduler: lr_warmup_iters must be < lr_decay_iters.";
31+
CHECK(config.lr_decay_style == "constant" || config.lr_decay_style == "linear" || config.lr_decay_style == "cosine"
32+
|| config.lr_decay_style == "inverse-square-root")
33+
<< "CreateLRScheduler: unsupported lr_decay_style: " << config.lr_decay_style;
34+
35+
std::shared_ptr<LRScheduler> main_scheduler;
36+
const int64_t decay_iters_after_warmup = config.lr_decay_iters - config.lr_warmup_iters;
37+
if (config.lr_decay_style == "constant") {
38+
main_scheduler = LRScheduler::Create<lr_schedulers::LambdaLR>(optimizer, [](int64_t) { return 1.0f; });
39+
} else if (config.lr_decay_style == "linear") {
40+
main_scheduler = LRScheduler::Create<lr_schedulers::LinearLR>(optimizer, 1.0f, config.min_lr / max_lr,
41+
decay_iters_after_warmup);
42+
} else if (config.lr_decay_style == "cosine") {
43+
main_scheduler = LRScheduler::Create<lr_schedulers::LambdaLR>(
44+
optimizer, [max_lr, min_lr = config.min_lr, decay_iters_after_warmup](int64_t step) {
45+
if (step > decay_iters_after_warmup) {
46+
return min_lr / max_lr;
3647
}
37-
}
38-
return LRScheduler::Create<lr_schedulers::SequentialLR>(opt, schedulers, milestones);
39-
}
40-
if (config.type == "chained") {
41-
std::vector<std::shared_ptr<LRScheduler>> schedulers;
42-
for (const auto &sub_config : config.chained_configs) {
43-
auto sub_sched = CreateLRScheduler(opt, sub_config);
44-
if (sub_sched) {
45-
schedulers.push_back(sub_sched);
48+
const float decay_ratio = static_cast<float>(step) / static_cast<float>(decay_iters_after_warmup);
49+
CHECK_GE(decay_ratio, 0.0f) << "CreateLRScheduler: decay "
50+
"ratio must be >= 0.";
51+
CHECK_LE(decay_ratio, 1.0f) << "CreateLRScheduler: decay "
52+
"ratio must be <= 1.";
53+
const float coeff = 0.5f * (std::cos(std::numbers::pi_v<float> * decay_ratio) + 1.0f);
54+
return (min_lr + coeff * (max_lr - min_lr)) / max_lr;
55+
});
56+
} else if (config.lr_decay_style == "inverse-square-root") {
57+
main_scheduler = LRScheduler::Create<lr_schedulers::LambdaLR>(
58+
optimizer, [max_lr, min_lr = config.min_lr, lr_warmup_iters = config.lr_warmup_iters,
59+
lr_decay_iters = config.lr_decay_iters](int64_t step) {
60+
const int64_t global_step = step + lr_warmup_iters;
61+
if (global_step > lr_decay_iters) {
62+
return min_lr / max_lr;
4663
}
47-
}
48-
return LRScheduler::Create<lr_schedulers::ChainedScheduler>(opt, schedulers);
49-
}
50-
LOG(FATAL) << "Unsupported LR scheduler type: " << config.type;
51-
return nullptr;
52-
};
53-
54-
if (config.warmup_steps <= 0) {
55-
return create_main(optimizer);
64+
const auto warmup = static_cast<float>(std::max<int64_t>(lr_warmup_iters, 1));
65+
const auto current = static_cast<float>(std::max<int64_t>(global_step, 1));
66+
return std::max(min_lr, max_lr * std::sqrt(warmup) / std::sqrt(current)) / max_lr;
67+
});
5668
}
5769

58-
auto warmup_scheduler = LRScheduler::Create<lr_schedulers::LinearLR>(optimizer,
59-
/*start_factor=*/config.warmup_start_factor,
60-
/*end_factor=*/config.warmup_end_factor,
61-
/*total_iters=*/config.warmup_steps);
62-
63-
auto main_scheduler = create_main(optimizer);
70+
CHECK(main_scheduler) << "CreateLRScheduler: failed to create scheduler.";
71+
if (config.lr_warmup_iters == 0) {
72+
return main_scheduler;
73+
}
6474

75+
auto warmup_scheduler = LRScheduler::Create<lr_schedulers::LambdaLR>(
76+
optimizer,
77+
[lr_warmup_init = config.lr_warmup_init, max_lr, lr_warmup_iters = config.lr_warmup_iters](int64_t step) {
78+
const float warmup_ratio = static_cast<float>(step) / static_cast<float>(lr_warmup_iters);
79+
return (lr_warmup_init + (max_lr - lr_warmup_init) * warmup_ratio) / max_lr;
80+
});
6581
return LRScheduler::Create<lr_schedulers::SequentialLR>(
66-
optimizer, std::vector<std::shared_ptr<LRScheduler>>{warmup_scheduler, main_scheduler},
67-
std::vector<int64_t>{config.warmup_steps});
68-
};
82+
std::move(optimizer), std::vector<std::shared_ptr<LRScheduler>>{warmup_scheduler, main_scheduler},
83+
std::vector<int64_t>{config.lr_warmup_iters});
84+
}
6985

7086
LRScheduler::LRScheduler(std::shared_ptr<Optimizer> optimizer, int64_t last_step)
7187
: optimizer_(std::move(optimizer)), last_step_(last_step), base_lr_(0.0f) {
@@ -310,9 +326,7 @@ ChainedScheduler::ChainedScheduler(std::shared_ptr<Optimizer> optimizer,
310326
}
311327
}
312328

313-
void ChainedScheduler::InitialStep() {
314-
last_step_ = 0;
315-
}
329+
void ChainedScheduler::InitialStep() { last_step_ = 0; }
316330

317331
void ChainedScheduler::Step() {
318332
++last_step_;

0 commit comments

Comments
 (0)