@@ -68,11 +68,10 @@ std::shared_ptr<LRScheduler> CreateLRScheduler(std::shared_ptr<Optimizer> optimi
6868};
6969
7070LRScheduler::LRScheduler (std::shared_ptr<Optimizer> optimizer, int64_t last_step)
71- : optimizer_(std::move(optimizer)), last_step_(last_step), current_lr_( 0 . 0f ), base_lr_(0 .0f ) {
71+ : optimizer_(std::move(optimizer)), last_step_(last_step), base_lr_(0 .0f ) {
7272 CHECK (optimizer_) << " LRScheduler: optimizer must not be null." ;
7373 optimizer_->SetInitialLearningRate (optimizer_->GetLearningRate ());
7474 base_lr_ = optimizer_->GetInitialLearningRate ();
75- current_lr_ = base_lr_;
7675}
7776
7877void LRScheduler::Step () {
@@ -91,34 +90,33 @@ void LRScheduler::InitialStep() {
9190 is_initial_ = false ;
9291}
9392
94- void LRScheduler::ApplyLR (float lr) {
95- current_lr_ = lr;
96- optimizer_->SetLearningRate (current_lr_);
97- }
93+ void LRScheduler::ApplyLR (float lr) { optimizer_->SetLearningRate (lr); }
9894
9995float LRScheduler::GetChainedFormLR () const { return GetClosedFormLR (); }
10096
101- float LRScheduler::GetLR () const { return current_lr_ ; }
97+ float LRScheduler::GetLR () const { return optimizer_-> GetLearningRate () ; }
10298
10399float LRScheduler::BaseLR () const { return base_lr_; }
104100
105101int64_t LRScheduler::LastStep () const { return last_step_; }
106102
103+ bool LRScheduler::SharesOptimizerWith (const std::shared_ptr<Optimizer> &opt) const { return optimizer_ == opt; }
104+
107105void LRScheduler::ResetStep (int64_t step) { last_step_ = step; }
108106
109107StateDict LRScheduler::State () const {
110108 return {
111109 {" last_step" , last_step_},
112- {" current_lr " , current_lr_ },
110+ {" recover_lr " , optimizer_-> GetLearningRate () },
113111 {" base_lr" , base_lr_},
114112 };
115113}
116114
117115void LRScheduler::LoadState (const StateDict &state) {
118116 last_step_ = std::get<int64_t >(state.at (" last_step" ));
119- current_lr_ = std::get<float >(state.at (" current_lr " ));
117+ recover_lr_ = std::get<float >(state.at (" recover_lr " ));
120118 base_lr_ = std::get<float >(state.at (" base_lr" ));
121- optimizer_->SetLearningRate (current_lr_ );
119+ optimizer_->SetLearningRate (recover_lr_ );
122120}
123121
124122// Concrete LR Schedulers
@@ -128,7 +126,10 @@ namespace lr_schedulers {
128126// --- ConstantLR ---
129127
130128ConstantLR::ConstantLR (std::shared_ptr<Optimizer> optimizer, float factor, int total_iters, int64_t last_step)
131- : LRScheduler(std::move(optimizer), last_step), factor_(factor), total_iters_(total_iters) {}
129+ : LRScheduler(std::move(optimizer), last_step), factor_(factor), total_iters_(total_iters) {
130+ CHECK_GE (factor_, 0 .0f ) << " ConstantLR: factor must be >= 0." ;
131+ CHECK_LE (factor_, 1 .0f ) << " ConstantLR: factor must be <= 1." ;
132+ }
132133
133134float ConstantLR::GetClosedFormLR () const { return last_step_ < total_iters_ ? base_lr_ * factor_ : base_lr_; }
134135
@@ -147,7 +148,10 @@ float ConstantLR::GetChainedFormLR() const {
147148// --- StepLR ---
148149
149150StepLR::StepLR (std::shared_ptr<Optimizer> optimizer, int64_t step_size, float gamma, int64_t last_step)
150- : LRScheduler(std::move(optimizer), last_step), step_size_(step_size), gamma_(gamma) {}
151+ : LRScheduler(std::move(optimizer), last_step), step_size_(step_size), gamma_(gamma) {
152+ CHECK_GT (step_size_, 0 ) << " StepLR: step_size must be > 0." ;
153+ CHECK_GT (gamma_, 0 .0f ) << " StepLR: gamma must be > 0." ;
154+ }
151155
152156float StepLR::GetClosedFormLR () const {
153157 return base_lr_
@@ -165,7 +169,13 @@ float StepLR::GetChainedFormLR() const {
165169LinearLR::LinearLR (std::shared_ptr<Optimizer> optimizer, float start_factor, float end_factor, int64_t total_iters,
166170 int64_t last_step)
167171 : LRScheduler(std::move(optimizer), last_step), start_factor_(start_factor), end_factor_(end_factor),
168- total_iters_ (total_iters) {}
172+ total_iters_ (total_iters) {
173+ CHECK_GT (start_factor_, 0 .0f ) << " LinearLR: start_factor must be > 0." ;
174+ CHECK_LE (start_factor_, 1 .0f ) << " LinearLR: start_factor must be <= 1." ;
175+ CHECK_GE (end_factor_, 0 .0f ) << " LinearLR: end_factor must be >= 0." ;
176+ CHECK_LE (end_factor_, 1 .0f ) << " LinearLR: end_factor must be <= 1." ;
177+ CHECK_GT (total_iters_, 0 ) << " LinearLR: total_iters must be > 0." ;
178+ }
169179
170180float LinearLR::GetClosedFormLR () const {
171181 if (last_step_ >= total_iters_) {
@@ -198,31 +208,40 @@ float LinearLR::GetChainedFormLR() const {
198208}
199209
200210LambdaLR::LambdaLR (std::shared_ptr<Optimizer> optimizer, std::function<float (int64_t )> lr_lambda, int64_t last_step)
201- : LRScheduler(std::move(optimizer), last_step), lr_lambda_(std::move(lr_lambda)) {}
211+ : LRScheduler(std::move(optimizer), last_step), lr_lambda_(std::move(lr_lambda)) {
212+ CHECK (lr_lambda_) << " LambdaLR: lr_lambda must not be null." ;
213+ }
202214
203215float LambdaLR::GetClosedFormLR () const { return base_lr_ * lr_lambda_ (last_step_); }
204216
205217SequentialLR::SequentialLR (std::shared_ptr<Optimizer> optimizer, std::vector<std::shared_ptr<LRScheduler>> schedulers,
206218 std::vector<int64_t > milestones, int64_t last_step)
207219 : LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)),
208- milestones_(std::move(milestones)) {}
209-
210- void SequentialLR::InitialStep () {
220+ milestones_(std::move(milestones)) {
211221 CHECK (!schedulers_.empty ()) << " SequentialLR requires at least one scheduler." ;
222+
223+ for (size_t i = 0 ; i < schedulers_.size (); ++i) {
224+ CHECK (schedulers_[i]) << " SequentialLR: scheduler at index " << i << " must not be null." ;
225+ CHECK (schedulers_[i]->SharesOptimizerWith (optimizer_))
226+ << " SequentialLR: scheduler at index " << i << " must share the same optimizer." ;
227+ }
228+
212229 CHECK_EQ (milestones_.size (), schedulers_.size () - 1 )
213230 << " SequentialLR: milestones count must be schedulers count - 1." ;
214231
215232 for (size_t i = 1 ; i < milestones_.size (); ++i) {
216233 CHECK_GT (milestones_[i], milestones_[i - 1 ]) << " Milestones must be strictly increasing." ;
217234 }
235+ }
236+
237+ void SequentialLR::InitialStep () {
218238
219239 optimizer_->SetLearningRate (schedulers_[0 ]->BaseLR ());
220240
221241 UndoChildInitialSteps ();
222242
223243 ++last_step_;
224244 schedulers_[0 ]->InitialStep ();
225- current_lr_ = schedulers_[0 ]->GetLR ();
226245}
227246
228247void SequentialLR::UndoChildInitialSteps () {
@@ -245,14 +264,12 @@ void SequentialLR::Step() {
245264 } else {
246265 scheduler->Step ();
247266 }
248-
249- current_lr_ = optimizer_->GetLearningRate ();
250267}
251268
252269StateDict SequentialLR::State () const {
253270 StateDict state;
254271 state[" last_step" ] = last_step_;
255- state[" current_lr " ] = current_lr_ ;
272+ state[" recover_lr " ] = optimizer_-> GetLearningRate () ;
256273 state[" base_lr" ] = base_lr_;
257274 for (size_t i = 0 ; i < schedulers_.size (); ++i) {
258275 auto sub_state = schedulers_[i]->State ();
@@ -263,7 +280,7 @@ StateDict SequentialLR::State() const {
263280
264281void SequentialLR::LoadState (const StateDict &state) {
265282 last_step_ = std::get<int64_t >(state.at (" last_step" ));
266- current_lr_ = std::get<float >(state.at (" current_lr " ));
283+ recover_lr_ = std::get<float >(state.at (" recover_lr " ));
267284 base_lr_ = std::get<float >(state.at (" base_lr" ));
268285
269286 for (size_t i = 0 ; i < schedulers_.size (); ++i) {
@@ -278,23 +295,28 @@ void SequentialLR::LoadState(const StateDict &state) {
278295 schedulers_[i]->LoadState (sub_state);
279296 }
280297 }
281- optimizer_->SetLearningRate (current_lr_ );
298+ optimizer_->SetLearningRate (recover_lr_ );
282299}
283300
284301ChainedScheduler::ChainedScheduler (std::shared_ptr<Optimizer> optimizer,
285302 std::vector<std::shared_ptr<LRScheduler>> schedulers, int64_t last_step)
286- : LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)) {}
287-
288- void ChainedScheduler::InitialStep () {
303+ : LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)) {
289304 CHECK (!schedulers_.empty ()) << " ChainedScheduler requires at least one scheduler." ;
290305
291- current_lr_ = optimizer_->GetLearningRate ();
306+ for (size_t i = 0 ; i < schedulers_.size (); ++i) {
307+ CHECK (schedulers_[i]) << " ChainedScheduler: scheduler at index " << i << " must not be null." ;
308+ CHECK (schedulers_[i]->SharesOptimizerWith (optimizer_))
309+ << " ChainedScheduler: scheduler at index " << i << " must share the same optimizer." ;
310+ }
311+ }
312+
313+ void ChainedScheduler::InitialStep () {
314+ last_step_ = 0 ;
292315}
293316
294317void ChainedScheduler::Step () {
295318 ++last_step_;
296319 for (auto &sched : schedulers_) { sched->Step (); }
297- current_lr_ = optimizer_->GetLearningRate ();
298320}
299321
300322StateDict ChainedScheduler::State () const {
@@ -323,4 +345,4 @@ void ChainedScheduler::LoadState(const StateDict &state) {
323345}
324346
325347} // namespace lr_schedulers
326- } // namespace infini_train
348+ } // namespace infini_train
0 commit comments