Skip to content

Commit 2980f93

Browse files
author
kinorw
committed
feat: add validation tests for learning rate schedulers
1 parent 151dda0 commit 2980f93

4 files changed

Lines changed: 207 additions & 11 deletions

File tree

CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,4 +220,7 @@ add_executable(test_sequential_lr test/lr_scheduler/test_sequential_lr.cc)
220220
target_link_libraries(test_sequential_lr infini_train)
221221

222222
add_executable(test_chained_lr test/lr_scheduler/test_chained_lr.cc)
223-
target_link_libraries(test_chained_lr infini_train)
223+
target_link_libraries(test_chained_lr infini_train)
224+
225+
add_executable(test_lr_scheduler_validation test/lr_scheduler/test_lr_scheduler_validation.cc)
226+
target_link_libraries(test_lr_scheduler_validation infini_train)

infini_train/include/lr_scheduler.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ class LRScheduler {
6767
virtual StateDict State() const;
6868
virtual void LoadState(const StateDict &state);
6969

70+
bool SharesOptimizerWith(const std::shared_ptr<Optimizer> &opt) const ;
71+
7072
protected:
7173
virtual float GetClosedFormLR() const = 0;
7274
virtual float GetChainedFormLR() const;

infini_train/src/lr_scheduler.cc

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ float LRScheduler::BaseLR() const { return base_lr_; }
102102

103103
int64_t LRScheduler::LastStep() const { return last_step_; }
104104

105+
bool LRScheduler::SharesOptimizerWith(const std::shared_ptr<Optimizer> &opt) const { return optimizer_ == opt; }
106+
105107
void LRScheduler::ResetStep(int64_t step) { last_step_ = step; }
106108

107109
StateDict LRScheduler::State() const {
@@ -126,7 +128,10 @@ namespace lr_schedulers {
126128
// --- ConstantLR ---
127129

128130
ConstantLR::ConstantLR(std::shared_ptr<Optimizer> optimizer, float factor, int total_iters, int64_t last_step)
129-
: LRScheduler(std::move(optimizer), last_step), factor_(factor), total_iters_(total_iters) {}
131+
: LRScheduler(std::move(optimizer), last_step), factor_(factor), total_iters_(total_iters) {
132+
CHECK_GE(factor_, 0.0f) << "ConstantLR: factor must be >= 0.";
133+
CHECK_LE(factor_, 1.0f) << "ConstantLR: factor must be <= 1.";
134+
}
130135

131136
float ConstantLR::GetClosedFormLR() const { return last_step_ < total_iters_ ? base_lr_ * factor_ : base_lr_; }
132137

@@ -145,7 +150,9 @@ float ConstantLR::GetChainedFormLR() const {
145150
// --- StepLR ---
146151

147152
StepLR::StepLR(std::shared_ptr<Optimizer> optimizer, int64_t step_size, float gamma, int64_t last_step)
148-
: LRScheduler(std::move(optimizer), last_step), step_size_(step_size), gamma_(gamma) {}
153+
: LRScheduler(std::move(optimizer), last_step), step_size_(step_size), gamma_(gamma) {
154+
CHECK_GT(step_size_, 0) << "StepLR: step_size must be > 0.";
155+
}
149156

150157
float StepLR::GetClosedFormLR() const {
151158
return base_lr_
@@ -163,7 +170,13 @@ float StepLR::GetChainedFormLR() const {
163170
LinearLR::LinearLR(std::shared_ptr<Optimizer> optimizer, float start_factor, float end_factor, int64_t total_iters,
164171
int64_t last_step)
165172
: LRScheduler(std::move(optimizer), last_step), start_factor_(start_factor), end_factor_(end_factor),
166-
total_iters_(total_iters) {}
173+
total_iters_(total_iters) {
174+
CHECK_GT(start_factor_, 0.0f) << "LinearLR: start_factor must be > 0.";
175+
CHECK_LE(start_factor_, 1.0f) << "LinearLR: start_factor must be <= 1.";
176+
CHECK_GE(end_factor_, 0.0f) << "LinearLR: end_factor must be >= 0.";
177+
CHECK_LE(end_factor_, 1.0f) << "LinearLR: end_factor must be <= 1.";
178+
CHECK_GT(total_iters_, 0) << "LinearLR: total_iters must be > 0.";
179+
}
167180

168181
float LinearLR::GetClosedFormLR() const {
169182
if (last_step_ >= total_iters_) {
@@ -196,24 +209,35 @@ float LinearLR::GetChainedFormLR() const {
196209
}
197210

198211
LambdaLR::LambdaLR(std::shared_ptr<Optimizer> optimizer, std::function<float(int64_t)> lr_lambda, int64_t last_step)
199-
: LRScheduler(std::move(optimizer), last_step), lr_lambda_(std::move(lr_lambda)) {}
212+
: LRScheduler(std::move(optimizer), last_step), lr_lambda_(std::move(lr_lambda)) {
213+
CHECK(lr_lambda_) << "LambdaLR: lr_lambda must not be null.";
214+
}
200215

201216
float LambdaLR::GetClosedFormLR() const { return base_lr_ * lr_lambda_(last_step_); }
202217

203218
SequentialLR::SequentialLR(std::shared_ptr<Optimizer> optimizer, std::vector<std::shared_ptr<LRScheduler>> schedulers,
204219
std::vector<int64_t> milestones, int64_t last_step)
205220
: LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)),
206-
milestones_(std::move(milestones)) {}
207-
208-
void SequentialLR::InitialStep() {
221+
milestones_(std::move(milestones)) {
209222
CHECK(!schedulers_.empty()) << "SequentialLR requires at least one scheduler.";
223+
224+
for (size_t i = 0; i < schedulers_.size(); ++i) {
225+
CHECK(schedulers_[i]) << "SequentialLR: scheduler at index " << i << " must not be null.";
226+
CHECK(schedulers_[i]->SharesOptimizerWith(optimizer_))
227+
<< "SequentialLR: scheduler at index " << i << " must share the same optimizer.";
228+
}
229+
210230
CHECK_EQ(milestones_.size(), schedulers_.size() - 1)
211231
<< "SequentialLR: milestones count must be schedulers count - 1.";
212232

213233
for (size_t i = 1; i < milestones_.size(); ++i) {
214234
CHECK_GT(milestones_[i], milestones_[i - 1]) << "Milestones must be strictly increasing.";
215235
}
216236

237+
}
238+
239+
void SequentialLR::InitialStep() {
240+
217241
optimizer_->SetLearningRate(schedulers_[0]->BaseLR());
218242

219243
UndoChildInitialSteps();
@@ -279,11 +303,17 @@ void SequentialLR::LoadState(const StateDict &state) {
279303

280304
ChainedScheduler::ChainedScheduler(std::shared_ptr<Optimizer> optimizer,
281305
std::vector<std::shared_ptr<LRScheduler>> schedulers, int64_t last_step)
282-
: LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)) {}
283-
284-
void ChainedScheduler::InitialStep() {
306+
: LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)) {
285307
CHECK(!schedulers_.empty()) << "ChainedScheduler requires at least one scheduler.";
286308

309+
for (size_t i = 0; i < schedulers_.size(); ++i) {
310+
CHECK(schedulers_[i]) << "ChainedScheduler: scheduler at index " << i << " must not be null.";
311+
CHECK(schedulers_[i]->SharesOptimizerWith(optimizer_))
312+
<< "ChainedScheduler: scheduler at index " << i << " must share the same optimizer.";
313+
}
314+
}
315+
316+
void ChainedScheduler::InitialStep() {
287317
}
288318

289319
void ChainedScheduler::Step() {
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
#include <functional>
2+
#include <iostream>
3+
#include <memory>
4+
#include <sys/wait.h>
5+
#include <unistd.h>
6+
#include <vector>
7+
8+
#include "infini_train/include/lr_scheduler.h"
9+
#include "test/lr_scheduler/test_helpers.h"
10+
11+
using namespace infini_train;
12+
using namespace infini_train::lr_schedulers;
13+
14+
namespace {
15+
16+
bool ExpectDeath(const std::function<void()> &fn) {
17+
pid_t pid = fork();
18+
if (pid == -1) {
19+
return false;
20+
}
21+
22+
if (pid == 0) {
23+
fn();
24+
_exit(0);
25+
}
26+
27+
int status = 0;
28+
if (waitpid(pid, &status, 0) == -1) {
29+
return false;
30+
}
31+
32+
return !WIFEXITED(status) || WEXITSTATUS(status) != 0;
33+
}
34+
35+
void TestStepLRRejectsNonPositiveStepSize() {
36+
ASSERT_TRUE(ExpectDeath([] {
37+
auto opt = MakeDummyOptimizer(0.1f);
38+
auto sched = CreateLRScheduler(opt, {
39+
.type = "step",
40+
.step_size = 0,
41+
.step_gamma = 0.1f,
42+
});
43+
(void)sched;
44+
}));
45+
}
46+
47+
void TestLinearLRRejectsNonPositiveTotalIters() {
48+
ASSERT_TRUE(ExpectDeath([] {
49+
auto opt = MakeDummyOptimizer(0.1f);
50+
auto sched = CreateLRScheduler(opt, {
51+
.type = "linear",
52+
.linear_start_factor = 0.5f,
53+
.linear_end_factor = 1.0f,
54+
.linear_total_iters = 0,
55+
});
56+
(void)sched;
57+
}));
58+
}
59+
60+
void TestLambdaLRRejectsNullLambda() {
61+
ASSERT_TRUE(ExpectDeath([] {
62+
auto opt = MakeDummyOptimizer(0.1f);
63+
auto sched = CreateLRScheduler(opt, {
64+
.type = "lambda",
65+
});
66+
(void)sched;
67+
}));
68+
}
69+
70+
void TestSequentialLRRejectsMismatchedOptimizer() {
71+
ASSERT_TRUE(ExpectDeath([] {
72+
auto opt1 = MakeDummyOptimizer(0.1f);
73+
auto opt2 = MakeDummyOptimizer(0.1f);
74+
75+
auto s1 = CreateLRScheduler(opt1, {
76+
.type = "linear",
77+
.linear_start_factor = 0.5f,
78+
.linear_end_factor = 1.0f,
79+
.linear_total_iters = 2,
80+
});
81+
auto s2 = CreateLRScheduler(opt2, {
82+
.type = "step",
83+
.step_size = 2,
84+
.step_gamma = 0.5f,
85+
});
86+
87+
auto sched = LRScheduler::Create<SequentialLR>(
88+
opt1, std::vector<std::shared_ptr<LRScheduler>>{s1, s2}, std::vector<int64_t>{1});
89+
(void)sched;
90+
}));
91+
}
92+
93+
void TestSequentialLRRejectsNullChild() {
94+
ASSERT_TRUE(ExpectDeath([] {
95+
auto opt = MakeDummyOptimizer(0.1f);
96+
auto sched = LRScheduler::Create<SequentialLR>(opt, std::vector<std::shared_ptr<LRScheduler>>{nullptr},
97+
std::vector<int64_t>{});
98+
(void)sched;
99+
}));
100+
}
101+
102+
void TestChainedSchedulerRejectsEmptyChildren() {
103+
ASSERT_TRUE(ExpectDeath([] {
104+
auto opt = MakeDummyOptimizer(0.1f);
105+
auto sched = LRScheduler::Create<ChainedScheduler>(opt, std::vector<std::shared_ptr<LRScheduler>>{});
106+
(void)sched;
107+
}));
108+
}
109+
110+
void TestChainedSchedulerRejectsMismatchedOptimizer() {
111+
ASSERT_TRUE(ExpectDeath([] {
112+
auto opt1 = MakeDummyOptimizer(0.1f);
113+
auto opt2 = MakeDummyOptimizer(0.1f);
114+
115+
auto s1 = CreateLRScheduler(opt1, {
116+
.type = "step",
117+
.step_size = 2,
118+
.step_gamma = 0.5f,
119+
});
120+
auto s2 = CreateLRScheduler(opt2, {
121+
.type = "constant",
122+
.constant_factor = 0.5f,
123+
.constant_total_iters = 2,
124+
});
125+
126+
auto sched = LRScheduler::Create<ChainedScheduler>(opt1, std::vector<std::shared_ptr<LRScheduler>>{s1, s2});
127+
(void)sched;
128+
}));
129+
}
130+
131+
void TestChainedSchedulerRejectsNullChild() {
132+
ASSERT_TRUE(ExpectDeath([] {
133+
auto opt = MakeDummyOptimizer(0.1f);
134+
auto sched = LRScheduler::Create<ChainedScheduler>(opt, std::vector<std::shared_ptr<LRScheduler>>{nullptr});
135+
(void)sched;
136+
}));
137+
}
138+
139+
} // namespace
140+
141+
int main(int argc, char *argv[]) {
142+
google::InitGoogleLogging(argv[0]);
143+
144+
std::cout << "=== LR Scheduler Validation Tests ===" << std::endl;
145+
TestStepLRRejectsNonPositiveStepSize();
146+
TestLinearLRRejectsNonPositiveTotalIters();
147+
TestLambdaLRRejectsNullLambda();
148+
TestSequentialLRRejectsMismatchedOptimizer();
149+
TestSequentialLRRejectsNullChild();
150+
TestChainedSchedulerRejectsEmptyChildren();
151+
TestChainedSchedulerRejectsMismatchedOptimizer();
152+
TestChainedSchedulerRejectsNullChild();
153+
154+
if (g_fail_count == 0) {
155+
std::cout << "All Tests PASSED" << std::endl;
156+
} else {
157+
std::cout << g_fail_count << " test(s) FAILED" << std::endl;
158+
}
159+
160+
return g_fail_count > 0 ? 1 : 0;
161+
}

0 commit comments

Comments
 (0)