@@ -109,3 +109,95 @@ def value(self, step) -> np.float64:
109109 return self .start_lr - decay_rate * (
110110 step - self .decay_start_rate * self .stop_steps
111111 )
112+
113+
114+ class LearningRateLinear :
115+ def __init__ (
116+ self ,
117+ start_lr : float ,
118+ stop_steps : int ,
119+ decay_steps : int ,
120+ start_factor : float = 1.0 ,
121+ end_factor : float = 1.0 ,
122+ ** kwargs ,
123+ ) -> None :
124+ """
125+ Piecewise-constant linear LR schedule updated every `decay_steps`.
126+
127+ The LR factor linearly interpolates from `start_factor` (at step=0)
128+ to `end_factor` (at and after step >= stop_steps), but the value only
129+ changes at discrete update boundaries (multiples of `decay_steps`).
130+
131+ Parameters
132+ ----------
133+ start_lr : float
134+ Base learning rate (multiplied by the factor below).
135+ stop_steps : int
136+ Total number of training steps for this scheduler.
137+ decay_steps : int
138+ Interval (in steps) between LR updates; e.g., 1k or 10k.
139+ start_factor : float
140+ Multiplicative factor at step 0.
141+ end_factor : float
142+ Multiplicative factor at and after step >= stop_steps.
143+
144+ Examples
145+ --------
146+ Let k = floor(step / decay_steps).
147+ Let U = stop_steps / decay_steps (can be non-integer).
148+ progress = clamp(k / U, 0, 1).
149+ factor(step) = start_factor + (end_factor - start_factor) * progress.
150+ After step >= stop_steps, factor(step) = end_factor.
151+ - If `decay_steps` >= `stop_steps`, it will be replaced by a reasonable
152+ default so the schedule still updates multiple times.
153+ - This mirrors the spirit of torch.optim.lr_scheduler.LinearLR but with
154+ discrete updates every `decay_steps` steps (akin to treating each
155+ update as an "epoch").
156+ """
157+ self .base_lr = float (start_lr )
158+ self .start_factor = float (start_factor )
159+ self .end_factor = float (end_factor )
160+ self .stop_steps = int (stop_steps )
161+
162+ # Choose a safe decay_steps (avoid zero/oversized intervals)
163+ self .decay_steps = int (decay_steps ) if int (decay_steps ) > 0 else 1
164+ default_ds = 100 if self .stop_steps // 10 > 100 else self .stop_steps // 100 + 1
165+ if self .decay_steps >= self .stop_steps :
166+ self .decay_steps = max (1 , int (default_ds ))
167+
168+ # Total number of "update buckets" over the training horizon (float)
169+ self .total_updates = self .stop_steps / self .decay_steps
170+
171+ def value (self , step : int ) -> np .float64 :
172+ """
173+ Get the learning rate at the given `step`.
174+
175+ - Updates occur only at multiples of `decay_steps`.
176+ - Saturates at `end_factor` when step >= stop_steps.
177+ - Negative steps are treated as 0.
178+ """
179+ if step <= 0 :
180+ factor = self .start_factor
181+ elif step >= self .stop_steps :
182+ factor = self .end_factor
183+ else :
184+ updates_done = step // self .decay_steps # integer count of updates so far
185+ progress = (
186+ updates_done / self .total_updates
187+ ) # may be slightly < 1 before stop_steps
188+ # Clamp numerical drift into [0, 1]
189+ if progress < 0.0 :
190+ progress = 0.0
191+ elif progress > 1.0 :
192+ progress = 1.0
193+
194+ factor = (
195+ self .start_factor + (self .end_factor - self .start_factor ) * progress
196+ )
197+ # Monotone clamp to never overshoot end_factor due to rounding
198+ if self .end_factor < self .start_factor :
199+ factor = max (factor , self .end_factor )
200+ else :
201+ factor = min (factor , self .end_factor )
202+
203+ return np .float64 (self .base_lr * factor )
0 commit comments