Skip to content

Commit f271bf8

Browse files
committed
feat: add WSD scheduler
1 parent b2805fb commit f271bf8

File tree

9 files changed

+650
-1
lines changed

9 files changed

+650
-1
lines changed

deepmd/dpmodel/utils/learning_rate.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,189 @@ def _decay_value(self, step: int | Array) -> Array:
393393
return step_lr
394394

395395

396+
@BaseLR.register("wsd")
397+
class LearningRateWSD(BaseLR):
398+
r"""
399+
Warmup-stable-decay learning rate schedule with configurable decay rules.
400+
401+
The schedule uses the shared warmup implementation from :class:`BaseLR`,
402+
then keeps the learning rate at ``start_lr`` during the stable phase, and
403+
finally applies one of the supported decay rules.
404+
405+
Let :math:`\tau \in [0, 1]` denote the normalized progress within the
406+
decay phase.
407+
408+
**Inverse-linear mode (``decay_type="inverse_linear"``):**
409+
410+
.. math::
411+
412+
lr(t) = \frac{1}{
413+
\tau / lr_{\text{stop}} + (1 - \tau) / lr_0
414+
}
415+
416+
**Cosine mode (``decay_type="cosine"``):**
417+
418+
.. math::
419+
420+
lr(t) = lr_{\text{stop}} +
421+
\frac{lr_0 - lr_{\text{stop}}}{2}
422+
\left(1 + \cos(\pi \tau)\right)
423+
424+
**Linear mode (``decay_type="linear"``):**
425+
426+
.. math::
427+
428+
lr(t) = lr_0 + \left(lr_{\text{stop}} - lr_0\right)\tau
429+
"""
430+
431+
def __init__(
432+
self,
433+
start_lr: float,
434+
num_steps: int,
435+
stop_lr: float | None = None,
436+
stop_lr_ratio: float | None = None,
437+
warmup_steps: int = 0,
438+
warmup_ratio: float | None = None,
439+
warmup_start_factor: float = 0.0,
440+
decay_phase_ratio: float = 0.1,
441+
decay_type: str = "inverse_linear",
442+
**kwargs: Any,
443+
) -> None:
444+
"""
445+
Construct a warmup-stable-decay learning rate schedule.
446+
447+
Parameters
448+
----------
449+
start_lr : float
450+
The learning rate at the start of the stable phase.
451+
num_steps : int
452+
The total training steps (including warmup).
453+
stop_lr : float, optional
454+
The final learning rate at the end of training.
455+
Mutually exclusive with stop_lr_ratio.
456+
stop_lr_ratio : float, optional
457+
The ratio of stop_lr to start_lr.
458+
Mutually exclusive with stop_lr.
459+
warmup_steps : int, optional
460+
The number of warmup steps.
461+
Mutually exclusive with warmup_ratio. Default is 0.
462+
warmup_ratio : float, optional
463+
The ratio of warmup steps to total training steps.
464+
Mutually exclusive with warmup_steps.
465+
warmup_start_factor : float, optional
466+
The factor of start_lr for the initial warmup learning rate.
467+
Default is 0.0.
468+
decay_phase_ratio : float, optional
469+
The ratio of the decay phase to total training steps.
470+
Default is 0.1.
471+
decay_type : str, optional
472+
The decay rule used in the decay phase.
473+
Supported values are ``inverse_linear``, ``cosine`` and ``linear``.
474+
Default is ``inverse_linear``.
475+
476+
Raises
477+
------
478+
ValueError
479+
If the learning rates are non-positive.
480+
If decay_phase_ratio is not in (0, 1].
481+
If decay_type is invalid.
482+
If the derived decay phase is empty or exceeds post-warmup steps.
483+
"""
484+
super().__init__(
485+
start_lr=start_lr,
486+
stop_lr=stop_lr,
487+
stop_lr_ratio=stop_lr_ratio,
488+
num_steps=num_steps,
489+
warmup_steps=warmup_steps,
490+
warmup_ratio=warmup_ratio,
491+
warmup_start_factor=warmup_start_factor,
492+
**kwargs,
493+
)
494+
495+
# === Validate WSD-specific invariants ===
496+
if self._start_lr <= 0:
497+
raise ValueError(f"start_lr ({self._start_lr}) must be positive.")
498+
if self.stop_lr <= 0:
499+
raise ValueError(f"stop_lr ({self.stop_lr}) must be positive.")
500+
if decay_phase_ratio <= 0 or decay_phase_ratio > 1:
501+
raise ValueError(
502+
f"decay_phase_ratio ({decay_phase_ratio}) must be in (0, 1]."
503+
)
504+
if decay_type not in ("inverse_linear", "cosine", "linear"):
505+
raise ValueError(
506+
"decay_type must be one of "
507+
f"{('inverse_linear', 'cosine', 'linear')}. "
508+
f"Got decay_type={decay_type}."
509+
)
510+
511+
# === Derive stable and decay phase lengths ===
512+
self.decay_phase_ratio = decay_phase_ratio
513+
self.decay_type = decay_type
514+
self.decay_phase_steps = int(self.decay_phase_ratio * self.num_steps)
515+
if self.decay_phase_steps <= 0:
516+
raise ValueError(
517+
"decay_phase_ratio results in zero decay steps. "
518+
"Increase num_steps or decay_phase_ratio."
519+
)
520+
if self.decay_phase_steps > self.decay_num_steps:
521+
raise ValueError(
522+
"decay phase steps must not exceed the post-warmup steps. "
523+
f"Got decay_phase_steps={self.decay_phase_steps}, "
524+
f"post_warmup_steps={self.decay_num_steps}."
525+
)
526+
self.stable_steps = self.decay_num_steps - self.decay_phase_steps
527+
528+
def _decay_value(self, step: int | Array) -> Array:
529+
"""
530+
Get the warmup-stable-decay learning rate at the given step.
531+
532+
Parameters
533+
----------
534+
step : int or Array
535+
The step index relative to the end of warmup.
536+
537+
Returns
538+
-------
539+
Array
540+
The learning rate (absolute value).
541+
"""
542+
if not array_api_compat.is_array_api_obj(step):
543+
step = np.asarray(step)
544+
xp = array_api_compat.array_namespace(step)
545+
step_dtype = (
546+
step.dtype
547+
if xp.isdtype(step.dtype, "real floating")
548+
else get_xp_precision(xp, "global")
549+
)
550+
551+
# === Step 1. Build typed scalar constants ===
552+
typed_step = xp.astype(step, step_dtype)
553+
zero = xp.asarray(0.0, dtype=step_dtype)
554+
one = xp.asarray(1.0, dtype=step_dtype)
555+
start_lr = xp.asarray(self._start_lr, dtype=step_dtype)
556+
stop_lr = xp.asarray(self.stop_lr, dtype=step_dtype)
557+
stable_steps = xp.asarray(self.stable_steps, dtype=step_dtype)
558+
decay_phase_steps = xp.asarray(self.decay_phase_steps, dtype=step_dtype)
559+
decay_num_steps = xp.asarray(self.decay_num_steps, dtype=step_dtype)
560+
561+
# === Step 2. Keep a constant learning rate in the stable phase ===
562+
decay_progress = (typed_step - stable_steps) / decay_phase_steps
563+
tau = xp.clip(decay_progress, zero, one)
564+
565+
# === Step 3. Apply the selected interpolation in the decay phase ===
566+
if self.decay_type == "inverse_linear":
567+
decay_lr = one / (tau / stop_lr + (one - tau) / start_lr)
568+
elif self.decay_type == "cosine":
569+
decay_lr = stop_lr + (start_lr - stop_lr) * 0.5 * (
570+
one + xp.cos(xp.asarray(xp.pi * tau, dtype=step_dtype))
571+
)
572+
else:
573+
decay_lr = start_lr + (stop_lr - start_lr) * tau
574+
step_lr = xp.where(step < self.stable_steps, start_lr, decay_lr)
575+
step_lr = xp.where(step >= self.decay_num_steps, stop_lr, step_lr)
576+
return step_lr
577+
578+
396579
@BaseLR.register("cosine")
397580
class LearningRateCosine(BaseLR):
398581
r"""

deepmd/pd/utils/learning_rate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
from deepmd.dpmodel.utils.learning_rate import (
33
LearningRateExp,
4+
LearningRateWSD,
45
)
56

67
__all__ = [
78
"LearningRateExp",
9+
"LearningRateWSD",
810
]

deepmd/pt/utils/learning_rate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
BaseLR,
44
LearningRateCosine,
55
LearningRateExp,
6+
LearningRateWSD,
67
)
78

89
__all__ = [
910
"BaseLR",
1011
"LearningRateCosine",
1112
"LearningRateExp",
13+
"LearningRateWSD",
1214
]

deepmd/utils/argcheck.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2594,6 +2594,61 @@ def _check_decay_steps_args(data: dict[str, Any]) -> bool:
25942594
return True
25952595

25962596

2597+
def _check_wsd_args(data: dict[str, Any]) -> bool:
2598+
"""
2599+
Check WSD-specific learning rate arguments.
2600+
2601+
Parameters
2602+
----------
2603+
data : dict[str, Any]
2604+
The learning rate configuration dictionary.
2605+
2606+
Returns
2607+
-------
2608+
bool
2609+
True if validation passes.
2610+
2611+
Raises
2612+
------
2613+
ValueError
2614+
If the WSD-specific arguments are invalid.
2615+
"""
2616+
lr_type = data.get("type", "exp")
2617+
if lr_type != "wsd":
2618+
return True
2619+
2620+
start_lr = data.get("start_lr")
2621+
if start_lr is not None and start_lr <= 0:
2622+
raise ValueError(f"start_lr ({start_lr}) must be positive for WSD.")
2623+
2624+
stop_lr = data.get("stop_lr")
2625+
if stop_lr is not None and stop_lr <= 0:
2626+
raise ValueError(f"stop_lr ({stop_lr}) must be positive for WSD.")
2627+
2628+
stop_lr_ratio = data.get("stop_lr_ratio")
2629+
if stop_lr_ratio is not None and stop_lr_ratio <= 0:
2630+
raise ValueError(f"stop_lr_ratio ({stop_lr_ratio}) must be positive for WSD.")
2631+
2632+
decay_phase_ratio = data.get("decay_phase_ratio")
2633+
if decay_phase_ratio is not None and (
2634+
decay_phase_ratio <= 0 or decay_phase_ratio > 1
2635+
):
2636+
raise ValueError(f"decay_phase_ratio ({decay_phase_ratio}) must be in (0, 1].")
2637+
2638+
decay_type = data.get("decay_type")
2639+
if decay_type is not None and decay_type not in (
2640+
"inverse_linear",
2641+
"cosine",
2642+
"linear",
2643+
):
2644+
raise ValueError(
2645+
"decay_type must be one of "
2646+
f"{('inverse_linear', 'cosine', 'linear')}. "
2647+
f"Got decay_type={decay_type}."
2648+
)
2649+
return True
2650+
2651+
25972652
@lr_args_plugin.register("exp")
25982653
def learning_rate_exp() -> list[Argument]:
25992654
"""
@@ -2645,6 +2700,42 @@ def learning_rate_cosine() -> list[Argument]:
26452700
return []
26462701

26472702

2703+
@lr_args_plugin.register("wsd")
2704+
def learning_rate_wsd() -> list[Argument]:
2705+
"""
2706+
Defines a warmup-stable-decay learning rate schedule with configurable
2707+
decay rules.
2708+
2709+
The learning rate stays at `start_lr` during the stable phase and then
2710+
decays to `stop_lr` with the selected decay rule.
2711+
"""
2712+
doc_decay_phase_ratio = (
2713+
"The ratio of the decay phase to total training steps. "
2714+
"The remaining post-warmup steps are used as the stable phase. "
2715+
"Default is 0.1."
2716+
)
2717+
doc_decay_type = (
2718+
"The decay rule used in the decay phase. "
2719+
"Supported values are `inverse_linear` (default), `cosine`, and `linear`."
2720+
)
2721+
return [
2722+
Argument(
2723+
"decay_phase_ratio",
2724+
float,
2725+
optional=True,
2726+
default=0.1,
2727+
doc=doc_decay_phase_ratio,
2728+
),
2729+
Argument(
2730+
"decay_type",
2731+
str,
2732+
optional=True,
2733+
default="inverse_linear",
2734+
doc=doc_decay_type,
2735+
),
2736+
]
2737+
2738+
26482739
def learning_rate_variant_type_args() -> Variant:
26492740
doc_lr = "The type of the learning rate."
26502741

@@ -2694,6 +2785,8 @@ def _check_lr_args(data: dict[str, Any]) -> bool:
26942785
_check_warmup_args(data)
26952786
# Check decay_steps and decay_rate
26962787
_check_decay_steps_args(data)
2788+
# Check WSD-specific arguments
2789+
_check_wsd_args(data)
26972790
return True
26982791

26992792
# Common arguments for all learning rate types (outside Variant)

source/tests/consistent/test_learning_rate.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,32 @@
5151
"num_steps": 1000000,
5252
"warmup_steps": 10000,
5353
},
54+
{
55+
"type": "wsd",
56+
"start_lr": 1e-3,
57+
"stop_lr": 1e-8,
58+
"num_steps": 1000000,
59+
"warmup_steps": 10000,
60+
"decay_phase_ratio": 0.1,
61+
},
62+
{
63+
"type": "wsd",
64+
"start_lr": 1e-3,
65+
"stop_lr": 1e-8,
66+
"num_steps": 1000000,
67+
"warmup_steps": 10000,
68+
"decay_phase_ratio": 0.1,
69+
"decay_type": "cosine",
70+
},
71+
{
72+
"type": "wsd",
73+
"start_lr": 1e-3,
74+
"stop_lr": 1e-8,
75+
"num_steps": 1000000,
76+
"warmup_steps": 10000,
77+
"decay_phase_ratio": 0.1,
78+
"decay_type": "linear",
79+
},
5480
),
5581
)
5682
class TestLearningRateConsistent(unittest.TestCase):
@@ -59,7 +85,14 @@ class TestLearningRateConsistent(unittest.TestCase):
5985
def setUp(self) -> None:
6086
(lr_param,) = self.param
6187
self.lr = BaseLR(**lr_param)
62-
self.step = 500000
88+
if hasattr(self.lr, "stable_steps") and hasattr(self.lr, "decay_phase_steps"):
89+
self.step = (
90+
self.lr.warmup_steps
91+
+ self.lr.stable_steps
92+
+ self.lr.decay_phase_steps // 2
93+
)
94+
else:
95+
self.step = 500000
6396
self.ref = self.lr.value(self.step)
6497
self.warmup_step = None
6598
self.warmup_ref = None

0 commit comments

Comments
 (0)