Skip to content

Commit 06c113d

Browse files
committed
add linear lr
1 parent c1c4234 commit 06c113d

4 files changed

Lines changed: 109 additions & 0 deletions

File tree

deepmd/dpmodel/utils/learning_rate.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,95 @@ def value(self, step) -> np.float64:
109109
return self.start_lr - decay_rate * (
110110
step - self.decay_start_rate * self.stop_steps
111111
)
112+
113+
114+
class LearningRateLinear:
115+
def __init__(
116+
self,
117+
start_lr: float,
118+
stop_steps: int,
119+
decay_steps: int,
120+
start_factor: float = 1.0,
121+
end_factor: float = 1.0,
122+
**kwargs,
123+
) -> None:
124+
"""
125+
Piecewise-constant linear LR schedule updated every `decay_steps`.
126+
127+
The LR factor linearly interpolates from `start_factor` (at step=0)
128+
to `end_factor` (at and after step >= stop_steps), but the value only
129+
changes at discrete update boundaries (multiples of `decay_steps`).
130+
131+
Parameters
132+
----------
133+
start_lr : float
134+
Base learning rate (multiplied by the factor below).
135+
stop_steps : int
136+
Total number of training steps for this scheduler.
137+
decay_steps : int
138+
Interval (in steps) between LR updates; e.g., 1k or 10k.
139+
start_factor : float
140+
Multiplicative factor at step 0.
141+
end_factor : float
142+
Multiplicative factor at and after step >= stop_steps.
143+
144+
Examples
145+
--------
146+
Let k = floor(step / decay_steps).
147+
Let U = stop_steps / decay_steps (can be non-integer).
148+
progress = clamp(k / U, 0, 1).
149+
factor(step) = start_factor + (end_factor - start_factor) * progress.
150+
After step >= stop_steps, factor(step) = end_factor.
151+
- If `decay_steps` >= `stop_steps`, it will be replaced by a reasonable
152+
default so the schedule still updates multiple times.
153+
- This mirrors the spirit of torch.optim.lr_scheduler.LinearLR but with
154+
discrete updates every `decay_steps` steps (akin to treating each
155+
update as an "epoch").
156+
"""
157+
self.base_lr = float(start_lr)
158+
self.start_factor = float(start_factor)
159+
self.end_factor = float(end_factor)
160+
self.stop_steps = int(stop_steps)
161+
162+
# Choose a safe decay_steps (avoid zero/oversized intervals)
163+
self.decay_steps = int(decay_steps) if int(decay_steps) > 0 else 1
164+
default_ds = 100 if self.stop_steps // 10 > 100 else self.stop_steps // 100 + 1
165+
if self.decay_steps >= self.stop_steps:
166+
self.decay_steps = max(1, int(default_ds))
167+
168+
# Total number of "update buckets" over the training horizon (float)
169+
self.total_updates = self.stop_steps / self.decay_steps
170+
171+
def value(self, step: int) -> np.float64:
172+
"""
173+
Get the learning rate at the given `step`.
174+
175+
- Updates occur only at multiples of `decay_steps`.
176+
- Saturates at `end_factor` when step >= stop_steps.
177+
- Negative steps are treated as 0.
178+
"""
179+
if step <= 0:
180+
factor = self.start_factor
181+
elif step >= self.stop_steps:
182+
factor = self.end_factor
183+
else:
184+
updates_done = step // self.decay_steps # integer count of updates so far
185+
progress = (
186+
updates_done / self.total_updates
187+
) # may be slightly < 1 before stop_steps
188+
# Clamp numerical drift into [0, 1]
189+
if progress < 0.0:
190+
progress = 0.0
191+
elif progress > 1.0:
192+
progress = 1.0
193+
194+
factor = (
195+
self.start_factor + (self.end_factor - self.start_factor) * progress
196+
)
197+
# Monotone clamp to never overshoot end_factor due to rounding
198+
if self.end_factor < self.start_factor:
199+
factor = max(factor, self.end_factor)
200+
else:
201+
factor = min(factor, self.end_factor)
202+
203+
return np.float64(self.base_lr * factor)

deepmd/pt/train/training.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from deepmd.pt.utils.learning_rate import (
6262
LearningRateCosine,
6363
LearningRateExp,
64+
LearningRateLinear,
6465
LearningRateWSD,
6566
)
6667
from deepmd.pt.utils.stat import (
@@ -252,6 +253,8 @@ def get_lr(lr_params):
252253
lr_schedule = LearningRateCosine(**lr_params)
253254
elif lr_type == "wsd":
254255
lr_schedule = LearningRateWSD(**lr_params)
256+
elif lr_type == "linear":
257+
lr_schedule = LearningRateLinear(**lr_params)
255258
else:
256259
raise ValueError(f"Not supported learning rate type '{lr_type}'!")
257260
return lr_schedule

deepmd/pt/utils/learning_rate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
from deepmd.dpmodel.utils.learning_rate import (
33
LearningRateCosine,
44
LearningRateExp,
5+
LearningRateLinear,
56
LearningRateWSD,
67
)
78

89
__all__ = [
910
"LearningRateCosine",
1011
"LearningRateExp",
12+
"LearningRateLinear",
1113
"LearningRateWSD",
1214
]

deepmd/utils/argcheck.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3066,6 +3066,17 @@ def learning_rate_wsd():
30663066
return args
30673067

30683068

3069+
def learning_rate_linear():
3070+
doc_start_lr = "The learning rate at the start of the training."
3071+
args = [
3072+
Argument("start_lr", float, optional=True, default=1e-3, doc=doc_start_lr),
3073+
Argument("start_factor", float, optional=True, default=1.0),
3074+
Argument("end_factor", float, optional=True, default=1e-3),
3075+
Argument("decay_steps", int, optional=True, default=1000),
3076+
]
3077+
return args
3078+
3079+
30693080
def learning_rate_variant_type_args():
30703081
doc_lr = "The type of the learning rate."
30713082

@@ -3075,6 +3086,7 @@ def learning_rate_variant_type_args():
30753086
Argument("exp", dict, learning_rate_exp()),
30763087
Argument("cosine", dict, learning_rate_cosine()),
30773088
Argument("wsd", dict, learning_rate_wsd()),
3089+
Argument("linear", dict, learning_rate_linear()),
30783090
],
30793091
optional=True,
30803092
default_tag="exp",

0 commit comments

Comments
 (0)