Skip to content

Commit 9b1df92

Browse files
authored
feat(pt/dp): add cosine LR & BaseLR (#5142)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added a cosine-annealing learning-rate schedule alongside the existing exponential option. * **Configuration** * Training can now select between exponential and cosine schedules; selection and error handling improved. * Both variants are exposed via the argument-registration system for configuration. * **Tests** * Added unit tests validating the cosine curve (start, end, midpoint, and steady final value). * **Refactor** * Introduced a common learning-rate schedule base and refactored the exponential schedule to use it. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 8f021aa commit 9b1df92

File tree

5 files changed

+133
-11
lines changed

5 files changed

+133
-11
lines changed

deepmd/dpmodel/utils/learning_rate.py

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,56 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from abc import (
3+
ABC,
4+
abstractmethod,
5+
)
26
from typing import (
37
Any,
48
)
59

610
import numpy as np
711

12+
from deepmd.common import (
13+
j_get_type,
14+
)
15+
from deepmd.utils.plugin import (
16+
PluginVariant,
17+
make_plugin_registry,
18+
)
19+
20+
21+
class BaseLR(ABC, PluginVariant, make_plugin_registry("lr")):
22+
def __new__(cls: type, *args: Any, **kwargs: Any) -> Any:
23+
if cls is BaseLR:
24+
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
25+
return super().__new__(cls)
26+
27+
def __init__(
28+
self, start_lr: float, stop_lr: float, stop_steps: int, **kwargs: Any
29+
) -> None:
30+
"""
31+
Base class for learning rate schedules.
32+
33+
Parameters
34+
----------
35+
start_lr
36+
The initial learning rate.
37+
stop_lr
38+
The final learning rate.
39+
stop_steps
40+
The total training steps for learning rate scheduler.
41+
"""
42+
self.start_lr = start_lr
43+
self.stop_lr = stop_lr
44+
self.stop_steps = stop_steps
45+
46+
@abstractmethod
47+
def value(self, step: int) -> np.float64:
48+
"""Get the learning rate at the given step."""
49+
pass
50+
851

9-
class LearningRateExp:
52+
@BaseLR.register("exp")
53+
class LearningRateExp(BaseLR):
1054
def __init__(
1155
self,
1256
start_lr: float,
@@ -37,7 +81,7 @@ def __init__(
3781
If provided, the decay rate will be set instead of
3882
calculating it through interpolation between start_lr and stop_lr.
3983
"""
40-
self.start_lr = start_lr
84+
super().__init__(start_lr, stop_lr, stop_steps, **kwargs)
4185
default_ds = 100 if stop_steps // 10 > 100 else stop_steps // 100 + 1
4286
self.decay_steps = decay_steps
4387
if self.decay_steps >= stop_steps:
@@ -47,11 +91,49 @@ def __init__(
4791
)
4892
if decay_rate is not None:
4993
self.decay_rate = decay_rate
50-
self.min_lr = stop_lr
94+
self.min_lr = self.stop_lr
5195

5296
def value(self, step: int) -> np.float64:
5397
"""Get the learning rate at the given step."""
5498
step_lr = self.start_lr * np.power(self.decay_rate, step // self.decay_steps)
5599
if step_lr < self.min_lr:
56100
step_lr = self.min_lr
57101
return step_lr
102+
103+
104+
@BaseLR.register("cosine")
105+
class LearningRateCosine(BaseLR):
106+
def __init__(
107+
self,
108+
start_lr: float,
109+
stop_lr: float,
110+
stop_steps: int,
111+
**kwargs: Any,
112+
) -> None:
113+
"""
114+
Defines a cosine annealing learning rate schedule.
115+
The learning rate starts at `start_lr` and gradually decreases to `stop_lr`
116+
following a cosine curve over the training steps.
117+
118+
Parameters
119+
----------
120+
start_lr
121+
The initial learning rate at the beginning of training.
122+
stop_lr
123+
The final learning rate at the end of training.
124+
stop_steps
125+
The total number of training steps over which the learning rate
126+
will be annealed from start_lr to stop_lr.
127+
"""
128+
super().__init__(start_lr, stop_lr, stop_steps, **kwargs)
129+
self.lr_min_factor = stop_lr / start_lr
130+
131+
def value(self, step: int) -> np.float64:
132+
if step >= self.stop_steps:
133+
return self.start_lr * self.lr_min_factor
134+
return self.start_lr * (
135+
self.lr_min_factor
136+
+ 0.5
137+
* (1 - self.lr_min_factor)
138+
* (1 + np.cos(np.pi * (step / self.stop_steps)))
139+
)

deepmd/pt/train/training.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
SAMPLER_RECORD,
6464
)
6565
from deepmd.pt.utils.learning_rate import (
66-
LearningRateExp,
66+
BaseLR,
6767
)
6868
from deepmd.pt.utils.stat import (
6969
make_stat_input,
@@ -266,13 +266,10 @@ def get_sample() -> Any:
266266
_stat_file_path.root.close()
267267
return get_sample
268268

269-
def get_lr(lr_params: dict[str, Any]) -> LearningRateExp:
270-
assert lr_params.get("type", "exp") == "exp", (
271-
"Only learning rate `exp` is supported!"
272-
)
269+
def get_lr(lr_params: dict[str, Any]) -> BaseLR:
273270
lr_params["stop_steps"] = self.num_steps - self.warmup_steps
274-
lr_exp = LearningRateExp(**lr_params)
275-
return lr_exp
271+
lr_schedule = BaseLR(**lr_params)
272+
return lr_schedule
276273

277274
# Optimizer
278275
if self.multi_task and training_params.get("optim_dict", None) is not None:

deepmd/pt/utils/learning_rate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
from deepmd.dpmodel.utils.learning_rate import (
3+
BaseLR,
4+
LearningRateCosine,
35
LearningRateExp,
46
)
57

68
__all__ = [
9+
"BaseLR",
10+
"LearningRateCosine",
711
"LearningRateExp",
812
]

deepmd/utils/argcheck.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2477,6 +2477,10 @@ def linear_ener_model_args() -> Argument:
24772477

24782478

24792479
# --- Learning rate configurations: --- #
2480+
lr_args_plugin = ArgsPlugin()
2481+
2482+
2483+
@lr_args_plugin.register("exp")
24802484
def learning_rate_exp() -> list[Argument]:
24812485
doc_start_lr = "The learning rate at the start of the training."
24822486
doc_stop_lr = (
@@ -2509,12 +2513,30 @@ def learning_rate_exp() -> list[Argument]:
25092513
return args
25102514

25112515

2516+
@lr_args_plugin.register("cosine", doc=doc_only_pt_supported)
2517+
def learning_rate_cosine() -> list[Argument]:
2518+
"""
2519+
Defines a cosine annealing learning rate schedule.
2520+
2521+
The learning rate starts at `start_lr` and gradually decreases to `stop_lr`
2522+
following a cosine curve over the training steps.
2523+
"""
2524+
doc_start_lr = "The learning rate at the start of the training."
2525+
doc_stop_lr = "The desired learning rate at the end of the training. "
2526+
2527+
args = [
2528+
Argument("start_lr", float, optional=True, default=1e-3, doc=doc_start_lr),
2529+
Argument("stop_lr", float, optional=True, default=1e-5, doc=doc_stop_lr),
2530+
]
2531+
return args
2532+
2533+
25122534
def learning_rate_variant_type_args() -> Variant:
25132535
doc_lr = "The type of the learning rate."
25142536

25152537
return Variant(
25162538
"type",
2517-
[Argument("exp", dict, learning_rate_exp())],
2539+
lr_args_plugin.get_all_argument(),
25182540
optional=True,
25192541
default_tag="exp",
25202542
doc=doc_lr,

source/tests/pt/test_lr.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
tf.disable_eager_execution()
88

99
from deepmd.pt.utils.learning_rate import (
10+
LearningRateCosine,
1011
LearningRateExp,
1112
)
1213
from deepmd.tf.utils import (
@@ -102,5 +103,21 @@ def decay_rate_pt(self) -> None:
102103
)
103104

104105

106+
class TestLearningRateCosine(unittest.TestCase):
107+
def test_basic_curve(self) -> None:
108+
start_lr = 1.0
109+
stop_lr = 0.1
110+
stop_steps = 10
111+
lr = LearningRateCosine(start_lr, stop_lr, stop_steps)
112+
113+
self.assertTrue(np.allclose(lr.value(0), start_lr))
114+
self.assertTrue(np.allclose(lr.value(stop_steps), stop_lr))
115+
self.assertTrue(np.allclose(lr.value(stop_steps + 5), stop_lr))
116+
117+
mid_step = stop_steps // 2
118+
expected_mid = stop_lr + (start_lr - stop_lr) * 0.5
119+
self.assertTrue(np.allclose(lr.value(mid_step), expected_mid))
120+
121+
105122
if __name__ == "__main__":
106123
unittest.main()

0 commit comments

Comments
 (0)