Skip to content

Commit 327d263

Browse files
kinorwlittleotherut
authored andcommitted
refactor: rename current_lr_ to recover_lr_ and update related methods, add validation tests for learning rate schedulers
- it now only be used for learning rate recovery when using loadstate
1 parent f0012be commit 327d263

File tree

12 files changed

+243
-51
lines changed

12 files changed

+243
-51
lines changed

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,6 @@ link_infini_train_exe(test_sequential_lr)
224224

225225
add_executable(test_chained_lr test/lr_scheduler/test_chained_lr.cc)
226226
link_infini_train_exe(test_chained_lr)
227+
228+
add_executable(test_lr_scheduler_validation test/lr_scheduler/test_lr_scheduler_validation.cc)
229+
link_infini_train_exe(test_lr_scheduler_validation)

infini_train/include/lr_scheduler.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,16 @@ class LRScheduler {
6767
virtual StateDict State() const;
6868
virtual void LoadState(const StateDict &state);
6969

70+
bool SharesOptimizerWith(const std::shared_ptr<Optimizer> &opt) const;
71+
7072
protected:
7173
virtual float GetClosedFormLR() const = 0;
7274
virtual float GetChainedFormLR() const;
7375
void ApplyLR(float lr);
7476

7577
std::shared_ptr<Optimizer> optimizer_;
7678
int64_t last_step_;
77-
float current_lr_;
79+
float recover_lr_;
7880
float base_lr_;
7981
bool is_initial_ = false;
8082
};
@@ -155,7 +157,9 @@ class SequentialLR : public LRScheduler {
155157
void LoadState(const StateDict &state) override;
156158

157159
protected:
158-
float GetClosedFormLR() const override { return current_lr_; }
160+
float GetClosedFormLR() const override {
161+
return base_lr_;
162+
} // FIXME: SequentialLR should not have a closed-form LR, but we need to implement this pure virtual function.
159163
void UndoChildInitialSteps();
160164

161165
private:
@@ -176,11 +180,13 @@ class ChainedScheduler : public LRScheduler {
176180
void LoadState(const StateDict &state) override;
177181

178182
protected:
179-
float GetClosedFormLR() const override { return current_lr_; }
183+
float GetClosedFormLR() const override {
184+
return base_lr_;
185+
} // FIXME: ChainedScheduler should not have a closed-form LR, but we need to implement this pure virtual function.
180186

181187
private:
182188
std::vector<std::shared_ptr<LRScheduler>> schedulers_;
183189
};
184190

185191
} // namespace lr_schedulers
186-
} // namespace infini_train
192+
} // namespace infini_train

infini_train/src/lr_scheduler.cc

Lines changed: 51 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,10 @@ std::shared_ptr<LRScheduler> CreateLRScheduler(std::shared_ptr<Optimizer> optimi
6868
};
6969

7070
LRScheduler::LRScheduler(std::shared_ptr<Optimizer> optimizer, int64_t last_step)
71-
: optimizer_(std::move(optimizer)), last_step_(last_step), current_lr_(0.0f), base_lr_(0.0f) {
71+
: optimizer_(std::move(optimizer)), last_step_(last_step), base_lr_(0.0f) {
7272
CHECK(optimizer_) << "LRScheduler: optimizer must not be null.";
7373
optimizer_->SetInitialLearningRate(optimizer_->GetLearningRate());
7474
base_lr_ = optimizer_->GetInitialLearningRate();
75-
current_lr_ = base_lr_;
7675
}
7776

7877
void LRScheduler::Step() {
@@ -91,34 +90,33 @@ void LRScheduler::InitialStep() {
9190
is_initial_ = false;
9291
}
9392

94-
void LRScheduler::ApplyLR(float lr) {
95-
current_lr_ = lr;
96-
optimizer_->SetLearningRate(current_lr_);
97-
}
93+
void LRScheduler::ApplyLR(float lr) { optimizer_->SetLearningRate(lr); }
9894

9995
float LRScheduler::GetChainedFormLR() const { return GetClosedFormLR(); }
10096

101-
float LRScheduler::GetLR() const { return current_lr_; }
97+
float LRScheduler::GetLR() const { return optimizer_->GetLearningRate(); }
10298

10399
float LRScheduler::BaseLR() const { return base_lr_; }
104100

105101
int64_t LRScheduler::LastStep() const { return last_step_; }
106102

103+
bool LRScheduler::SharesOptimizerWith(const std::shared_ptr<Optimizer> &opt) const { return optimizer_ == opt; }
104+
107105
void LRScheduler::ResetStep(int64_t step) { last_step_ = step; }
108106

109107
StateDict LRScheduler::State() const {
110108
return {
111109
{"last_step", last_step_},
112-
{"current_lr", current_lr_},
110+
{"recover_lr", optimizer_->GetLearningRate()},
113111
{"base_lr", base_lr_},
114112
};
115113
}
116114

117115
void LRScheduler::LoadState(const StateDict &state) {
118116
last_step_ = std::get<int64_t>(state.at("last_step"));
119-
current_lr_ = std::get<float>(state.at("current_lr"));
117+
recover_lr_ = std::get<float>(state.at("recover_lr"));
120118
base_lr_ = std::get<float>(state.at("base_lr"));
121-
optimizer_->SetLearningRate(current_lr_);
119+
optimizer_->SetLearningRate(recover_lr_);
122120
}
123121

124122
// Concrete LR Schedulers
@@ -128,7 +126,10 @@ namespace lr_schedulers {
128126
// --- ConstantLR ---
129127

130128
ConstantLR::ConstantLR(std::shared_ptr<Optimizer> optimizer, float factor, int total_iters, int64_t last_step)
131-
: LRScheduler(std::move(optimizer), last_step), factor_(factor), total_iters_(total_iters) {}
129+
: LRScheduler(std::move(optimizer), last_step), factor_(factor), total_iters_(total_iters) {
130+
CHECK_GE(factor_, 0.0f) << "ConstantLR: factor must be >= 0.";
131+
CHECK_LE(factor_, 1.0f) << "ConstantLR: factor must be <= 1.";
132+
}
132133

133134
float ConstantLR::GetClosedFormLR() const { return last_step_ < total_iters_ ? base_lr_ * factor_ : base_lr_; }
134135

@@ -147,7 +148,10 @@ float ConstantLR::GetChainedFormLR() const {
147148
// --- StepLR ---
148149

149150
StepLR::StepLR(std::shared_ptr<Optimizer> optimizer, int64_t step_size, float gamma, int64_t last_step)
150-
: LRScheduler(std::move(optimizer), last_step), step_size_(step_size), gamma_(gamma) {}
151+
: LRScheduler(std::move(optimizer), last_step), step_size_(step_size), gamma_(gamma) {
152+
CHECK_GT(step_size_, 0) << "StepLR: step_size must be > 0.";
153+
CHECK_GT(gamma_, 0.0f) << "StepLR: gamma must be > 0.";
154+
}
151155

152156
float StepLR::GetClosedFormLR() const {
153157
return base_lr_
@@ -165,7 +169,13 @@ float StepLR::GetChainedFormLR() const {
165169
LinearLR::LinearLR(std::shared_ptr<Optimizer> optimizer, float start_factor, float end_factor, int64_t total_iters,
166170
int64_t last_step)
167171
: LRScheduler(std::move(optimizer), last_step), start_factor_(start_factor), end_factor_(end_factor),
168-
total_iters_(total_iters) {}
172+
total_iters_(total_iters) {
173+
CHECK_GT(start_factor_, 0.0f) << "LinearLR: start_factor must be > 0.";
174+
CHECK_LE(start_factor_, 1.0f) << "LinearLR: start_factor must be <= 1.";
175+
CHECK_GE(end_factor_, 0.0f) << "LinearLR: end_factor must be >= 0.";
176+
CHECK_LE(end_factor_, 1.0f) << "LinearLR: end_factor must be <= 1.";
177+
CHECK_GT(total_iters_, 0) << "LinearLR: total_iters must be > 0.";
178+
}
169179

170180
float LinearLR::GetClosedFormLR() const {
171181
if (last_step_ >= total_iters_) {
@@ -198,31 +208,40 @@ float LinearLR::GetChainedFormLR() const {
198208
}
199209

200210
LambdaLR::LambdaLR(std::shared_ptr<Optimizer> optimizer, std::function<float(int64_t)> lr_lambda, int64_t last_step)
201-
: LRScheduler(std::move(optimizer), last_step), lr_lambda_(std::move(lr_lambda)) {}
211+
: LRScheduler(std::move(optimizer), last_step), lr_lambda_(std::move(lr_lambda)) {
212+
CHECK(lr_lambda_) << "LambdaLR: lr_lambda must not be null.";
213+
}
202214

203215
float LambdaLR::GetClosedFormLR() const { return base_lr_ * lr_lambda_(last_step_); }
204216

205217
SequentialLR::SequentialLR(std::shared_ptr<Optimizer> optimizer, std::vector<std::shared_ptr<LRScheduler>> schedulers,
206218
std::vector<int64_t> milestones, int64_t last_step)
207219
: LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)),
208-
milestones_(std::move(milestones)) {}
209-
210-
void SequentialLR::InitialStep() {
220+
milestones_(std::move(milestones)) {
211221
CHECK(!schedulers_.empty()) << "SequentialLR requires at least one scheduler.";
222+
223+
for (size_t i = 0; i < schedulers_.size(); ++i) {
224+
CHECK(schedulers_[i]) << "SequentialLR: scheduler at index " << i << " must not be null.";
225+
CHECK(schedulers_[i]->SharesOptimizerWith(optimizer_))
226+
<< "SequentialLR: scheduler at index " << i << " must share the same optimizer.";
227+
}
228+
212229
CHECK_EQ(milestones_.size(), schedulers_.size() - 1)
213230
<< "SequentialLR: milestones count must be schedulers count - 1.";
214231

215232
for (size_t i = 1; i < milestones_.size(); ++i) {
216233
CHECK_GT(milestones_[i], milestones_[i - 1]) << "Milestones must be strictly increasing.";
217234
}
235+
}
236+
237+
void SequentialLR::InitialStep() {
218238

219239
optimizer_->SetLearningRate(schedulers_[0]->BaseLR());
220240

221241
UndoChildInitialSteps();
222242

223243
++last_step_;
224244
schedulers_[0]->InitialStep();
225-
current_lr_ = schedulers_[0]->GetLR();
226245
}
227246

228247
void SequentialLR::UndoChildInitialSteps() {
@@ -245,14 +264,12 @@ void SequentialLR::Step() {
245264
} else {
246265
scheduler->Step();
247266
}
248-
249-
current_lr_ = optimizer_->GetLearningRate();
250267
}
251268

252269
StateDict SequentialLR::State() const {
253270
StateDict state;
254271
state["last_step"] = last_step_;
255-
state["current_lr"] = current_lr_;
272+
state["recover_lr"] = optimizer_->GetLearningRate();
256273
state["base_lr"] = base_lr_;
257274
for (size_t i = 0; i < schedulers_.size(); ++i) {
258275
auto sub_state = schedulers_[i]->State();
@@ -263,7 +280,7 @@ StateDict SequentialLR::State() const {
263280

264281
void SequentialLR::LoadState(const StateDict &state) {
265282
last_step_ = std::get<int64_t>(state.at("last_step"));
266-
current_lr_ = std::get<float>(state.at("current_lr"));
283+
recover_lr_ = std::get<float>(state.at("recover_lr"));
267284
base_lr_ = std::get<float>(state.at("base_lr"));
268285

269286
for (size_t i = 0; i < schedulers_.size(); ++i) {
@@ -278,23 +295,28 @@ void SequentialLR::LoadState(const StateDict &state) {
278295
schedulers_[i]->LoadState(sub_state);
279296
}
280297
}
281-
optimizer_->SetLearningRate(current_lr_);
298+
optimizer_->SetLearningRate(recover_lr_);
282299
}
283300

284301
ChainedScheduler::ChainedScheduler(std::shared_ptr<Optimizer> optimizer,
285302
std::vector<std::shared_ptr<LRScheduler>> schedulers, int64_t last_step)
286-
: LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)) {}
287-
288-
void ChainedScheduler::InitialStep() {
303+
: LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)) {
289304
CHECK(!schedulers_.empty()) << "ChainedScheduler requires at least one scheduler.";
290305

291-
current_lr_ = optimizer_->GetLearningRate();
306+
for (size_t i = 0; i < schedulers_.size(); ++i) {
307+
CHECK(schedulers_[i]) << "ChainedScheduler: scheduler at index " << i << " must not be null.";
308+
CHECK(schedulers_[i]->SharesOptimizerWith(optimizer_))
309+
<< "ChainedScheduler: scheduler at index " << i << " must share the same optimizer.";
310+
}
311+
}
312+
313+
void ChainedScheduler::InitialStep() {
314+
last_step_ = 0;
292315
}
293316

294317
void ChainedScheduler::Step() {
295318
++last_step_;
296319
for (auto &sched : schedulers_) { sched->Step(); }
297-
current_lr_ = optimizer_->GetLearningRate();
298320
}
299321

300322
StateDict ChainedScheduler::State() const {
@@ -323,4 +345,4 @@ void ChainedScheduler::LoadState(const StateDict &state) {
323345
}
324346

325347
} // namespace lr_schedulers
326-
} // namespace infini_train
348+
} // namespace infini_train

test/lr_scheduler/test_chained_lr.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using namespace infini_train::lr_schedulers;
77
namespace {
88
constexpr float kBaseLR = 0.1f;
99
}
10-
// TC1: 单子调度器退化
10+
1111
void TestSingleScheduler() {
1212
std::cout << "[TC1] TestSingleScheduler" << std::endl;
1313
auto opt = MakeDummyOptimizer(kBaseLR);
@@ -23,7 +23,7 @@ void TestSingleScheduler() {
2323
ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps);
2424
}
2525

26-
// TC2: StepLR + LambdaLR 乘法叠加
26+
// TC2: StepLR + LambdaLR
2727
void TestMultiplicativeChain() {
2828
std::cout << "[TC2] TestMultiplicativeChain" << std::endl;
2929
auto opt = MakeDummyOptimizer(kBaseLR);
@@ -53,7 +53,7 @@ void TestMultiplicativeChain() {
5353
ASSERT_FLOAT_NEAR(sched->GetLR(), 0.07f, kEps);
5454
}
5555

56-
// TC3: ConstantLR + StepLR 叠加 (无穿插声明)
56+
// TC3: ConstantLR + StepLR
5757
void TestConstantPlusStep() {
5858
std::cout << "[TC3] TestConstantPlusStep" << std::endl;
5959
auto opt = MakeDummyOptimizer(kBaseLR);
@@ -86,7 +86,7 @@ void TestConstantPlusStep() {
8686
ASSERT_FLOAT_NEAR(sched->GetLR(), 0.01f, kEps);
8787
}
8888

89-
// TC4: ConstantLR + StepLR 叠加(有穿插声明
89+
// TC4: ConstantLR + StepLR (with extra unused scheduler
9090
void TestConstantPlusStepDLC() {
9191
std::cout << "[TC4] TestConstantPlusStepDLC" << std::endl;
9292
auto opt = MakeDummyOptimizer(kBaseLR);
@@ -129,7 +129,7 @@ void TestConstantPlusStepDLC() {
129129
ASSERT_FLOAT_NEAR(sched->GetLR(), 0.02f, kEps);
130130
}
131131

132-
// TC5: State/LoadState 往返
132+
// TC5: State/LoadState
133133
void TestStateRoundTrip() {
134134
std::cout << "[TC5] TestStateRoundTrip" << std::endl;
135135
auto opt = MakeDummyOptimizer(kBaseLR);
@@ -152,7 +152,7 @@ void TestStateRoundTrip() {
152152
ASSERT_FLOAT_NEAR(sched2->GetLR(), sched->GetLR(), kEps);
153153
}
154154

155-
// TC6: resume 一致性
155+
// TC6: resume consistency (load state at step K, then step N-K, should match directly stepping to N)
156156
void TestResumeConsistency() {
157157
std::cout << "[TC6] TestResumeConsistency" << std::endl;
158158
constexpr int kN = 10, kK = 4;
@@ -199,4 +199,4 @@ int main(int argc, char *argv[]) {
199199
std::cout << g_fail_count << " test(s) FAILED" << std::endl;
200200
}
201201
return g_fail_count > 0 ? 1 : 0;
202-
}
202+
}

test/lr_scheduler/test_constant_lr.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,4 +181,4 @@ int main(int argc, char *argv[]) {
181181
std::cout << g_fail_count << " test(s) FAILED" << std::endl;
182182
}
183183
return g_fail_count > 0 ? 1 : 0;
184-
}
184+
}

test/lr_scheduler/test_helpers.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@ void Check(bool cond, const char *expr, int line) {
3232
#define ASSERT_FLOAT_EQ(a, b) Check(FloatNear((a), (b)), #a " == " #b, __LINE__)
3333
#define ASSERT_FLOAT_NEAR(a, b, eps) Check(FloatNear((a), (b), (eps)), #a "" #b, __LINE__)
3434

35-
} // namespace
35+
} // namespace

test/lr_scheduler/test_lambda_lr.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ void TestIdentityLambda() {
1414
.type = "lambda",
1515
.lambda_fn = [](int64_t) { return 1.0f; },
1616
});
17-
// 构造器内 Step() → last_step_=0, lr = 0.1 * 1.0 = 0.1
17+
// Step() → last_step_=0, lr = 0.1 * 1.0 = 0.1
1818
ASSERT_TRUE(sched->LastStep() == 0);
1919
ASSERT_FLOAT_NEAR(sched->GetLR(), kBaseLR, kEps);
2020
ASSERT_FLOAT_NEAR(opt->GetLearningRate(), kBaseLR, kEps);
@@ -124,4 +124,4 @@ int main(int argc, char *argv[]) {
124124
std::cout << g_fail_count << " test(s) FAILED" << std::endl;
125125
}
126126
return g_fail_count > 0 ? 1 : 0;
127-
}
127+
}

test/lr_scheduler/test_linear_lr.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,4 +137,4 @@ int main(int argc, char *argv[]) {
137137
std::cout << g_fail_count << " test(s) FAILED" << std::endl;
138138
}
139139
return g_fail_count > 0 ? 1 : 0;
140-
}
140+
}

test/lr_scheduler/test_lr_scheduler.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ void TestLinearDecay() {
103103
ASSERT_FLOAT_EQ(opt->GetLearningRate(), 0.05f);
104104
}
105105

106-
// T4: State → LoadState 往返一致性。
106+
// T4: State → LoadState
107107
void TestStateRoundTrip() {
108108
std::cout << "[T4] TestStateRoundTrip" << std::endl;
109109
constexpr int64_t kTotalSteps = 20;
@@ -115,7 +115,7 @@ void TestStateRoundTrip() {
115115
StateDict saved = sched->State();
116116

117117
ASSERT_TRUE(saved.count("last_step") == 1);
118-
ASSERT_TRUE(saved.count("current_lr") == 1);
118+
ASSERT_TRUE(saved.count("recover_lr") == 1);
119119
ASSERT_TRUE(saved.count("base_lr") == 1);
120120

121121
auto opt2 = MakeDummyOptimizer(kBaseLR);
@@ -175,4 +175,4 @@ int main(int argc, char *argv[]) {
175175
std::cout << "========================================" << std::endl;
176176

177177
return g_fail_count > 0 ? 1 : 0;
178-
}
178+
}

0 commit comments

Comments
 (0)