Skip to content

Commit 151dda0

Browse files
author
kinorw
committed
refactor: rename current_lr_ to recover_lr_ and update related methods
- it now only be used for learning rate recovery when using loadstate
1 parent afd98ff commit 151dda0

3 files changed

Lines changed: 17 additions & 23 deletions

File tree

infini_train/include/lr_scheduler.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class LRScheduler {
7474

7575
std::shared_ptr<Optimizer> optimizer_;
7676
int64_t last_step_;
77-
float current_lr_;
77+
float recover_lr_;
7878
float base_lr_;
7979
bool is_initial_ = false;
8080
};
@@ -155,7 +155,7 @@ class SequentialLR : public LRScheduler {
155155
void LoadState(const StateDict &state) override;
156156

157157
protected:
158-
float GetClosedFormLR() const override { return current_lr_; }
158+
float GetClosedFormLR() const override { return base_lr_; } // FIXME: SequentialLR should not have a closed-form LR, but we need to implement this pure virtual function.
159159
void UndoChildInitialSteps();
160160

161161
private:
@@ -176,11 +176,11 @@ class ChainedScheduler : public LRScheduler {
176176
void LoadState(const StateDict &state) override;
177177

178178
protected:
179-
float GetClosedFormLR() const override { return current_lr_; }
179+
float GetClosedFormLR() const override { return base_lr_; } // FIXME: ChainedScheduler should not have a closed-form LR, but we need to implement this pure virtual function.
180180

181181
private:
182182
std::vector<std::shared_ptr<LRScheduler>> schedulers_;
183183
};
184184

185185
} // namespace lr_schedulers
186-
} // namespace infini_train
186+
} // namespace infini_train

infini_train/src/lr_scheduler.cc

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,10 @@ std::shared_ptr<LRScheduler> CreateLRScheduler(std::shared_ptr<Optimizer> optimi
6868
};
6969

7070
LRScheduler::LRScheduler(std::shared_ptr<Optimizer> optimizer, int64_t last_step)
71-
: optimizer_(std::move(optimizer)), last_step_(last_step), current_lr_(0.0f), base_lr_(0.0f) {
71+
: optimizer_(std::move(optimizer)), last_step_(last_step), base_lr_(0.0f) {
7272
CHECK(optimizer_) << "LRScheduler: optimizer must not be null.";
7373
optimizer_->SetInitialLearningRate(optimizer_->GetLearningRate());
7474
base_lr_ = optimizer_->GetInitialLearningRate();
75-
current_lr_ = base_lr_;
7675
}
7776

7877
void LRScheduler::Step() {
@@ -92,13 +91,12 @@ void LRScheduler::InitialStep() {
9291
}
9392

9493
void LRScheduler::ApplyLR(float lr) {
95-
current_lr_ = lr;
96-
optimizer_->SetLearningRate(current_lr_);
94+
optimizer_->SetLearningRate(lr);
9795
}
9896

9997
float LRScheduler::GetChainedFormLR() const { return GetClosedFormLR(); }
10098

101-
float LRScheduler::GetLR() const { return current_lr_; }
99+
float LRScheduler::GetLR() const { return optimizer_->GetLearningRate(); }
102100

103101
float LRScheduler::BaseLR() const { return base_lr_; }
104102

@@ -109,16 +107,16 @@ void LRScheduler::ResetStep(int64_t step) { last_step_ = step; }
109107
StateDict LRScheduler::State() const {
110108
return {
111109
{"last_step", last_step_},
112-
{"current_lr", current_lr_},
110+
{"recover_lr", optimizer_->GetLearningRate()},
113111
{"base_lr", base_lr_},
114112
};
115113
}
116114

117115
void LRScheduler::LoadState(const StateDict &state) {
118116
last_step_ = std::get<int64_t>(state.at("last_step"));
119-
current_lr_ = std::get<float>(state.at("current_lr"));
117+
recover_lr_ = std::get<float>(state.at("recover_lr"));
120118
base_lr_ = std::get<float>(state.at("base_lr"));
121-
optimizer_->SetLearningRate(current_lr_);
119+
optimizer_->SetLearningRate(recover_lr_);
122120
}
123121

124122
// Concrete LR Schedulers
@@ -222,7 +220,6 @@ void SequentialLR::InitialStep() {
222220

223221
++last_step_;
224222
schedulers_[0]->InitialStep();
225-
current_lr_ = schedulers_[0]->GetLR();
226223
}
227224

228225
void SequentialLR::UndoChildInitialSteps() {
@@ -246,13 +243,12 @@ void SequentialLR::Step() {
246243
scheduler->Step();
247244
}
248245

249-
current_lr_ = optimizer_->GetLearningRate();
250246
}
251247

252248
StateDict SequentialLR::State() const {
253249
StateDict state;
254250
state["last_step"] = last_step_;
255-
state["current_lr"] = current_lr_;
251+
state["recover_lr"] = optimizer_->GetLearningRate();
256252
state["base_lr"] = base_lr_;
257253
for (size_t i = 0; i < schedulers_.size(); ++i) {
258254
auto sub_state = schedulers_[i]->State();
@@ -263,7 +259,7 @@ StateDict SequentialLR::State() const {
263259

264260
void SequentialLR::LoadState(const StateDict &state) {
265261
last_step_ = std::get<int64_t>(state.at("last_step"));
266-
current_lr_ = std::get<float>(state.at("current_lr"));
262+
recover_lr_ = std::get<float>(state.at("recover_lr"));
267263
base_lr_ = std::get<float>(state.at("base_lr"));
268264

269265
for (size_t i = 0; i < schedulers_.size(); ++i) {
@@ -278,7 +274,7 @@ void SequentialLR::LoadState(const StateDict &state) {
278274
schedulers_[i]->LoadState(sub_state);
279275
}
280276
}
281-
optimizer_->SetLearningRate(current_lr_);
277+
optimizer_->SetLearningRate(recover_lr_);
282278
}
283279

284280
ChainedScheduler::ChainedScheduler(std::shared_ptr<Optimizer> optimizer,
@@ -288,13 +284,11 @@ ChainedScheduler::ChainedScheduler(std::shared_ptr<Optimizer> optimizer,
288284
void ChainedScheduler::InitialStep() {
289285
CHECK(!schedulers_.empty()) << "ChainedScheduler requires at least one scheduler.";
290286

291-
current_lr_ = optimizer_->GetLearningRate();
292287
}
293288

294289
void ChainedScheduler::Step() {
295290
++last_step_;
296291
for (auto &sched : schedulers_) { sched->Step(); }
297-
current_lr_ = optimizer_->GetLearningRate();
298292
}
299293

300294
StateDict ChainedScheduler::State() const {
@@ -323,4 +317,4 @@ void ChainedScheduler::LoadState(const StateDict &state) {
323317
}
324318

325319
} // namespace lr_schedulers
326-
} // namespace infini_train
320+
} // namespace infini_train

test/lr_scheduler/test_lr_scheduler.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ void TestLinearDecay() {
103103
ASSERT_FLOAT_EQ(opt->GetLearningRate(), 0.05f);
104104
}
105105

106-
// T4: State → LoadState 往返一致性。
106+
// T4: State → LoadState
107107
void TestStateRoundTrip() {
108108
std::cout << "[T4] TestStateRoundTrip" << std::endl;
109109
constexpr int64_t kTotalSteps = 20;
@@ -115,7 +115,7 @@ void TestStateRoundTrip() {
115115
StateDict saved = sched->State();
116116

117117
ASSERT_TRUE(saved.count("last_step") == 1);
118-
ASSERT_TRUE(saved.count("current_lr") == 1);
118+
ASSERT_TRUE(saved.count("recover_lr") == 1);
119119
ASSERT_TRUE(saved.count("base_lr") == 1);
120120

121121
auto opt2 = MakeDummyOptimizer(kBaseLR);
@@ -175,4 +175,4 @@ int main(int argc, char *argv[]) {
175175
std::cout << "========================================" << std::endl;
176176

177177
return g_fail_count > 0 ? 1 : 0;
178-
}
178+
}

0 commit comments

Comments
 (0)