11#pragma once
22
3- #include < cstdint>
43#include < cmath>
4+ #include < cstdint>
55#include < functional>
66#include < memory>
77#include < string>
@@ -13,12 +13,11 @@ namespace infini_train {
1313
1414class Optimizer ;
1515
16- using StateValue = std::variant<int64_t , float , double , std::string,
17- std::vector<float >>;
16+ using StateValue = std::variant<int64_t , float , double , std::string, std::vector<float >>;
1817using StateDict = std::unordered_map<std::string, StateValue>;
1918
2019struct LRSchedulerConfig {
21- std::string type = " none" ;
20+ std::string type = " none" ;
2221 // ConstantLR
2322 float constant_factor = 1 .0f / 3 .0f ;
2423 int constant_total_iters = 5 ;
@@ -44,15 +43,13 @@ struct LRSchedulerConfig {
4443
4544class LRScheduler {
4645public:
47- template <typename T, typename ... Args>
48- static std::shared_ptr<T> Create (Args&&... args) {
46+ template <typename T, typename ... Args> static std::shared_ptr<T> Create (Args &&...args) {
4947 auto scheduler = std::make_shared<T>(std::forward<Args>(args)...);
5048 scheduler->InitialStep ();
5149 return scheduler;
5250 }
5351
54- explicit LRScheduler (std::shared_ptr<Optimizer> optimizer,
55- int64_t last_step = -1 );
52+ explicit LRScheduler (std::shared_ptr<Optimizer> optimizer, int64_t last_step = -1 );
5653 virtual ~LRScheduler () = default ;
5754
5855 LRScheduler (const LRScheduler &) = delete ;
@@ -82,17 +79,13 @@ class LRScheduler {
8279 bool is_initial_ = false ;
8380};
8481
85- std::shared_ptr<LRScheduler> CreateLRScheduler (
86- std::shared_ptr<Optimizer> optimizer,
87- const LRSchedulerConfig& config);
82+ std::shared_ptr<LRScheduler> CreateLRScheduler (std::shared_ptr<Optimizer> optimizer, const LRSchedulerConfig &config);
8883
8984namespace lr_schedulers {
9085
9186class ConstantLR : public LRScheduler {
9287public:
93- ConstantLR (std::shared_ptr<Optimizer> optimizer,
94- float factor = 1 .0f / 3 .0f ,
95- int total_iters = 5 ,
88+ ConstantLR (std::shared_ptr<Optimizer> optimizer, float factor = 1 .0f / 3 .0f , int total_iters = 5 ,
9689 int64_t last_step = -1 );
9790 ~ConstantLR () override = default ;
9891
@@ -107,10 +100,7 @@ class ConstantLR : public LRScheduler {
107100
108101class StepLR : public LRScheduler {
109102public:
110- StepLR (std::shared_ptr<Optimizer> optimizer,
111- int64_t step_size,
112- float gamma = 0 .1f ,
113- int64_t last_step = -1 );
103+ StepLR (std::shared_ptr<Optimizer> optimizer, int64_t step_size, float gamma = 0 .1f , int64_t last_step = -1 );
114104 ~StepLR () override = default ;
115105
116106protected:
@@ -124,11 +114,8 @@ class StepLR : public LRScheduler {
124114
125115class LinearLR : public LRScheduler {
126116public:
127- LinearLR (std::shared_ptr<Optimizer> optimizer,
128- float start_factor = 1 .0f / 3 .0f ,
129- float end_factor = 1 .0f ,
130- int64_t total_iters = 5 ,
131- int64_t last_step = -1 );
117+ LinearLR (std::shared_ptr<Optimizer> optimizer, float start_factor = 1 .0f / 3 .0f , float end_factor = 1 .0f ,
118+ int64_t total_iters = 5 , int64_t last_step = -1 );
132119 ~LinearLR () override = default ;
133120
134121protected:
@@ -145,9 +132,7 @@ class LambdaLR : public LRScheduler {
145132public:
146133 using LambdaFunc = std::function<float (int64_t )>;
147134
148- LambdaLR (std::shared_ptr<Optimizer> optimizer,
149- LambdaFunc lr_lambda,
150- int64_t last_step = -1 );
135+ LambdaLR (std::shared_ptr<Optimizer> optimizer, LambdaFunc lr_lambda, int64_t last_step = -1 );
151136 ~LambdaLR () override = default ;
152137
153138protected:
@@ -157,13 +142,10 @@ class LambdaLR : public LRScheduler {
157142 const LambdaFunc lr_lambda_;
158143};
159144
160-
161145class SequentialLR : public LRScheduler {
162146public:
163- SequentialLR (std::shared_ptr<Optimizer> optimizer,
164- std::vector<std::shared_ptr<LRScheduler>> schedulers,
165- std::vector<int64_t > milestones,
166- int64_t last_step = -1 );
147+ SequentialLR (std::shared_ptr<Optimizer> optimizer, std::vector<std::shared_ptr<LRScheduler>> schedulers,
148+ std::vector<int64_t > milestones, int64_t last_step = -1 );
167149 ~SequentialLR () override = default ;
168150
169151 void Step () override ;
@@ -183,16 +165,15 @@ class SequentialLR : public LRScheduler {
183165
184166class ChainedScheduler : public LRScheduler {
185167public:
186- ChainedScheduler (std::shared_ptr<Optimizer> optimizer,
187- std::vector<std::shared_ptr<LRScheduler>> schedulers,
168+ ChainedScheduler (std::shared_ptr<Optimizer> optimizer, std::vector<std::shared_ptr<LRScheduler>> schedulers,
188169 int64_t last_step = -1 );
189170 ~ChainedScheduler () override = default ;
190171
191172 void Step () override ;
192173 void InitialStep () override ;
193174
194175 StateDict State () const override ;
195- void LoadState (const StateDict& state) override ;
176+ void LoadState (const StateDict & state) override ;
196177
197178protected:
198179 float GetClosedFormLR () const override { return current_lr_; }
@@ -201,6 +182,5 @@ class ChainedScheduler : public LRScheduler {
201182 std::vector<std::shared_ptr<LRScheduler>> schedulers_;
202183};
203184
204-
205- } // namespace lr_schedulers
206- } // namespace infini_train
185+ } // namespace lr_schedulers
186+ } // namespace infini_train
0 commit comments