Skip to content

Commit dc748bd

Browse files
author
kinorw
committed
style: format code and comments for consistency across lr_scheduler files
1 parent 2980f93 commit dc748bd

File tree

11 files changed

+36
-37
lines changed

11 files changed

+36
-37
lines changed

infini_train/include/lr_scheduler.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class LRScheduler {
6767
virtual StateDict State() const;
6868
virtual void LoadState(const StateDict &state);
6969

70-
bool SharesOptimizerWith(const std::shared_ptr<Optimizer> &opt) const ;
70+
bool SharesOptimizerWith(const std::shared_ptr<Optimizer> &opt) const;
7171

7272
protected:
7373
virtual float GetClosedFormLR() const = 0;
@@ -157,7 +157,9 @@ class SequentialLR : public LRScheduler {
157157
void LoadState(const StateDict &state) override;
158158

159159
protected:
160-
float GetClosedFormLR() const override { return base_lr_; } // FIXME: SequentialLR should not have a closed-form LR, but we need to implement this pure virtual function.
160+
float GetClosedFormLR() const override {
161+
return base_lr_;
162+
} // FIXME: SequentialLR should not have a closed-form LR, but we need to implement this pure virtual function.
161163
void UndoChildInitialSteps();
162164

163165
private:
@@ -178,7 +180,9 @@ class ChainedScheduler : public LRScheduler {
178180
void LoadState(const StateDict &state) override;
179181

180182
protected:
181-
float GetClosedFormLR() const override { return base_lr_; } // FIXME: ChainedScheduler should not have a closed-form LR, but we need to implement this pure virtual function.
183+
float GetClosedFormLR() const override {
184+
return base_lr_;
185+
} // FIXME: ChainedScheduler should not have a closed-form LR, but we need to implement this pure virtual function.
182186

183187
private:
184188
std::vector<std::shared_ptr<LRScheduler>> schedulers_;

infini_train/src/lr_scheduler.cc

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,7 @@ void LRScheduler::InitialStep() {
9090
is_initial_ = false;
9191
}
9292

93-
void LRScheduler::ApplyLR(float lr) {
94-
optimizer_->SetLearningRate(lr);
95-
}
93+
void LRScheduler::ApplyLR(float lr) { optimizer_->SetLearningRate(lr); }
9694

9795
float LRScheduler::GetChainedFormLR() const { return GetClosedFormLR(); }
9896

@@ -129,9 +127,9 @@ namespace lr_schedulers {
129127

130128
ConstantLR::ConstantLR(std::shared_ptr<Optimizer> optimizer, float factor, int total_iters, int64_t last_step)
131129
: LRScheduler(std::move(optimizer), last_step), factor_(factor), total_iters_(total_iters) {
132-
CHECK_GE(factor_, 0.0f) << "ConstantLR: factor must be >= 0.";
133-
CHECK_LE(factor_, 1.0f) << "ConstantLR: factor must be <= 1.";
134-
}
130+
CHECK_GE(factor_, 0.0f) << "ConstantLR: factor must be >= 0.";
131+
CHECK_LE(factor_, 1.0f) << "ConstantLR: factor must be <= 1.";
132+
}
135133

136134
float ConstantLR::GetClosedFormLR() const { return last_step_ < total_iters_ ? base_lr_ * factor_ : base_lr_; }
137135

@@ -171,12 +169,12 @@ LinearLR::LinearLR(std::shared_ptr<Optimizer> optimizer, float start_factor, flo
171169
int64_t last_step)
172170
: LRScheduler(std::move(optimizer), last_step), start_factor_(start_factor), end_factor_(end_factor),
173171
total_iters_(total_iters) {
174-
CHECK_GT(start_factor_, 0.0f) << "LinearLR: start_factor must be > 0.";
175-
CHECK_LE(start_factor_, 1.0f) << "LinearLR: start_factor must be <= 1.";
176-
CHECK_GE(end_factor_, 0.0f) << "LinearLR: end_factor must be >= 0.";
177-
CHECK_LE(end_factor_, 1.0f) << "LinearLR: end_factor must be <= 1.";
178-
CHECK_GT(total_iters_, 0) << "LinearLR: total_iters must be > 0.";
179-
}
172+
CHECK_GT(start_factor_, 0.0f) << "LinearLR: start_factor must be > 0.";
173+
CHECK_LE(start_factor_, 1.0f) << "LinearLR: start_factor must be <= 1.";
174+
CHECK_GE(end_factor_, 0.0f) << "LinearLR: end_factor must be >= 0.";
175+
CHECK_LE(end_factor_, 1.0f) << "LinearLR: end_factor must be <= 1.";
176+
CHECK_GT(total_iters_, 0) << "LinearLR: total_iters must be > 0.";
177+
}
180178

181179
float LinearLR::GetClosedFormLR() const {
182180
if (last_step_ >= total_iters_) {
@@ -210,8 +208,8 @@ float LinearLR::GetChainedFormLR() const {
210208

211209
LambdaLR::LambdaLR(std::shared_ptr<Optimizer> optimizer, std::function<float(int64_t)> lr_lambda, int64_t last_step)
212210
: LRScheduler(std::move(optimizer), last_step), lr_lambda_(std::move(lr_lambda)) {
213-
CHECK(lr_lambda_) << "LambdaLR: lr_lambda must not be null.";
214-
}
211+
CHECK(lr_lambda_) << "LambdaLR: lr_lambda must not be null.";
212+
}
215213

216214
float LambdaLR::GetClosedFormLR() const { return base_lr_ * lr_lambda_(last_step_); }
217215

@@ -233,7 +231,6 @@ SequentialLR::SequentialLR(std::shared_ptr<Optimizer> optimizer, std::vector<std
233231
for (size_t i = 1; i < milestones_.size(); ++i) {
234232
CHECK_GT(milestones_[i], milestones_[i - 1]) << "Milestones must be strictly increasing.";
235233
}
236-
237234
}
238235

239236
void SequentialLR::InitialStep() {
@@ -266,7 +263,6 @@ void SequentialLR::Step() {
266263
} else {
267264
scheduler->Step();
268265
}
269-
270266
}
271267

272268
StateDict SequentialLR::State() const {
@@ -313,8 +309,7 @@ ChainedScheduler::ChainedScheduler(std::shared_ptr<Optimizer> optimizer,
313309
}
314310
}
315311

316-
void ChainedScheduler::InitialStep() {
317-
}
312+
void ChainedScheduler::InitialStep() {}
318313

319314
void ChainedScheduler::Step() {
320315
++last_step_;

test/lr_scheduler/test_chained_lr.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using namespace infini_train::lr_schedulers;
77
namespace {
88
constexpr float kBaseLR = 0.1f;
99
}
10-
// TC1: 单子调度器退化
10+
1111
void TestSingleScheduler() {
1212
std::cout << "[TC1] TestSingleScheduler" << std::endl;
1313
auto opt = MakeDummyOptimizer(kBaseLR);
@@ -23,7 +23,7 @@ void TestSingleScheduler() {
2323
ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps);
2424
}
2525

26-
// TC2: StepLR + LambdaLR 乘法叠加
26+
// TC2: StepLR + LambdaLR
2727
void TestMultiplicativeChain() {
2828
std::cout << "[TC2] TestMultiplicativeChain" << std::endl;
2929
auto opt = MakeDummyOptimizer(kBaseLR);
@@ -53,7 +53,7 @@ void TestMultiplicativeChain() {
5353
ASSERT_FLOAT_NEAR(sched->GetLR(), 0.07f, kEps);
5454
}
5555

56-
// TC3: ConstantLR + StepLR 叠加 (无穿插声明)
56+
// TC3: ConstantLR + StepLR
5757
void TestConstantPlusStep() {
5858
std::cout << "[TC3] TestConstantPlusStep" << std::endl;
5959
auto opt = MakeDummyOptimizer(kBaseLR);
@@ -86,7 +86,7 @@ void TestConstantPlusStep() {
8686
ASSERT_FLOAT_NEAR(sched->GetLR(), 0.01f, kEps);
8787
}
8888

89-
// TC4: ConstantLR + StepLR 叠加(有穿插声明
89+
// TC4: ConstantLR + StepLR (with extra unused scheduler
9090
void TestConstantPlusStepDLC() {
9191
std::cout << "[TC4] TestConstantPlusStepDLC" << std::endl;
9292
auto opt = MakeDummyOptimizer(kBaseLR);
@@ -129,7 +129,7 @@ void TestConstantPlusStepDLC() {
129129
ASSERT_FLOAT_NEAR(sched->GetLR(), 0.02f, kEps);
130130
}
131131

132-
// TC5: State/LoadState 往返
132+
// TC5: State/LoadState
133133
void TestStateRoundTrip() {
134134
std::cout << "[TC5] TestStateRoundTrip" << std::endl;
135135
auto opt = MakeDummyOptimizer(kBaseLR);
@@ -199,4 +199,4 @@ int main(int argc, char *argv[]) {
199199
std::cout << g_fail_count << " test(s) FAILED" << std::endl;
200200
}
201201
return g_fail_count > 0 ? 1 : 0;
202-
}
202+
}

test/lr_scheduler/test_constant_lr.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,4 +181,4 @@ int main(int argc, char *argv[]) {
181181
std::cout << g_fail_count << " test(s) FAILED" << std::endl;
182182
}
183183
return g_fail_count > 0 ? 1 : 0;
184-
}
184+
}

test/lr_scheduler/test_helpers.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@ void Check(bool cond, const char *expr, int line) {
3232
#define ASSERT_FLOAT_EQ(a, b) Check(FloatNear((a), (b)), #a " == " #b, __LINE__)
3333
#define ASSERT_FLOAT_NEAR(a, b, eps) Check(FloatNear((a), (b), (eps)), #a "" #b, __LINE__)
3434

35-
} // namespace
35+
} // namespace

test/lr_scheduler/test_lambda_lr.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ void TestIdentityLambda() {
1414
.type = "lambda",
1515
.lambda_fn = [](int64_t) { return 1.0f; },
1616
});
17-
// 构造器内 Step() → last_step_=0, lr = 0.1 * 1.0 = 0.1
17+
// Step() → last_step_=0, lr = 0.1 * 1.0 = 0.1
1818
ASSERT_TRUE(sched->LastStep() == 0);
1919
ASSERT_FLOAT_NEAR(sched->GetLR(), kBaseLR, kEps);
2020
ASSERT_FLOAT_NEAR(opt->GetLearningRate(), kBaseLR, kEps);
@@ -124,4 +124,4 @@ int main(int argc, char *argv[]) {
124124
std::cout << g_fail_count << " test(s) FAILED" << std::endl;
125125
}
126126
return g_fail_count > 0 ? 1 : 0;
127-
}
127+
}

test/lr_scheduler/test_linear_lr.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,4 +137,4 @@ int main(int argc, char *argv[]) {
137137
std::cout << g_fail_count << " test(s) FAILED" << std::endl;
138138
}
139139
return g_fail_count > 0 ? 1 : 0;
140-
}
140+
}

test/lr_scheduler/test_lr_scheduler.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ void TestLinearDecay() {
103103
ASSERT_FLOAT_EQ(opt->GetLearningRate(), 0.05f);
104104
}
105105

106-
// T4: State → LoadState
106+
// T4: State → LoadState
107107
void TestStateRoundTrip() {
108108
std::cout << "[T4] TestStateRoundTrip" << std::endl;
109109
constexpr int64_t kTotalSteps = 20;

test/lr_scheduler/test_lr_scheduler_validation.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ void TestSequentialLRRejectsMismatchedOptimizer() {
8484
.step_gamma = 0.5f,
8585
});
8686

87-
auto sched = LRScheduler::Create<SequentialLR>(
88-
opt1, std::vector<std::shared_ptr<LRScheduler>>{s1, s2}, std::vector<int64_t>{1});
87+
auto sched = LRScheduler::Create<SequentialLR>(opt1, std::vector<std::shared_ptr<LRScheduler>>{s1, s2},
88+
std::vector<int64_t>{1});
8989
(void)sched;
9090
}));
9191
}

test/lr_scheduler/test_sequential_lr.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,4 +219,4 @@ int main(int argc, char *argv[]) {
219219
std::cout << g_fail_count << " test(s) FAILED" << std::endl;
220220
}
221221
return g_fail_count > 0 ? 1 : 0;
222-
}
222+
}

0 commit comments

Comments
 (0)