Skip to content

Commit 0828604

Browse files
authored
feat: add WSD LR Scheduler (#5326)
The doc will be added through PR #5276 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added a "wsd" (warmup → stable → decay) learning-rate schedule with configurable warmup, stable duration, decay-phase ratio, and three decay modes: inverse_linear, cosine, linear. Validation prevents invalid parameter combinations and enforces sensible ranges. * Exposed the new schedule in the public API. * **Tests** * Added comprehensive tests for decay modes, warmup/stable behavior, edge cases, array/JIT inputs, and beyond-endstep behavior. * **Documentation** * Documented the new "wsd" schedule, parameters, examples, and mathematical definitions. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 8060865 commit 0828604

File tree

10 files changed

+799
-6
lines changed

10 files changed

+799
-6
lines changed

deepmd/dpmodel/utils/learning_rate.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,180 @@ def _decay_value(self, step: int | Array) -> Array:
393393
return step_lr
394394

395395

396+
@BaseLR.register("wsd")
397+
class LearningRateWSD(BaseLR):
398+
r"""
399+
Warmup-stable-decay learning rate schedule with configurable decay rules.
400+
401+
The schedule uses the shared warmup implementation from :class:`BaseLR`,
402+
then keeps the learning rate at ``start_lr`` during the stable phase, and
403+
finally applies one of the supported decay rules.
404+
405+
Let :math:`\tau \in [0, 1]` denote the normalized progress within the
406+
decay phase.
407+
408+
**Inverse-linear mode (``decay_type="inverse_linear"``):**
409+
410+
.. math::
411+
412+
lr(t) = \frac{1}{
413+
\tau / lr_{\text{stop}} + (1 - \tau) / lr_0
414+
}
415+
416+
**Cosine mode (``decay_type="cosine"``):**
417+
418+
.. math::
419+
420+
lr(t) = lr_{\text{stop}} +
421+
\frac{lr_0 - lr_{\text{stop}}}{2}
422+
\left(1 + \cos(\pi \tau)\right)
423+
424+
**Linear mode (``decay_type="linear"``):**
425+
426+
.. math::
427+
428+
lr(t) = lr_0 + \left(lr_{\text{stop}} - lr_0\right)\tau
429+
"""
430+
431+
def __init__(
432+
self,
433+
start_lr: float,
434+
num_steps: int,
435+
stop_lr: float | None = None,
436+
stop_lr_ratio: float | None = None,
437+
warmup_steps: int = 0,
438+
warmup_ratio: float | None = None,
439+
warmup_start_factor: float = 0.0,
440+
decay_phase_ratio: float = 0.1,
441+
decay_type: str = "inverse_linear",
442+
**kwargs: Any,
443+
) -> None:
444+
"""
445+
Construct a warmup-stable-decay learning rate schedule.
446+
447+
Parameters
448+
----------
449+
start_lr : float
450+
The learning rate at the start of the stable phase.
451+
num_steps : int
452+
The total training steps (including warmup).
453+
stop_lr : float, optional
454+
The final learning rate at the end of training.
455+
Mutually exclusive with stop_lr_ratio.
456+
stop_lr_ratio : float, optional
457+
The ratio of stop_lr to start_lr.
458+
Mutually exclusive with stop_lr.
459+
warmup_steps : int, optional
460+
The number of warmup steps.
461+
Mutually exclusive with warmup_ratio. Default is 0.
462+
warmup_ratio : float, optional
463+
The ratio of warmup steps to total training steps.
464+
Mutually exclusive with warmup_steps.
465+
warmup_start_factor : float, optional
466+
The factor of start_lr for the initial warmup learning rate.
467+
Default is 0.0.
468+
decay_phase_ratio : float, optional
469+
The ratio of the decay phase to total training steps.
470+
Default is 0.1.
471+
decay_type : str, optional
472+
The decay rule used in the decay phase.
473+
Supported values are ``inverse_linear``, ``cosine`` and ``linear``.
474+
Default is ``inverse_linear``.
475+
476+
Raises
477+
------
478+
ValueError
479+
If the learning rates are non-positive.
480+
If decay_phase_ratio is not in (0, 1].
481+
If decay_type is invalid.
482+
If the derived decay phase is empty or exceeds post-warmup steps.
483+
"""
484+
super().__init__(
485+
start_lr=start_lr,
486+
stop_lr=stop_lr,
487+
stop_lr_ratio=stop_lr_ratio,
488+
num_steps=num_steps,
489+
warmup_steps=warmup_steps,
490+
warmup_ratio=warmup_ratio,
491+
warmup_start_factor=warmup_start_factor,
492+
**kwargs,
493+
)
494+
495+
# === Validate WSD-specific invariants ===
496+
if self._start_lr <= 0:
497+
raise ValueError(f"start_lr ({self._start_lr}) must be positive.")
498+
if self.stop_lr <= 0:
499+
raise ValueError(f"stop_lr ({self.stop_lr}) must be positive.")
500+
if decay_phase_ratio <= 0 or decay_phase_ratio > 1:
501+
raise ValueError(
502+
f"decay_phase_ratio ({decay_phase_ratio}) must be in (0, 1]."
503+
)
504+
if decay_type not in ("inverse_linear", "cosine", "linear"):
505+
raise ValueError(
506+
"decay_type must be one of "
507+
f"{('inverse_linear', 'cosine', 'linear')}. "
508+
f"Got decay_type={decay_type}."
509+
)
510+
511+
# === Derive stable and decay phase lengths ===
512+
self.decay_phase_ratio = decay_phase_ratio
513+
self.decay_type = decay_type
514+
# Clamp decay_phase_steps to valid range [1, decay_num_steps]
515+
self.decay_phase_steps = max(
516+
1, min(int(self.decay_phase_ratio * self.num_steps), self.decay_num_steps)
517+
)
518+
self.stable_steps = self.decay_num_steps - self.decay_phase_steps
519+
520+
def _decay_value(self, step: int | Array) -> Array:
521+
"""
522+
Get the warmup-stable-decay learning rate at the given step.
523+
524+
Parameters
525+
----------
526+
step : int or Array
527+
The step index relative to the end of warmup.
528+
529+
Returns
530+
-------
531+
Array
532+
The learning rate (absolute value).
533+
"""
534+
if not array_api_compat.is_array_api_obj(step):
535+
step = np.asarray(step)
536+
xp = array_api_compat.array_namespace(step)
537+
step_dtype = (
538+
step.dtype
539+
if xp.isdtype(step.dtype, "real floating")
540+
else get_xp_precision(xp, "global")
541+
)
542+
543+
# === Step 1. Build typed scalar constants ===
544+
typed_step = xp.astype(step, step_dtype)
545+
zero = xp.asarray(0.0, dtype=step_dtype)
546+
one = xp.asarray(1.0, dtype=step_dtype)
547+
start_lr = xp.asarray(self._start_lr, dtype=step_dtype)
548+
stop_lr = xp.asarray(self.stop_lr, dtype=step_dtype)
549+
stable_steps = xp.asarray(self.stable_steps, dtype=step_dtype)
550+
decay_phase_steps = xp.asarray(self.decay_phase_steps, dtype=step_dtype)
551+
552+
# === Step 2. Keep a constant learning rate in the stable phase ===
553+
decay_progress = (typed_step - stable_steps) / decay_phase_steps
554+
tau = xp.clip(decay_progress, zero, one)
555+
556+
# === Step 3. Apply the selected interpolation in the decay phase ===
557+
if self.decay_type == "inverse_linear":
558+
decay_lr = one / (tau / stop_lr + (one - tau) / start_lr)
559+
elif self.decay_type == "cosine":
560+
decay_lr = stop_lr + (start_lr - stop_lr) * 0.5 * (
561+
one + xp.cos(xp.asarray(xp.pi * tau, dtype=step_dtype))
562+
)
563+
else:
564+
decay_lr = start_lr + (stop_lr - start_lr) * tau
565+
step_lr = xp.where(step < self.stable_steps, start_lr, decay_lr)
566+
step_lr = xp.where(step >= self.decay_num_steps, stop_lr, step_lr)
567+
return step_lr
568+
569+
396570
@BaseLR.register("cosine")
397571
class LearningRateCosine(BaseLR):
398572
r"""

deepmd/pd/utils/learning_rate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
from deepmd.dpmodel.utils.learning_rate import (
33
LearningRateExp,
4+
LearningRateWSD,
45
)
56

67
__all__ = [
78
"LearningRateExp",
9+
"LearningRateWSD",
810
]

deepmd/pt/utils/learning_rate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
BaseLR,
44
LearningRateCosine,
55
LearningRateExp,
6+
LearningRateWSD,
67
)
78

89
__all__ = [
910
"BaseLR",
1011
"LearningRateCosine",
1112
"LearningRateExp",
13+
"LearningRateWSD",
1214
]

deepmd/utils/argcheck.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2594,6 +2594,61 @@ def _check_decay_steps_args(data: dict[str, Any]) -> bool:
25942594
return True
25952595

25962596

2597+
def _check_wsd_args(data: dict[str, Any]) -> bool:
2598+
"""
2599+
Check WSD-specific learning rate arguments.
2600+
2601+
Parameters
2602+
----------
2603+
data : dict[str, Any]
2604+
The learning rate configuration dictionary.
2605+
2606+
Returns
2607+
-------
2608+
bool
2609+
True if validation passes.
2610+
2611+
Raises
2612+
------
2613+
ValueError
2614+
If the WSD-specific arguments are invalid.
2615+
"""
2616+
lr_type = data.get("type", "exp")
2617+
if lr_type != "wsd":
2618+
return True
2619+
2620+
start_lr = data.get("start_lr")
2621+
if start_lr is not None and start_lr <= 0:
2622+
raise ValueError(f"start_lr ({start_lr}) must be positive for WSD.")
2623+
2624+
stop_lr = data.get("stop_lr")
2625+
if stop_lr is not None and stop_lr <= 0:
2626+
raise ValueError(f"stop_lr ({stop_lr}) must be positive for WSD.")
2627+
2628+
stop_lr_ratio = data.get("stop_lr_ratio")
2629+
if stop_lr_ratio is not None and stop_lr_ratio <= 0:
2630+
raise ValueError(f"stop_lr_ratio ({stop_lr_ratio}) must be positive for WSD.")
2631+
2632+
decay_phase_ratio = data.get("decay_phase_ratio")
2633+
if decay_phase_ratio is not None and (
2634+
decay_phase_ratio <= 0 or decay_phase_ratio > 1
2635+
):
2636+
raise ValueError(f"decay_phase_ratio ({decay_phase_ratio}) must be in (0, 1].")
2637+
2638+
decay_type = data.get("decay_type")
2639+
if decay_type is not None and decay_type not in (
2640+
"inverse_linear",
2641+
"cosine",
2642+
"linear",
2643+
):
2644+
raise ValueError(
2645+
"decay_type must be one of "
2646+
f"{('inverse_linear', 'cosine', 'linear')}. "
2647+
f"Got decay_type={decay_type}."
2648+
)
2649+
return True
2650+
2651+
25972652
@lr_args_plugin.register("exp")
25982653
def learning_rate_exp() -> list[Argument]:
25992654
"""
@@ -2645,6 +2700,42 @@ def learning_rate_cosine() -> list[Argument]:
26452700
return []
26462701

26472702

2703+
@lr_args_plugin.register("wsd")
2704+
def learning_rate_wsd() -> list[Argument]:
2705+
"""
2706+
Defines a warmup-stable-decay learning rate schedule with configurable
2707+
decay rules.
2708+
2709+
The learning rate stays at `start_lr` during the stable phase and then
2710+
decays to `stop_lr` with the selected decay rule.
2711+
"""
2712+
doc_decay_phase_ratio = (
2713+
"The ratio of the decay phase to total training steps. "
2714+
"The remaining post-warmup steps are used as the stable phase. "
2715+
"Default is 0.1."
2716+
)
2717+
doc_decay_type = (
2718+
"The decay rule used in the decay phase. "
2719+
"Supported values are `inverse_linear` (default), `cosine`, and `linear`."
2720+
)
2721+
return [
2722+
Argument(
2723+
"decay_phase_ratio",
2724+
float,
2725+
optional=True,
2726+
default=0.1,
2727+
doc=doc_decay_phase_ratio,
2728+
),
2729+
Argument(
2730+
"decay_type",
2731+
str,
2732+
optional=True,
2733+
default="inverse_linear",
2734+
doc=doc_decay_type,
2735+
),
2736+
]
2737+
2738+
26482739
def learning_rate_variant_type_args() -> Variant:
26492740
doc_lr = "The type of the learning rate."
26502741

@@ -2694,6 +2785,8 @@ def _check_lr_args(data: dict[str, Any]) -> bool:
26942785
_check_warmup_args(data)
26952786
# Check decay_steps and decay_rate
26962787
_check_decay_steps_args(data)
2788+
# Check WSD-specific arguments
2789+
_check_wsd_args(data)
26972790
return True
26982791

26992792
# Common arguments for all learning rate types (outside Variant)

0 commit comments

Comments
 (0)