@@ -102,6 +102,8 @@ float LRScheduler::BaseLR() const { return base_lr_; }
102102
103103int64_t LRScheduler::LastStep () const { return last_step_; }
104104
105+ bool LRScheduler::SharesOptimizerWith (const std::shared_ptr<Optimizer> &opt) const { return optimizer_ == opt; }
106+
105107void LRScheduler::ResetStep (int64_t step) { last_step_ = step; }
106108
107109StateDict LRScheduler::State () const {
@@ -126,7 +128,10 @@ namespace lr_schedulers {
126128// --- ConstantLR ---
127129
128130ConstantLR::ConstantLR (std::shared_ptr<Optimizer> optimizer, float factor, int total_iters, int64_t last_step)
129- : LRScheduler(std::move(optimizer), last_step), factor_(factor), total_iters_(total_iters) {}
131+ : LRScheduler(std::move(optimizer), last_step), factor_(factor), total_iters_(total_iters) {
132+ CHECK_GE (factor_, 0 .0f ) << " ConstantLR: factor must be >= 0." ;
133+ CHECK_LE (factor_, 1 .0f ) << " ConstantLR: factor must be <= 1." ;
134+ }
130135
131136float ConstantLR::GetClosedFormLR () const { return last_step_ < total_iters_ ? base_lr_ * factor_ : base_lr_; }
132137
@@ -145,7 +150,9 @@ float ConstantLR::GetChainedFormLR() const {
145150// --- StepLR ---
146151
147152StepLR::StepLR (std::shared_ptr<Optimizer> optimizer, int64_t step_size, float gamma, int64_t last_step)
148- : LRScheduler(std::move(optimizer), last_step), step_size_(step_size), gamma_(gamma) {}
153+ : LRScheduler(std::move(optimizer), last_step), step_size_(step_size), gamma_(gamma) {
154+ CHECK_GT (step_size_, 0 ) << " StepLR: step_size must be > 0." ;
155+ }
149156
150157float StepLR::GetClosedFormLR () const {
151158 return base_lr_
@@ -163,7 +170,13 @@ float StepLR::GetChainedFormLR() const {
163170LinearLR::LinearLR (std::shared_ptr<Optimizer> optimizer, float start_factor, float end_factor, int64_t total_iters,
164171 int64_t last_step)
165172 : LRScheduler(std::move(optimizer), last_step), start_factor_(start_factor), end_factor_(end_factor),
166- total_iters_ (total_iters) {}
173+ total_iters_ (total_iters) {
174+ CHECK_GT (start_factor_, 0 .0f ) << " LinearLR: start_factor must be > 0." ;
175+ CHECK_LE (start_factor_, 1 .0f ) << " LinearLR: start_factor must be <= 1." ;
176+ CHECK_GE (end_factor_, 0 .0f ) << " LinearLR: end_factor must be >= 0." ;
177+ CHECK_LE (end_factor_, 1 .0f ) << " LinearLR: end_factor must be <= 1." ;
178+ CHECK_GT (total_iters_, 0 ) << " LinearLR: total_iters must be > 0." ;
179+ }
167180
168181float LinearLR::GetClosedFormLR () const {
169182 if (last_step_ >= total_iters_) {
@@ -196,24 +209,35 @@ float LinearLR::GetChainedFormLR() const {
196209}
197210
198211LambdaLR::LambdaLR (std::shared_ptr<Optimizer> optimizer, std::function<float (int64_t )> lr_lambda, int64_t last_step)
199- : LRScheduler(std::move(optimizer), last_step), lr_lambda_(std::move(lr_lambda)) {}
212+ : LRScheduler(std::move(optimizer), last_step), lr_lambda_(std::move(lr_lambda)) {
213+ CHECK (lr_lambda_) << " LambdaLR: lr_lambda must not be null." ;
214+ }
200215
201216float LambdaLR::GetClosedFormLR () const { return base_lr_ * lr_lambda_ (last_step_); }
202217
203218SequentialLR::SequentialLR (std::shared_ptr<Optimizer> optimizer, std::vector<std::shared_ptr<LRScheduler>> schedulers,
204219 std::vector<int64_t > milestones, int64_t last_step)
205220 : LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)),
206- milestones_(std::move(milestones)) {}
207-
208- void SequentialLR::InitialStep () {
221+ milestones_(std::move(milestones)) {
209222 CHECK (!schedulers_.empty ()) << " SequentialLR requires at least one scheduler." ;
223+
224+ for (size_t i = 0 ; i < schedulers_.size (); ++i) {
225+ CHECK (schedulers_[i]) << " SequentialLR: scheduler at index " << i << " must not be null." ;
226+ CHECK (schedulers_[i]->SharesOptimizerWith (optimizer_))
227+ << " SequentialLR: scheduler at index " << i << " must share the same optimizer." ;
228+ }
229+
210230 CHECK_EQ (milestones_.size (), schedulers_.size () - 1 )
211231 << " SequentialLR: milestones count must be schedulers count - 1." ;
212232
213233 for (size_t i = 1 ; i < milestones_.size (); ++i) {
214234 CHECK_GT (milestones_[i], milestones_[i - 1 ]) << " Milestones must be strictly increasing." ;
215235 }
216236
237+ }
238+
239+ void SequentialLR::InitialStep () {
240+
217241 optimizer_->SetLearningRate (schedulers_[0 ]->BaseLR ());
218242
219243 UndoChildInitialSteps ();
@@ -279,11 +303,17 @@ void SequentialLR::LoadState(const StateDict &state) {
279303
280304ChainedScheduler::ChainedScheduler (std::shared_ptr<Optimizer> optimizer,
281305 std::vector<std::shared_ptr<LRScheduler>> schedulers, int64_t last_step)
282- : LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)) {}
283-
284- void ChainedScheduler::InitialStep () {
306+ : LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)) {
285307 CHECK (!schedulers_.empty ()) << " ChainedScheduler requires at least one scheduler." ;
286308
309+ for (size_t i = 0 ; i < schedulers_.size (); ++i) {
310+ CHECK (schedulers_[i]) << " ChainedScheduler: scheduler at index " << i << " must not be null." ;
311+ CHECK (schedulers_[i]->SharesOptimizerWith (optimizer_))
312+ << " ChainedScheduler: scheduler at index " << i << " must share the same optimizer." ;
313+ }
314+ }
315+
316+ void ChainedScheduler::InitialStep () {
287317}
288318
289319void ChainedScheduler::Step () {
0 commit comments