Skip to content

Commit 315ebbb

Browse files
style: fix setter/getter name
1 parent 355d1ef commit 315ebbb

9 files changed

Lines changed: 44 additions & 46 deletions

File tree

infini_train/include/lr_scheduler.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,8 @@ class SequentialLR : public LRScheduler {
142142
void LoadState(const StateDict &state) override;
143143

144144
protected:
145-
float GetClosedFormLR() const override {
146-
return base_lr_;
147-
} // FIXME: SequentialLR should not have a closed-form LR, but we need to implement this pure virtual function.
145+
// FIXME: SequentialLR should not have a closed-form LR, but we need to implement this pure virtual function.
146+
float GetClosedFormLR() const override { return base_lr_; }
148147
void UndoChildInitialSteps();
149148

150149
private:
@@ -165,9 +164,8 @@ class ChainedScheduler : public LRScheduler {
165164
void LoadState(const StateDict &state) override;
166165

167166
protected:
168-
float GetClosedFormLR() const override {
169-
return base_lr_;
170-
} // FIXME: ChainedScheduler should not have a closed-form LR, but we need to implement this pure virtual function.
167+
// FIXME: ChainedScheduler should not have a closed-form LR, but we need to implement this pure virtual function.
168+
float GetClosedFormLR() const override { return base_lr_; }
171169

172170
private:
173171
std::vector<std::shared_ptr<LRScheduler>> schedulers_;

infini_train/include/nn/parallel/ddp/distributed_optimizer.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ class DistributedOptimizer final : public infini_train::Optimizer {
3434
void StartParamSync(bool force_sync = false);
3535
void FinishParamSync(bool skip_next_bucket_dispatch = false);
3636

37-
virtual void SetLearningRate(float lr) override;
38-
virtual float GetLearningRate() const override;
37+
virtual void set_learning_rate(float lr) override;
38+
virtual float learning_rate() const override;
3939

4040
private:
4141
void BuildShardParamsAndBindGrads();

infini_train/include/optimizer.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ class Optimizer {
2121

2222
virtual void Step() = 0;
2323

24-
virtual void SetLearningRate(float lr);
24+
virtual void set_learning_rate(float lr);
2525

26-
virtual float GetLearningRate() const;
26+
virtual float learning_rate() const;
2727

28-
float GetInitialLearningRate() const;
28+
float initial_learning_rate() const;
2929

30-
void SetInitialLearningRate(float lr);
30+
void set_initial_learning_rate(float lr);
3131

3232
protected:
3333
std::vector<std::shared_ptr<Tensor>> params_;

infini_train/src/lr_scheduler.cc

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ std::shared_ptr<LRScheduler> CreateLRScheduler(std::shared_ptr<Optimizer> optimi
1818
}
1919

2020
CHECK(optimizer) << "CreateLRScheduler: optimizer must not be null.";
21-
const float max_lr = config.lr != 0.0f ? config.lr : optimizer->GetLearningRate();
21+
const float max_lr = config.lr != 0.0f ? config.lr : optimizer->learning_rate();
2222
CHECK_GT(max_lr, 0.0f) << "CreateLRScheduler: max_lr must be > 0.";
2323
CHECK_GE(config.lr_warmup_init, 0.0f) << "CreateLRScheduler: lr_warmup_init must be >= 0.";
2424
CHECK_GE(config.min_lr, 0.0f) << "CreateLRScheduler: min_lr must be >= 0.";
@@ -86,8 +86,8 @@ std::shared_ptr<LRScheduler> CreateLRScheduler(std::shared_ptr<Optimizer> optimi
8686
LRScheduler::LRScheduler(std::shared_ptr<Optimizer> optimizer, int64_t last_step)
8787
: optimizer_(std::move(optimizer)), last_step_(last_step), base_lr_(0.0f) {
8888
CHECK(optimizer_) << "LRScheduler: optimizer must not be null.";
89-
optimizer_->SetInitialLearningRate(optimizer_->GetLearningRate());
90-
base_lr_ = optimizer_->GetInitialLearningRate();
89+
optimizer_->set_initial_learning_rate(optimizer_->learning_rate());
90+
base_lr_ = optimizer_->initial_learning_rate();
9191
}
9292

9393
void LRScheduler::Step() {
@@ -106,11 +106,11 @@ void LRScheduler::InitialStep() {
106106
is_initial_ = false;
107107
}
108108

109-
void LRScheduler::ApplyLR(float lr) { optimizer_->SetLearningRate(lr); }
109+
void LRScheduler::ApplyLR(float lr) { optimizer_->set_learning_rate(lr); }
110110

111111
float LRScheduler::GetChainedFormLR() const { return GetClosedFormLR(); }
112112

113-
float LRScheduler::GetLR() const { return optimizer_->GetLearningRate(); }
113+
float LRScheduler::GetLR() const { return optimizer_->learning_rate(); }
114114

115115
float LRScheduler::BaseLR() const { return base_lr_; }
116116

@@ -123,7 +123,7 @@ void LRScheduler::ResetStep(int64_t step) { last_step_ = step; }
123123
StateDict LRScheduler::State() const {
124124
return {
125125
{"last_step", last_step_},
126-
{"recover_lr", optimizer_->GetLearningRate()},
126+
{"recover_lr", optimizer_->learning_rate()},
127127
{"base_lr", base_lr_},
128128
};
129129
}
@@ -132,7 +132,7 @@ void LRScheduler::LoadState(const StateDict &state) {
132132
last_step_ = std::get<int64_t>(state.at("last_step"));
133133
recover_lr_ = std::get<float>(state.at("recover_lr"));
134134
base_lr_ = std::get<float>(state.at("base_lr"));
135-
optimizer_->SetLearningRate(recover_lr_);
135+
optimizer_->set_learning_rate(recover_lr_);
136136
}
137137

138138
// Concrete LR Schedulers
@@ -150,7 +150,7 @@ ConstantLR::ConstantLR(std::shared_ptr<Optimizer> optimizer, float factor, int t
150150
float ConstantLR::GetClosedFormLR() const { return last_step_ < total_iters_ ? base_lr_ * factor_ : base_lr_; }
151151

152152
float ConstantLR::GetChainedFormLR() const {
153-
const float lr = optimizer_->GetLearningRate();
153+
const float lr = optimizer_->learning_rate();
154154
if (last_step_ == 0) {
155155
return lr * factor_;
156156
} else if (last_step_ < total_iters_) {
@@ -175,7 +175,7 @@ float StepLR::GetClosedFormLR() const {
175175
}
176176

177177
float StepLR::GetChainedFormLR() const {
178-
const float lr = optimizer_->GetLearningRate();
178+
const float lr = optimizer_->learning_rate();
179179
if (last_step_ == 0 || (last_step_ % step_size_) != 0) {
180180
return lr;
181181
}
@@ -203,7 +203,7 @@ float LinearLR::GetClosedFormLR() const {
203203
}
204204

205205
float LinearLR::GetChainedFormLR() const {
206-
const float lr = optimizer_->GetLearningRate();
206+
const float lr = optimizer_->learning_rate();
207207
if (last_step_ == 0) {
208208
return lr * start_factor_;
209209
}
@@ -252,7 +252,7 @@ SequentialLR::SequentialLR(std::shared_ptr<Optimizer> optimizer, std::vector<std
252252

253253
void SequentialLR::InitialStep() {
254254

255-
optimizer_->SetLearningRate(schedulers_[0]->BaseLR());
255+
optimizer_->set_learning_rate(schedulers_[0]->BaseLR());
256256

257257
UndoChildInitialSteps();
258258

@@ -285,7 +285,7 @@ void SequentialLR::Step() {
285285
StateDict SequentialLR::State() const {
286286
StateDict state;
287287
state["last_step"] = last_step_;
288-
state["recover_lr"] = optimizer_->GetLearningRate();
288+
state["recover_lr"] = optimizer_->learning_rate();
289289
state["base_lr"] = base_lr_;
290290
for (size_t i = 0; i < schedulers_.size(); ++i) {
291291
auto sub_state = schedulers_[i]->State();
@@ -311,7 +311,7 @@ void SequentialLR::LoadState(const StateDict &state) {
311311
schedulers_[i]->LoadState(sub_state);
312312
}
313313
}
314-
optimizer_->SetLearningRate(recover_lr_);
314+
optimizer_->set_learning_rate(recover_lr_);
315315
}
316316

317317
ChainedScheduler::ChainedScheduler(std::shared_ptr<Optimizer> optimizer,

infini_train/src/nn/parallel/ddp/distributed_optimizer.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,18 +114,18 @@ void DistributedOptimizer::ZeroGrad(bool set_to_none) {
114114
}
115115
}
116116

117-
void DistributedOptimizer::SetLearningRate(float lr) {
118-
Optimizer::SetLearningRate(lr);
117+
void DistributedOptimizer::set_learning_rate(float lr) {
118+
Optimizer::set_learning_rate(lr);
119119
if (base_optimizer_) {
120-
base_optimizer_->SetLearningRate(lr);
120+
base_optimizer_->set_learning_rate(lr);
121121
}
122122
}
123123

124-
float DistributedOptimizer::GetLearningRate() const {
124+
float DistributedOptimizer::learning_rate() const {
125125
if (base_optimizer_) {
126-
return base_optimizer_->GetLearningRate();
126+
return base_optimizer_->learning_rate();
127127
}
128-
return Optimizer::GetLearningRate();
128+
return Optimizer::learning_rate();
129129
}
130130

131131
void DistributedOptimizer::Step() {

infini_train/src/optimizer.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,17 @@ void Optimizer::ZeroGrad(bool set_to_none) {
1515
for (auto param : params_) { param->ZeroGrad(set_to_none); }
1616
}
1717

18-
void Optimizer::SetLearningRate(float lr) { learning_rate_ = lr; }
18+
void Optimizer::set_learning_rate(float lr) { learning_rate_ = lr; }
1919

20-
float Optimizer::GetLearningRate() const { return learning_rate_; }
20+
float Optimizer::learning_rate() const { return learning_rate_; }
2121

22-
float Optimizer::GetInitialLearningRate() const {
22+
float Optimizer::initial_learning_rate() const {
2323
CHECK(initial_lr_set_) << "Optimizer: initial_learning_rate not set. "
2424
"Use with an LRScheduler first.";
2525
return initial_learning_rate_;
2626
}
2727

28-
void Optimizer::SetInitialLearningRate(float lr) {
28+
void Optimizer::set_initial_learning_rate(float lr) {
2929
if (!initial_lr_set_) {
3030
initial_learning_rate_ = lr;
3131
initial_lr_set_ = true;

test/lr_scheduler/test_constant_lr.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@ void TestInitialState() {
1313
auto sched = LRScheduler::Create<ConstantLR>(opt, /*factor=*/0.5f, /*total_iters=*/3);
1414
ASSERT_FLOAT_EQ(sched->GetLR(), 0.05f);
1515
ASSERT_TRUE(sched->LastStep() == 0);
16-
ASSERT_FLOAT_EQ(opt->GetLearningRate(), 0.05f);
16+
ASSERT_FLOAT_EQ(opt->learning_rate(), 0.05f);
1717
}
1818

1919
void TestFirstStepAppliesFactor() {
2020
auto opt = MakeDummyOptimizer(kBaseLR);
2121
auto sched = LRScheduler::Create<ConstantLR>(opt, /*factor=*/0.5f, /*total_iters=*/3);
2222
sched->Step(); // last_step_ = 0
2323
ASSERT_FLOAT_EQ(sched->GetLR(), 0.05f);
24-
ASSERT_FLOAT_EQ(opt->GetLearningRate(), 0.05f);
24+
ASSERT_FLOAT_EQ(opt->learning_rate(), 0.05f);
2525
ASSERT_TRUE(sched->LastStep() == 1);
2626
}
2727

@@ -38,7 +38,7 @@ void TestBeyondTotalIters() {
3838
auto sched = LRScheduler::Create<ConstantLR>(opt, /*factor=*/0.5f, /*total_iters=*/3);
3939
for (int i = 0; i < 10; ++i) { sched->Step(); }
4040
ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR);
41-
ASSERT_FLOAT_EQ(opt->GetLearningRate(), kBaseLR);
41+
ASSERT_FLOAT_EQ(opt->learning_rate(), kBaseLR);
4242
}
4343

4444
void TestPyTorchAlignment() {
@@ -63,7 +63,7 @@ void TestStateRoundTrip() {
6363

6464
ASSERT_TRUE(sched2->LastStep() == sched->LastStep());
6565
ASSERT_FLOAT_EQ(sched2->GetLR(), sched->GetLR());
66-
ASSERT_FLOAT_EQ(opt2->GetLearningRate(), sched->GetLR());
66+
ASSERT_FLOAT_EQ(opt2->learning_rate(), sched->GetLR());
6767
}
6868

6969
void TestResumeConsistency() {

test/lr_scheduler/test_lambda_lr.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ void TestIdentityLambda() {
1414
// Step() → last_step_=0, lr = 0.1 * 1.0 = 0.1
1515
ASSERT_TRUE(sched->LastStep() == 0);
1616
ASSERT_FLOAT_NEAR(sched->GetLR(), kBaseLR, kEps);
17-
ASSERT_FLOAT_NEAR(opt->GetLearningRate(), kBaseLR, kEps);
17+
ASSERT_FLOAT_NEAR(opt->learning_rate(), kBaseLR, kEps);
1818
}
1919

2020
void TestLinearDecayLambda() {

test/lr_scheduler/test_lr_scheduler.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ void TestInitialState() {
6969

7070
ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR);
7171
ASSERT_TRUE(sched->LastStep() == 0);
72-
ASSERT_FLOAT_EQ(opt->GetLearningRate(), kBaseLR);
72+
ASSERT_FLOAT_EQ(opt->learning_rate(), kBaseLR);
7373
}
7474

7575
// T2: SingleStep
@@ -82,7 +82,7 @@ void TestSingleStep() {
8282

8383
ASSERT_TRUE(sched->LastStep() == 1);
8484
ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR);
85-
ASSERT_FLOAT_EQ(opt->GetLearningRate(), kBaseLR);
85+
ASSERT_FLOAT_EQ(opt->learning_rate(), kBaseLR);
8686
}
8787

8888
// T3: ComputeLR
@@ -92,15 +92,15 @@ void TestLinearDecay() {
9292
auto opt = MakeDummyOptimizer(kBaseLR);
9393
auto sched = LRScheduler::Create<LinearDecayScheduler>(opt, /*total_steps=*/kTotalSteps);
9494
ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR);
95-
ASSERT_FLOAT_EQ(opt->GetLearningRate(), kBaseLR);
95+
ASSERT_FLOAT_EQ(opt->learning_rate(), kBaseLR);
9696

9797
sched->Step(); // last_step = 1 -> 0.09
9898
ASSERT_FLOAT_EQ(sched->GetLR(), 0.09f);
99-
ASSERT_FLOAT_EQ(opt->GetLearningRate(), 0.09f);
99+
ASSERT_FLOAT_EQ(opt->learning_rate(), 0.09f);
100100

101101
for (int i = 0; i < 4; ++i) { sched->Step(); } // last_step = 5
102102
ASSERT_FLOAT_EQ(sched->GetLR(), 0.05f);
103-
ASSERT_FLOAT_EQ(opt->GetLearningRate(), 0.05f);
103+
ASSERT_FLOAT_EQ(opt->learning_rate(), 0.05f);
104104
}
105105

106106
// T4: State → LoadState
@@ -124,7 +124,7 @@ void TestStateRoundTrip() {
124124

125125
ASSERT_TRUE(sched2->LastStep() == 7);
126126
ASSERT_FLOAT_EQ(sched2->GetLR(), sched->GetLR());
127-
ASSERT_FLOAT_EQ(opt2->GetLearningRate(), sched->GetLR());
127+
ASSERT_FLOAT_EQ(opt2->learning_rate(), sched->GetLR());
128128
}
129129

130130
// T5: resume Step

0 commit comments

Comments
 (0)