@@ -68,11 +68,10 @@ std::shared_ptr<LRScheduler> CreateLRScheduler(std::shared_ptr<Optimizer> optimi
6868};
6969
7070LRScheduler::LRScheduler (std::shared_ptr<Optimizer> optimizer, int64_t last_step)
71- : optimizer_(std::move(optimizer)), last_step_(last_step), current_lr_( 0 . 0f ), base_lr_(0 .0f ) {
71+ : optimizer_(std::move(optimizer)), last_step_(last_step), base_lr_(0 .0f ) {
7272 CHECK (optimizer_) << " LRScheduler: optimizer must not be null." ;
7373 optimizer_->SetInitialLearningRate (optimizer_->GetLearningRate ());
7474 base_lr_ = optimizer_->GetInitialLearningRate ();
75- current_lr_ = base_lr_;
7675}
7776
7877void LRScheduler::Step () {
@@ -92,13 +91,12 @@ void LRScheduler::InitialStep() {
9291}
9392
9493void LRScheduler::ApplyLR (float lr) {
95- current_lr_ = lr;
96- optimizer_->SetLearningRate (current_lr_);
94+ optimizer_->SetLearningRate (lr);
9795}
9896
9997float LRScheduler::GetChainedFormLR () const { return GetClosedFormLR (); }
10098
101- float LRScheduler::GetLR () const { return current_lr_ ; }
99+ float LRScheduler::GetLR () const { return optimizer_-> GetLearningRate () ; }
102100
103101float LRScheduler::BaseLR () const { return base_lr_; }
104102
@@ -109,16 +107,16 @@ void LRScheduler::ResetStep(int64_t step) { last_step_ = step; }
109107StateDict LRScheduler::State () const {
110108 return {
111109 {" last_step" , last_step_},
112- {" current_lr " , current_lr_ },
110+ {" recover_lr " , optimizer_-> GetLearningRate () },
113111 {" base_lr" , base_lr_},
114112 };
115113}
116114
117115void LRScheduler::LoadState (const StateDict &state) {
118116 last_step_ = std::get<int64_t >(state.at (" last_step" ));
119- current_lr_ = std::get<float >(state.at (" current_lr " ));
117+ recover_lr_ = std::get<float >(state.at (" recover_lr " ));
120118 base_lr_ = std::get<float >(state.at (" base_lr" ));
121- optimizer_->SetLearningRate (current_lr_ );
119+ optimizer_->SetLearningRate (recover_lr_ );
122120}
123121
124122// Concrete LR Schedulers
@@ -222,7 +220,6 @@ void SequentialLR::InitialStep() {
222220
223221 ++last_step_;
224222 schedulers_[0 ]->InitialStep ();
225- current_lr_ = schedulers_[0 ]->GetLR ();
226223}
227224
228225void SequentialLR::UndoChildInitialSteps () {
@@ -246,13 +243,12 @@ void SequentialLR::Step() {
246243 scheduler->Step ();
247244 }
248245
249- current_lr_ = optimizer_->GetLearningRate ();
250246}
251247
252248StateDict SequentialLR::State () const {
253249 StateDict state;
254250 state[" last_step" ] = last_step_;
255- state[" current_lr " ] = current_lr_ ;
251+ state[" recover_lr " ] = optimizer_-> GetLearningRate () ;
256252 state[" base_lr" ] = base_lr_;
257253 for (size_t i = 0 ; i < schedulers_.size (); ++i) {
258254 auto sub_state = schedulers_[i]->State ();
@@ -263,7 +259,7 @@ StateDict SequentialLR::State() const {
263259
264260void SequentialLR::LoadState (const StateDict &state) {
265261 last_step_ = std::get<int64_t >(state.at (" last_step" ));
266- current_lr_ = std::get<float >(state.at (" current_lr " ));
262+ recover_lr_ = std::get<float >(state.at (" recover_lr " ));
267263 base_lr_ = std::get<float >(state.at (" base_lr" ));
268264
269265 for (size_t i = 0 ; i < schedulers_.size (); ++i) {
@@ -278,7 +274,7 @@ void SequentialLR::LoadState(const StateDict &state) {
278274 schedulers_[i]->LoadState (sub_state);
279275 }
280276 }
281- optimizer_->SetLearningRate (current_lr_ );
277+ optimizer_->SetLearningRate (recover_lr_ );
282278}
283279
284280ChainedScheduler::ChainedScheduler (std::shared_ptr<Optimizer> optimizer,
@@ -288,13 +284,11 @@ ChainedScheduler::ChainedScheduler(std::shared_ptr<Optimizer> optimizer,
288284void ChainedScheduler::InitialStep () {
289285 CHECK (!schedulers_.empty ()) << " ChainedScheduler requires at least one scheduler." ;
290286
291- current_lr_ = optimizer_->GetLearningRate ();
292287}
293288
294289void ChainedScheduler::Step () {
295290 ++last_step_;
296291 for (auto &sched : schedulers_) { sched->Step (); }
297- current_lr_ = optimizer_->GetLearningRate ();
298292}
299293
300294StateDict ChainedScheduler::State () const {
@@ -323,4 +317,4 @@ void ChainedScheduler::LoadState(const StateDict &state) {
323317}
324318
325319} // namespace lr_schedulers
326- } // namespace infini_train
320+ } // namespace infini_train
0 commit comments