|
1 | 1 | #include "infini_train/include/lr_scheduler.h" |
2 | 2 |
|
| 3 | +#include <algorithm> |
| 4 | +#include <cmath> |
| 5 | +#include <numbers> |
| 6 | +#include <utility> |
| 7 | + |
3 | 8 | #include "glog/logging.h" |
4 | 9 |
|
5 | 10 | #include "infini_train/include/optimizer.h" |
6 | 11 |
|
7 | 12 | namespace infini_train { |
8 | 13 |
|
9 | | -std::shared_ptr<LRScheduler> CreateLRScheduler(std::shared_ptr<Optimizer> optimizer, const LRSchedulerConfig &config) { |
10 | | - if (config.type == "none") { |
| 14 | +std::shared_ptr<LRScheduler> CreateLRScheduler(std::shared_ptr<Optimizer> optimizer, |
| 15 | + const TrainingLRSchedulerConfig &config) { |
| 16 | + if (config.lr_decay_style == "none") { |
11 | 17 | return nullptr; |
12 | 18 | } |
13 | 19 |
|
14 | | - auto create_main = [&](std::shared_ptr<Optimizer> opt) -> std::shared_ptr<LRScheduler> { |
15 | | - if (config.type == "constant") { |
16 | | - return LRScheduler::Create<lr_schedulers::ConstantLR>(opt, config.constant_factor, |
17 | | - config.constant_total_iters); |
18 | | - } |
19 | | - if (config.type == "step") { |
20 | | - return LRScheduler::Create<lr_schedulers::StepLR>(opt, config.step_size, config.step_gamma); |
21 | | - } |
22 | | - if (config.type == "linear") { |
23 | | - return LRScheduler::Create<lr_schedulers::LinearLR>(opt, config.linear_start_factor, |
24 | | - config.linear_end_factor, config.linear_total_iters); |
25 | | - } |
26 | | - if (config.type == "lambda") { |
27 | | - return LRScheduler::Create<lr_schedulers::LambdaLR>(opt, config.lambda_fn); |
28 | | - } |
29 | | - if (config.type == "sequential") { |
30 | | - std::vector<std::shared_ptr<LRScheduler>> schedulers; |
31 | | - std::vector<int64_t> milestones = config.sequential_milestones; |
32 | | - for (const auto &sub_config : config.sequential_configs) { |
33 | | - auto sub_sched = CreateLRScheduler(opt, sub_config); |
34 | | - if (sub_sched) { |
35 | | - schedulers.push_back(sub_sched); |
| 20 | + CHECK(optimizer) << "CreateLRScheduler: optimizer must not be null."; |
| 21 | + const float max_lr = config.lr != 0.0f ? config.lr : optimizer->GetLearningRate(); |
| 22 | + CHECK_GT(max_lr, 0.0f) << "CreateLRScheduler: max_lr must be > 0."; |
| 23 | + CHECK_GE(config.lr_warmup_init, 0.0f) << "CreateLRScheduler: lr_warmup_init must be >= 0."; |
| 24 | + CHECK_GE(config.min_lr, 0.0f) << "CreateLRScheduler: min_lr must be >= 0."; |
| 25 | + CHECK_GE(max_lr, config.min_lr) << "CreateLRScheduler: max_lr must be >= min_lr."; |
| 26 | + CHECK_LE(config.lr_warmup_init, max_lr) << "CreateLRScheduler: lr_warmup_init must be <= max_lr."; |
| 27 | + CHECK_GE(config.lr_warmup_iters, 0) << "CreateLRScheduler: lr_warmup_iters must be >= 0."; |
| 28 | + CHECK_GT(config.lr_decay_iters, 0) << "CreateLRScheduler: lr_decay_iters must be > 0."; |
| 29 | + CHECK_LT(config.lr_warmup_iters, config.lr_decay_iters) |
| 30 | + << "CreateLRScheduler: lr_warmup_iters must be < lr_decay_iters."; |
| 31 | + CHECK(config.lr_decay_style == "constant" || config.lr_decay_style == "linear" || config.lr_decay_style == "cosine" |
| 32 | + || config.lr_decay_style == "inverse-square-root") |
| 33 | + << "CreateLRScheduler: unsupported lr_decay_style: " << config.lr_decay_style; |
| 34 | + |
| 35 | + std::shared_ptr<LRScheduler> main_scheduler; |
| 36 | + const int64_t decay_iters_after_warmup = config.lr_decay_iters - config.lr_warmup_iters; |
| 37 | + if (config.lr_decay_style == "constant") { |
| 38 | + main_scheduler = LRScheduler::Create<lr_schedulers::LambdaLR>(optimizer, [](int64_t) { return 1.0f; }); |
| 39 | + } else if (config.lr_decay_style == "linear") { |
| 40 | + main_scheduler = LRScheduler::Create<lr_schedulers::LinearLR>(optimizer, 1.0f, config.min_lr / max_lr, |
| 41 | + decay_iters_after_warmup); |
| 42 | + } else if (config.lr_decay_style == "cosine") { |
| 43 | + main_scheduler = LRScheduler::Create<lr_schedulers::LambdaLR>( |
| 44 | + optimizer, [max_lr, min_lr = config.min_lr, decay_iters_after_warmup](int64_t step) { |
| 45 | + if (step > decay_iters_after_warmup) { |
| 46 | + return min_lr / max_lr; |
36 | 47 | } |
37 | | - } |
38 | | - return LRScheduler::Create<lr_schedulers::SequentialLR>(opt, schedulers, milestones); |
39 | | - } |
40 | | - if (config.type == "chained") { |
41 | | - std::vector<std::shared_ptr<LRScheduler>> schedulers; |
42 | | - for (const auto &sub_config : config.chained_configs) { |
43 | | - auto sub_sched = CreateLRScheduler(opt, sub_config); |
44 | | - if (sub_sched) { |
45 | | - schedulers.push_back(sub_sched); |
| 48 | + const float decay_ratio = static_cast<float>(step) / static_cast<float>(decay_iters_after_warmup); |
| 49 | + CHECK_GE(decay_ratio, 0.0f) << "CreateLRScheduler: decay " |
| 50 | + "ratio must be >= 0."; |
| 51 | + CHECK_LE(decay_ratio, 1.0f) << "CreateLRScheduler: decay " |
| 52 | + "ratio must be <= 1."; |
| 53 | + const float coeff = 0.5f * (std::cos(std::numbers::pi_v<float> * decay_ratio) + 1.0f); |
| 54 | + return (min_lr + coeff * (max_lr - min_lr)) / max_lr; |
| 55 | + }); |
| 56 | + } else if (config.lr_decay_style == "inverse-square-root") { |
| 57 | + main_scheduler = LRScheduler::Create<lr_schedulers::LambdaLR>( |
| 58 | + optimizer, [max_lr, min_lr = config.min_lr, lr_warmup_iters = config.lr_warmup_iters, |
| 59 | + lr_decay_iters = config.lr_decay_iters](int64_t step) { |
| 60 | + const int64_t global_step = step + lr_warmup_iters; |
| 61 | + if (global_step > lr_decay_iters) { |
| 62 | + return min_lr / max_lr; |
46 | 63 | } |
47 | | - } |
48 | | - return LRScheduler::Create<lr_schedulers::ChainedScheduler>(opt, schedulers); |
49 | | - } |
50 | | - LOG(FATAL) << "Unsupported LR scheduler type: " << config.type; |
51 | | - return nullptr; |
52 | | - }; |
53 | | - |
54 | | - if (config.warmup_steps <= 0) { |
55 | | - return create_main(optimizer); |
| 64 | + const auto warmup = static_cast<float>(std::max<int64_t>(lr_warmup_iters, 1)); |
| 65 | + const auto current = static_cast<float>(std::max<int64_t>(global_step, 1)); |
| 66 | + return std::max(min_lr, max_lr * std::sqrt(warmup) / std::sqrt(current)) / max_lr; |
| 67 | + }); |
56 | 68 | } |
57 | 69 |
|
58 | | - auto warmup_scheduler = LRScheduler::Create<lr_schedulers::LinearLR>(optimizer, |
59 | | - /*start_factor=*/config.warmup_start_factor, |
60 | | - /*end_factor=*/config.warmup_end_factor, |
61 | | - /*total_iters=*/config.warmup_steps); |
62 | | - |
63 | | - auto main_scheduler = create_main(optimizer); |
| 70 | + CHECK(main_scheduler) << "CreateLRScheduler: failed to create scheduler."; |
| 71 | + if (config.lr_warmup_iters == 0) { |
| 72 | + return main_scheduler; |
| 73 | + } |
64 | 74 |
|
| 75 | + auto warmup_scheduler = LRScheduler::Create<lr_schedulers::LambdaLR>( |
| 76 | + optimizer, |
| 77 | + [lr_warmup_init = config.lr_warmup_init, max_lr, lr_warmup_iters = config.lr_warmup_iters](int64_t step) { |
| 78 | + const float warmup_ratio = static_cast<float>(step) / static_cast<float>(lr_warmup_iters); |
| 79 | + return (lr_warmup_init + (max_lr - lr_warmup_init) * warmup_ratio) / max_lr; |
| 80 | + }); |
65 | 81 | return LRScheduler::Create<lr_schedulers::SequentialLR>( |
66 | | - optimizer, std::vector<std::shared_ptr<LRScheduler>>{warmup_scheduler, main_scheduler}, |
67 | | - std::vector<int64_t>{config.warmup_steps}); |
68 | | -}; |
| 82 | + std::move(optimizer), std::vector<std::shared_ptr<LRScheduler>>{warmup_scheduler, main_scheduler}, |
| 83 | + std::vector<int64_t>{config.lr_warmup_iters}); |
| 84 | +} |
69 | 85 |
|
70 | 86 | LRScheduler::LRScheduler(std::shared_ptr<Optimizer> optimizer, int64_t last_step) |
71 | 87 | : optimizer_(std::move(optimizer)), last_step_(last_step), base_lr_(0.0f) { |
@@ -310,9 +326,7 @@ ChainedScheduler::ChainedScheduler(std::shared_ptr<Optimizer> optimizer, |
310 | 326 | } |
311 | 327 | } |
312 | 328 |
|
313 | | -void ChainedScheduler::InitialStep() { |
314 | | - last_step_ = 0; |
315 | | -} |
| 329 | +void ChainedScheduler::InitialStep() { last_step_ = 0; } |
316 | 330 |
|
317 | 331 | void ChainedScheduler::Step() { |
318 | 332 | ++last_step_; |
|
0 commit comments