@@ -511,18 +511,10 @@ def __init__(
511511 # === Derive stable and decay phase lengths ===
512512 self .decay_phase_ratio = decay_phase_ratio
513513 self .decay_type = decay_type
514- self .decay_phase_steps = int (self .decay_phase_ratio * self .num_steps )
515- if self .decay_phase_steps <= 0 :
516- raise ValueError (
517- "decay_phase_ratio results in zero decay steps. "
518- "Increase num_steps or decay_phase_ratio."
519- )
520- if self .decay_phase_steps > self .decay_num_steps :
521- raise ValueError (
522- "decay phase steps must not exceed the post-warmup steps. "
523- f"Got decay_phase_steps={ self .decay_phase_steps } , "
524- f"post_warmup_steps={ self .decay_num_steps } ."
525- )
514+ # Clamp decay_phase_steps to valid range [1, decay_num_steps]
515+ self .decay_phase_steps = max (
516+ 1 , min (int (self .decay_phase_ratio * self .num_steps ), self .decay_num_steps )
517+ )
526518 self .stable_steps = self .decay_num_steps - self .decay_phase_steps
527519
528520 def _decay_value (self , step : int | Array ) -> Array :
@@ -556,7 +548,6 @@ def _decay_value(self, step: int | Array) -> Array:
556548 stop_lr = xp .asarray (self .stop_lr , dtype = step_dtype )
557549 stable_steps = xp .asarray (self .stable_steps , dtype = step_dtype )
558550 decay_phase_steps = xp .asarray (self .decay_phase_steps , dtype = step_dtype )
559- decay_num_steps = xp .asarray (self .decay_num_steps , dtype = step_dtype )
560551
561552 # === Step 2. Keep a constant learning rate in the stable phase ===
562553 decay_progress = (typed_step - stable_steps ) / decay_phase_steps
0 commit comments