Skip to content

Commit 314b946

Browse files
refactor: unify learning rate schedulers with array API (#5154)
- Refactor BaseLR in dpmodel to use array_api_compat for backend-agnostic implementation - Consolidate learning rate logic from TF/PT/PD backends into unified dpmodel layer - Use array API operations (xp.where, xp.clip, etc.) for JIT compatibility across backends - Add warmup support (warmup_steps, warmup_ratio, warmup_start_factor) during refactoring - Add stop_ratio parameter as alternative to stop_lr for flexible configuration - Implement mutual exclusion validation for stop_lr/stop_ratio and warmup_steps/warmup_ratio - Update all backends to use unified BaseLR implementation - Add comprehensive consistency tests across NumPy/PyTorch/JAX/array_api_strict backends <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added comprehensive warmup support for learning rate schedules with configurable warmup steps, ratios, and start factors. * Enhanced learning rate scheduling with unified configuration across TensorFlow, PyTorch, and Paddle backends. * Introduced flexible stop learning rate configuration using either absolute values or ratios. * **Improvements** * Moved warmup configuration from training to learning rate settings for consistency. * Added automatic migration of legacy warmup settings for backward compatibility. * Expanded cosine annealing schedule support with proper warmup integration. * **Documentation** * Added comprehensive learning rate scheduling documentation with examples. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent e471690 commit 314b946

40 files changed

Lines changed: 1664 additions & 446 deletions

deepmd/dpmodel/utils/learning_rate.py

Lines changed: 415 additions & 60 deletions
Large diffs are not rendered by default.

deepmd/pd/train/training.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def get_sample() -> dict[str, Any]:
251251
return get_sample
252252

253253
def get_lr(lr_params: dict[str, Any]) -> BaseLR:
254-
lr_params["stop_steps"] = self.num_steps - self.warmup_steps
254+
lr_params["num_steps"] = self.num_steps
255255
lr_schedule = BaseLR(**lr_params)
256256
return lr_schedule
257257

@@ -475,17 +475,15 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
475475
)
476476

477477
# Learning rate
478-
self.warmup_steps = training_params.get("warmup_steps", 0)
479478
self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0)
480-
assert self.num_steps - self.warmup_steps > 0 or self.warmup_steps == 0, (
481-
"Warm up steps must be less than total training steps!"
482-
)
483479
if self.multi_task and config.get("learning_rate_dict", None) is not None:
484-
self.lr_exp = {}
480+
self.lr_schedule = {}
485481
for model_key in self.model_keys:
486-
self.lr_exp[model_key] = get_lr(config["learning_rate_dict"][model_key])
482+
self.lr_schedule[model_key] = get_lr(
483+
config["learning_rate_dict"][model_key]
484+
)
487485
else:
488-
self.lr_exp = get_lr(config["learning_rate"])
486+
self.lr_schedule = get_lr(config["learning_rate"])
489487

490488
# JIT
491489
if JIT:
@@ -668,18 +666,15 @@ def single_model_finetune(
668666

669667
# TODO add lr warmups for multitask
670668
# author: iProzd
671-
def warm_up_linear(step: int, warmup_steps: int) -> float:
672-
if step < warmup_steps:
673-
return step / warmup_steps
674-
else:
675-
return self.lr_exp.value(step - warmup_steps) / self.lr_exp.start_lr
676-
677669
# TODO add optimizers for multitask
678670
# author: iProzd
679671
if self.opt_type == "Adam":
680672
self.scheduler = paddle.optimizer.lr.LambdaDecay(
681-
learning_rate=self.lr_exp.start_lr,
682-
lr_lambda=lambda step: warm_up_linear(step, self.warmup_steps),
673+
learning_rate=self.lr_schedule.start_lr,
674+
lr_lambda=lambda step: (
675+
self.lr_schedule.value(step + self.start_step)
676+
/ self.lr_schedule.start_lr
677+
),
683678
)
684679
self.optimizer = paddle.optimizer.Adam(
685680
learning_rate=self.scheduler, parameters=self.wrapper.parameters()
@@ -811,10 +806,10 @@ def step(_step_id: int, task_key: str = "Default") -> None:
811806
# Paddle Profiler
812807
if enable_profiling:
813808
core.nvprof_nvtx_push(f"Training step {_step_id}")
814-
if isinstance(self.lr_exp, dict):
815-
_lr = self.lr_exp[task_key]
809+
if isinstance(self.lr_schedule, dict):
810+
_lr = self.lr_schedule[task_key]
816811
else:
817-
_lr = self.lr_exp
812+
_lr = self.lr_schedule
818813
cur_lr = _lr.value(_step_id)
819814
pref_lr = cur_lr
820815

@@ -828,10 +823,7 @@ def step(_step_id: int, task_key: str = "Default") -> None:
828823
fout1.flush()
829824
if self.opt_type == "Adam":
830825
cur_lr = self.scheduler.get_lr()
831-
if _step_id < self.warmup_steps:
832-
pref_lr = _lr.start_lr
833-
else:
834-
pref_lr = cur_lr
826+
pref_lr = cur_lr
835827

836828
# disable synchronization in forward-backward manually
837829
# as derivatives exist in model forward
@@ -1072,7 +1064,7 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
10721064
_bias_adjust_mode="change-by-statistic",
10731065
)
10741066
self.latest_model = Path(self.save_ckpt + f"-{self.num_steps}.pd")
1075-
cur_lr = self.lr_exp.value(self.num_steps - 1)
1067+
cur_lr = self.lr_schedule.value(self.num_steps - 1)
10761068
self.save_model(self.latest_model, lr=cur_lr, step=self.num_steps - 1)
10771069
log.info(f"Saved model to {self.latest_model}")
10781070
symlink_prefix_files(self.latest_model.stem, self.save_ckpt)

deepmd/pd/utils/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
from .env import (
3636
DEVICE,
37+
GLOBAL_NP_FLOAT_PRECISION,
3738
)
3839
from .env import PRECISION_DICT as PD_PRECISION_DICT
3940

@@ -257,7 +258,8 @@ def to_numpy_array(
257258
):
258259
if xx is None:
259260
return None
260-
assert xx is not None
261+
if isinstance(xx, (float, int)):
262+
return np.array(xx, dtype=GLOBAL_NP_FLOAT_PRECISION)
261263
# Create a reverse mapping of PD_PRECISION_DICT
262264
reverse_precision_dict = {v: k for k, v in PD_PRECISION_DICT.items()}
263265
# Use the reverse mapping to find keys with the desired value

deepmd/pt/train/training.py

Lines changed: 35 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def get_sample() -> Any:
308308
return get_sample
309309

310310
def get_lr(lr_params: dict[str, Any]) -> BaseLR:
311-
lr_params["stop_steps"] = self.num_steps - self.warmup_steps
311+
lr_params["num_steps"] = self.num_steps
312312
lr_schedule = BaseLR(**lr_params)
313313
return lr_schedule
314314

@@ -547,33 +547,15 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
547547
)
548548

549549
# Learning rate
550-
warmup_steps = training_params.get("warmup_steps", None)
551-
warmup_ratio = training_params.get("warmup_ratio", None)
552-
if warmup_steps is not None:
553-
self.warmup_steps = warmup_steps
554-
elif warmup_ratio is not None:
555-
if not 0 <= warmup_ratio < 1:
556-
raise ValueError(f"warmup_ratio must be in [0, 1), got {warmup_ratio}")
557-
self.warmup_steps = int(warmup_ratio * self.num_steps)
558-
if self.warmup_steps == 0 and warmup_ratio > 0:
559-
log.warning(
560-
f"warmup_ratio {warmup_ratio} results in 0 warmup steps "
561-
f"due to truncation. Consider using a larger ratio or "
562-
f"specify warmup_steps directly."
563-
)
564-
else:
565-
self.warmup_steps = 0
566-
self.warmup_start_factor = training_params.get("warmup_start_factor", 0.0)
567550
self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0)
568-
assert self.num_steps - self.warmup_steps > 0 or self.warmup_steps == 0, (
569-
"Warm up steps must be less than total training steps!"
570-
)
571551
if self.multi_task and config.get("learning_rate_dict", None) is not None:
572-
self.lr_exp = {}
552+
self.lr_schedule = {}
573553
for model_key in self.model_keys:
574-
self.lr_exp[model_key] = get_lr(config["learning_rate_dict"][model_key])
554+
self.lr_schedule[model_key] = get_lr(
555+
config["learning_rate_dict"][model_key]
556+
)
575557
else:
576-
self.lr_exp = get_lr(config["learning_rate"])
558+
self.lr_schedule = get_lr(config["learning_rate"])
577559

578560
# JIT
579561
if JIT:
@@ -807,34 +789,32 @@ def single_model_finetune(
807789

808790
# TODO add lr warmups for multitask
809791
# author: iProzd
810-
def warm_up_linear(step: int, warmup_steps: int) -> float:
811-
if step < warmup_steps:
812-
return self.warmup_start_factor + (1.0 - self.warmup_start_factor) * (
813-
step / warmup_steps
814-
)
815-
else:
816-
return self.lr_exp.value(step - warmup_steps) / self.lr_exp.start_lr
817-
818792
# TODO add optimizers for multitask
819793
# author: iProzd
794+
initial_lr = self.lr_schedule.value(self.start_step)
820795
if self.opt_type in ["Adam", "AdamW"]:
796+
# Initialize optimizer with the actual learning rate at start_step
797+
# to ensure warmup is applied from the first step
821798
if self.opt_type == "Adam":
822799
self.optimizer = self._create_optimizer(
823800
torch.optim.Adam,
824-
lr=self.lr_exp.start_lr,
801+
lr=initial_lr,
825802
fused=DEVICE.type != "cpu",
826803
)
827804
else:
828805
self.optimizer = self._create_optimizer(
829806
torch.optim.AdamW,
830-
lr=self.lr_exp.start_lr,
807+
lr=initial_lr,
831808
weight_decay=float(self.opt_param["weight_decay"]),
832809
fused=DEVICE.type != "cpu",
833810
)
834811
self._load_optimizer_state(optimizer_state_dict)
835812
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
836813
self.optimizer,
837-
lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
814+
lambda step: (
815+
self.lr_schedule.value(step + self.start_step) / initial_lr
816+
),
817+
last_epoch=self.start_step - 1,
838818
)
839819
elif self.opt_type == "LKF":
840820
self.optimizer = LKFOptimizer(
@@ -843,7 +823,7 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
843823
elif self.opt_type == "AdaMuon":
844824
self.optimizer = self._create_optimizer(
845825
AdaMuonOptimizer,
846-
lr=self.lr_exp.start_lr,
826+
lr=initial_lr,
847827
momentum=float(self.opt_param["momentum"]),
848828
weight_decay=float(self.opt_param["weight_decay"]),
849829
adam_betas=(
@@ -853,10 +833,19 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
853833
lr_adjust=float(self.opt_param["lr_adjust"]),
854834
lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]),
855835
)
836+
if optimizer_state_dict is not None and self.restart_training:
837+
self.optimizer.load_state_dict(optimizer_state_dict)
838+
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
839+
self.optimizer,
840+
lambda step: (
841+
self.lr_schedule.value(step + self.start_step) / initial_lr
842+
),
843+
last_epoch=self.start_step - 1,
844+
)
856845
elif self.opt_type == "HybridMuon":
857846
self.optimizer = self._create_optimizer(
858847
HybridMuonOptimizer,
859-
lr=self.lr_exp.start_lr,
848+
lr=initial_lr,
860849
momentum=float(self.opt_param["momentum"]),
861850
weight_decay=float(self.opt_param["weight_decay"]),
862851
adam_betas=(
@@ -872,7 +861,10 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
872861
self._load_optimizer_state(optimizer_state_dict)
873862
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
874863
self.optimizer,
875-
lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
864+
lambda step: (
865+
self.lr_schedule.value(step + self.start_step) / initial_lr
866+
),
867+
last_epoch=self.start_step - 1,
876868
)
877869
else:
878870
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")
@@ -1034,10 +1026,10 @@ def step(_step_id: int, task_key: str = "Default") -> None:
10341026
# PyTorch Profiler
10351027
if self.enable_profiler or self.profiling:
10361028
prof.step()
1037-
if isinstance(self.lr_exp, dict):
1038-
_lr = self.lr_exp[task_key]
1029+
if isinstance(self.lr_schedule, dict):
1030+
_lr = self.lr_schedule[task_key]
10391031
else:
1040-
_lr = self.lr_exp
1032+
_lr = self.lr_schedule
10411033
cur_lr = _lr.value(_step_id)
10421034
pref_lr = cur_lr
10431035
self.optimizer.zero_grad(set_to_none=True)
@@ -1050,10 +1042,7 @@ def step(_step_id: int, task_key: str = "Default") -> None:
10501042
fout1.flush()
10511043
if self.opt_type in ["Adam", "AdamW", "AdaMuon", "HybridMuon"]:
10521044
cur_lr = self.scheduler.get_last_lr()[0]
1053-
if _step_id < self.warmup_steps:
1054-
pref_lr = _lr.start_lr
1055-
else:
1056-
pref_lr = cur_lr
1045+
pref_lr = cur_lr
10571046
model_pred, loss, more_loss = self.wrapper(
10581047
**input_dict, cur_lr=pref_lr, label=label_dict, task_key=task_key
10591048
)
@@ -1446,7 +1435,7 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
14461435
_bias_adjust_mode="change-by-statistic",
14471436
)
14481437
self.latest_model = Path(self.save_ckpt + f"-{self.num_steps}.pt")
1449-
cur_lr = self.lr_exp.value(self.num_steps - 1)
1438+
cur_lr = self.lr_schedule.value(self.num_steps - 1)
14501439
self.save_model(self.latest_model, lr=cur_lr, step=self.num_steps - 1)
14511440
log.info(f"Saved model to {self.latest_model}")
14521441
symlink_prefix_files(self.latest_model.stem, self.save_ckpt)

deepmd/pt/utils/utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from .env import (
1818
DEVICE,
19+
GLOBAL_NP_FLOAT_PRECISION,
1920
)
2021
from .env import PRECISION_DICT as PT_PRECISION_DICT
2122

@@ -218,6 +219,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
218219
raise RuntimeError(f"activation function {self.activation} not supported")
219220

220221

222+
@overload
223+
def to_numpy_array(xx: np.ndarray) -> np.ndarray: ...
224+
225+
226+
@overload
227+
def to_numpy_array(xx: float) -> np.ndarray: ...
228+
229+
221230
@overload
222231
def to_numpy_array(xx: torch.Tensor) -> np.ndarray: ...
223232

@@ -227,18 +236,22 @@ def to_numpy_array(xx: None) -> None: ...
227236

228237

229238
def to_numpy_array(
230-
xx: torch.Tensor | None,
239+
xx: torch.Tensor | np.ndarray | float | None,
231240
) -> np.ndarray | None:
232241
if xx is None:
233242
return None
234-
assert xx is not None
243+
if isinstance(xx, (float, int)):
244+
return np.array(xx, dtype=GLOBAL_NP_FLOAT_PRECISION)
245+
if isinstance(xx, np.ndarray):
246+
return xx.astype(GLOBAL_NP_FLOAT_PRECISION)
235247
# Create a reverse mapping of PT_PRECISION_DICT
236248
reverse_precision_dict = {v: k for k, v in PT_PRECISION_DICT.items()}
237249
# Use the reverse mapping to find keys with the desired value
238250
prec = reverse_precision_dict.get(xx.dtype, None)
239251
prec = NP_PRECISION_DICT.get(prec, None)
240252
if prec is None:
241253
raise ValueError(f"unknown precision {xx.dtype}")
254+
assert isinstance(xx, torch.Tensor)
242255
if xx.dtype == torch.bfloat16:
243256
# https://github.com/pytorch/pytorch/issues/109873
244257
xx = xx.float()

deepmd/tf/fit/dipole.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def get_loss(self, loss: dict, lr: LearningRateExp) -> Loss:
402402
----------
403403
loss : dict
404404
the loss dict
405-
lr : LearningRateExp
405+
lr : LearningRateSchedule
406406
the learning rate
407407
408408
Returns

deepmd/tf/fit/dos.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
)
5151

5252
if TYPE_CHECKING:
53-
from deepmd.tf.train.learning_rate import (
53+
from deepmd.tf.utils.learning_rate import (
5454
LearningRateExp,
5555
)
5656
from deepmd.utils.version import (
@@ -668,7 +668,7 @@ def get_loss(self, loss: dict, lr: "LearningRateExp") -> Loss:
668668
----------
669669
loss : dict
670670
the loss dict
671-
lr : LearningRateExp
671+
lr : LearningRateSchedule
672672
the learning rate
673673
674674
Returns

deepmd/tf/fit/ener.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ def get_loss(self, loss: dict, lr: LearningRateExp) -> Loss:
864864
----------
865865
loss : dict
866866
The loss function parameters.
867-
lr : LearningRateExp
867+
lr : LearningRateSchedule
868868
The learning rate.
869869
870870
Returns

deepmd/tf/fit/fitting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def get_loss(self, loss: dict, lr: LearningRateExp) -> Loss:
8585
----------
8686
loss : dict
8787
the loss dict
88-
lr : LearningRateExp
88+
lr : LearningRateSchedule
8989
the learning rate
9090
9191
Returns

deepmd/tf/fit/polar.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
)
5151

5252
if TYPE_CHECKING:
53-
from deepmd.tf.train.learning_rate import (
53+
from deepmd.tf.utils.learning_rate import (
5454
LearningRateExp,
5555
)
5656

@@ -880,7 +880,7 @@ def get_loss(self, loss: dict, lr: "LearningRateExp") -> Loss:
880880
----------
881881
loss : dict
882882
the loss dict
883-
lr : LearningRateExp
883+
lr : LearningRateSchedule
884884
the learning rate
885885
886886
Returns

0 commit comments

Comments
 (0)