@@ -393,6 +393,189 @@ def _decay_value(self, step: int | Array) -> Array:
393393 return step_lr
394394
395395
396+ @BaseLR .register ("wsd" )
397+ class LearningRateWSD (BaseLR ):
398+ r"""
399+ Warmup-stable-decay learning rate schedule with configurable decay rules.
400+
401+ The schedule uses the shared warmup implementation from :class:`BaseLR`,
402+ then keeps the learning rate at ``start_lr`` during the stable phase, and
403+ finally applies one of the supported decay rules.
404+
405+ Let :math:`\tau \in [0, 1]` denote the normalized progress within the
406+ decay phase.
407+
408+ **Inverse-linear mode (``decay_type="inverse_linear"``):**
409+
410+ .. math::
411+
412+ lr(t) = \frac{1}{
413+ \tau / lr_{\text{stop}} + (1 - \tau) / lr_0
414+ }
415+
416+ **Cosine mode (``decay_type="cosine"``):**
417+
418+ .. math::
419+
420+ lr(t) = lr_{\text{stop}} +
421+ \frac{lr_0 - lr_{\text{stop}}}{2}
422+ \left(1 + \cos(\pi \tau)\right)
423+
424+ **Linear mode (``decay_type="linear"``):**
425+
426+ .. math::
427+
428+ lr(t) = lr_0 + \left(lr_{\text{stop}} - lr_0\right)\tau
429+ """
430+
431+ def __init__ (
432+ self ,
433+ start_lr : float ,
434+ num_steps : int ,
435+ stop_lr : float | None = None ,
436+ stop_lr_ratio : float | None = None ,
437+ warmup_steps : int = 0 ,
438+ warmup_ratio : float | None = None ,
439+ warmup_start_factor : float = 0.0 ,
440+ decay_phase_ratio : float = 0.1 ,
441+ decay_type : str = "inverse_linear" ,
442+ ** kwargs : Any ,
443+ ) -> None :
444+ """
445+ Construct a warmup-stable-decay learning rate schedule.
446+
447+ Parameters
448+ ----------
449+ start_lr : float
450+ The learning rate at the start of the stable phase.
451+ num_steps : int
452+ The total training steps (including warmup).
453+ stop_lr : float, optional
454+ The final learning rate at the end of training.
455+ Mutually exclusive with stop_lr_ratio.
456+ stop_lr_ratio : float, optional
457+ The ratio of stop_lr to start_lr.
458+ Mutually exclusive with stop_lr.
459+ warmup_steps : int, optional
460+ The number of warmup steps.
461+ Mutually exclusive with warmup_ratio. Default is 0.
462+ warmup_ratio : float, optional
463+ The ratio of warmup steps to total training steps.
464+ Mutually exclusive with warmup_steps.
465+ warmup_start_factor : float, optional
466+ The factor of start_lr for the initial warmup learning rate.
467+ Default is 0.0.
468+ decay_phase_ratio : float, optional
469+ The ratio of the decay phase to total training steps.
470+ Default is 0.1.
471+ decay_type : str, optional
472+ The decay rule used in the decay phase.
473+ Supported values are ``inverse_linear``, ``cosine`` and ``linear``.
474+ Default is ``inverse_linear``.
475+
476+ Raises
477+ ------
478+ ValueError
479+ If the learning rates are non-positive.
480+ If decay_phase_ratio is not in (0, 1].
481+ If decay_type is invalid.
482+ If the derived decay phase is empty or exceeds post-warmup steps.
483+ """
484+ super ().__init__ (
485+ start_lr = start_lr ,
486+ stop_lr = stop_lr ,
487+ stop_lr_ratio = stop_lr_ratio ,
488+ num_steps = num_steps ,
489+ warmup_steps = warmup_steps ,
490+ warmup_ratio = warmup_ratio ,
491+ warmup_start_factor = warmup_start_factor ,
492+ ** kwargs ,
493+ )
494+
495+ # === Validate WSD-specific invariants ===
496+ if self ._start_lr <= 0 :
497+ raise ValueError (f"start_lr ({ self ._start_lr } ) must be positive." )
498+ if self .stop_lr <= 0 :
499+ raise ValueError (f"stop_lr ({ self .stop_lr } ) must be positive." )
500+ if decay_phase_ratio <= 0 or decay_phase_ratio > 1 :
501+ raise ValueError (
502+ f"decay_phase_ratio ({ decay_phase_ratio } ) must be in (0, 1]."
503+ )
504+ if decay_type not in ("inverse_linear" , "cosine" , "linear" ):
505+ raise ValueError (
506+ "decay_type must be one of "
507+ f"{ ('inverse_linear' , 'cosine' , 'linear' )} . "
508+ f"Got decay_type={ decay_type } ."
509+ )
510+
511+ # === Derive stable and decay phase lengths ===
512+ self .decay_phase_ratio = decay_phase_ratio
513+ self .decay_type = decay_type
514+ self .decay_phase_steps = int (self .decay_phase_ratio * self .num_steps )
515+ if self .decay_phase_steps <= 0 :
516+ raise ValueError (
517+ "decay_phase_ratio results in zero decay steps. "
518+ "Increase num_steps or decay_phase_ratio."
519+ )
520+ if self .decay_phase_steps > self .decay_num_steps :
521+ raise ValueError (
522+ "decay phase steps must not exceed the post-warmup steps. "
523+ f"Got decay_phase_steps={ self .decay_phase_steps } , "
524+ f"post_warmup_steps={ self .decay_num_steps } ."
525+ )
526+ self .stable_steps = self .decay_num_steps - self .decay_phase_steps
527+
528+ def _decay_value (self , step : int | Array ) -> Array :
529+ """
530+ Get the warmup-stable-decay learning rate at the given step.
531+
532+ Parameters
533+ ----------
534+ step : int or Array
535+ The step index relative to the end of warmup.
536+
537+ Returns
538+ -------
539+ Array
540+ The learning rate (absolute value).
541+ """
542+ if not array_api_compat .is_array_api_obj (step ):
543+ step = np .asarray (step )
544+ xp = array_api_compat .array_namespace (step )
545+ step_dtype = (
546+ step .dtype
547+ if xp .isdtype (step .dtype , "real floating" )
548+ else get_xp_precision (xp , "global" )
549+ )
550+
551+ # === Step 1. Build typed scalar constants ===
552+ typed_step = xp .astype (step , step_dtype )
553+ zero = xp .asarray (0.0 , dtype = step_dtype )
554+ one = xp .asarray (1.0 , dtype = step_dtype )
555+ start_lr = xp .asarray (self ._start_lr , dtype = step_dtype )
556+ stop_lr = xp .asarray (self .stop_lr , dtype = step_dtype )
557+ stable_steps = xp .asarray (self .stable_steps , dtype = step_dtype )
558+ decay_phase_steps = xp .asarray (self .decay_phase_steps , dtype = step_dtype )
559+ decay_num_steps = xp .asarray (self .decay_num_steps , dtype = step_dtype )
560+
561+ # === Step 2. Keep a constant learning rate in the stable phase ===
562+ decay_progress = (typed_step - stable_steps ) / decay_phase_steps
563+ tau = xp .clip (decay_progress , zero , one )
564+
565+ # === Step 3. Apply the selected interpolation in the decay phase ===
566+ if self .decay_type == "inverse_linear" :
567+ decay_lr = one / (tau / stop_lr + (one - tau ) / start_lr )
568+ elif self .decay_type == "cosine" :
569+ decay_lr = stop_lr + (start_lr - stop_lr ) * 0.5 * (
570+ one + xp .cos (xp .asarray (xp .pi * tau , dtype = step_dtype ))
571+ )
572+ else :
573+ decay_lr = start_lr + (stop_lr - start_lr ) * tau
574+ step_lr = xp .where (step < self .stable_steps , start_lr , decay_lr )
575+ step_lr = xp .where (step >= self .decay_num_steps , stop_lr , step_lr )
576+ return step_lr
577+
578+
396579@BaseLR .register ("cosine" )
397580class LearningRateCosine (BaseLR ):
398581 r"""
0 commit comments