@@ -18,7 +18,7 @@ std::shared_ptr<LRScheduler> CreateLRScheduler(std::shared_ptr<Optimizer> optimi
1818 }
1919
2020 CHECK (optimizer) << " CreateLRScheduler: optimizer must not be null." ;
21- const float max_lr = config.lr != 0 .0f ? config.lr : optimizer->GetLearningRate ();
21+ const float max_lr = config.lr != 0 .0f ? config.lr : optimizer->learning_rate ();
2222 CHECK_GT (max_lr, 0 .0f ) << " CreateLRScheduler: max_lr must be > 0." ;
2323 CHECK_GE (config.lr_warmup_init , 0 .0f ) << " CreateLRScheduler: lr_warmup_init must be >= 0." ;
2424 CHECK_GE (config.min_lr , 0 .0f ) << " CreateLRScheduler: min_lr must be >= 0." ;
@@ -86,8 +86,8 @@ std::shared_ptr<LRScheduler> CreateLRScheduler(std::shared_ptr<Optimizer> optimi
8686LRScheduler::LRScheduler (std::shared_ptr<Optimizer> optimizer, int64_t last_step)
8787 : optimizer_(std::move(optimizer)), last_step_(last_step), base_lr_(0 .0f ) {
8888 CHECK (optimizer_) << " LRScheduler: optimizer must not be null." ;
89- optimizer_->SetInitialLearningRate (optimizer_->GetLearningRate ());
90- base_lr_ = optimizer_->GetInitialLearningRate ();
89+ optimizer_->set_initial_learning_rate (optimizer_->learning_rate ());
90+ base_lr_ = optimizer_->initial_learning_rate ();
9191}
9292
9393void LRScheduler::Step () {
@@ -106,11 +106,11 @@ void LRScheduler::InitialStep() {
106106 is_initial_ = false ;
107107}
108108
109- void LRScheduler::ApplyLR (float lr) { optimizer_->SetLearningRate (lr); }
109+ void LRScheduler::ApplyLR (float lr) { optimizer_->set_learning_rate (lr); }
110110
111111float LRScheduler::GetChainedFormLR () const { return GetClosedFormLR (); }
112112
113- float LRScheduler::GetLR () const { return optimizer_->GetLearningRate (); }
113+ float LRScheduler::GetLR () const { return optimizer_->learning_rate (); }
114114
115115float LRScheduler::BaseLR () const { return base_lr_; }
116116
@@ -123,7 +123,7 @@ void LRScheduler::ResetStep(int64_t step) { last_step_ = step; }
123123StateDict LRScheduler::State () const {
124124 return {
125125 {" last_step" , last_step_},
126- {" recover_lr" , optimizer_->GetLearningRate ()},
126+ {" recover_lr" , optimizer_->learning_rate ()},
127127 {" base_lr" , base_lr_},
128128 };
129129}
@@ -132,7 +132,7 @@ void LRScheduler::LoadState(const StateDict &state) {
132132 last_step_ = std::get<int64_t >(state.at (" last_step" ));
133133 recover_lr_ = std::get<float >(state.at (" recover_lr" ));
134134 base_lr_ = std::get<float >(state.at (" base_lr" ));
135- optimizer_->SetLearningRate (recover_lr_);
135+ optimizer_->set_learning_rate (recover_lr_);
136136}
137137
138138// Concrete LR Schedulers
@@ -150,7 +150,7 @@ ConstantLR::ConstantLR(std::shared_ptr<Optimizer> optimizer, float factor, int t
150150float ConstantLR::GetClosedFormLR () const { return last_step_ < total_iters_ ? base_lr_ * factor_ : base_lr_; }
151151
152152float ConstantLR::GetChainedFormLR () const {
153- const float lr = optimizer_->GetLearningRate ();
153+ const float lr = optimizer_->learning_rate ();
154154 if (last_step_ == 0 ) {
155155 return lr * factor_;
156156 } else if (last_step_ < total_iters_) {
@@ -175,7 +175,7 @@ float StepLR::GetClosedFormLR() const {
175175}
176176
177177float StepLR::GetChainedFormLR () const {
178- const float lr = optimizer_->GetLearningRate ();
178+ const float lr = optimizer_->learning_rate ();
179179 if (last_step_ == 0 || (last_step_ % step_size_) != 0 ) {
180180 return lr;
181181 }
@@ -203,7 +203,7 @@ float LinearLR::GetClosedFormLR() const {
203203}
204204
205205float LinearLR::GetChainedFormLR () const {
206- const float lr = optimizer_->GetLearningRate ();
206+ const float lr = optimizer_->learning_rate ();
207207 if (last_step_ == 0 ) {
208208 return lr * start_factor_;
209209 }
@@ -252,7 +252,7 @@ SequentialLR::SequentialLR(std::shared_ptr<Optimizer> optimizer, std::vector<std
252252
253253void SequentialLR::InitialStep () {
254254
255- optimizer_->SetLearningRate (schedulers_[0 ]->BaseLR ());
255+ optimizer_->set_learning_rate (schedulers_[0 ]->BaseLR ());
256256
257257 UndoChildInitialSteps ();
258258
@@ -285,7 +285,7 @@ void SequentialLR::Step() {
285285StateDict SequentialLR::State () const {
286286 StateDict state;
287287 state[" last_step" ] = last_step_;
288- state[" recover_lr" ] = optimizer_->GetLearningRate ();
288+ state[" recover_lr" ] = optimizer_->learning_rate ();
289289 state[" base_lr" ] = base_lr_;
290290 for (size_t i = 0 ; i < schedulers_.size (); ++i) {
291291 auto sub_state = schedulers_[i]->State ();
@@ -311,7 +311,7 @@ void SequentialLR::LoadState(const StateDict &state) {
311311 schedulers_[i]->LoadState (sub_state);
312312 }
313313 }
314- optimizer_->SetLearningRate (recover_lr_);
314+ optimizer_->set_learning_rate (recover_lr_);
315315}
316316
317317ChainedScheduler::ChainedScheduler (std::shared_ptr<Optimizer> optimizer,
0 commit comments