Skip to content

Commit 80bd704

Browse files
committed
refactor: unify learning rate schedulers with array API
- Refactor BaseLR in dpmodel to use array_api_compat for backend-agnostic implementation - Consolidate learning rate logic from TF/PT/PD backends into unified dpmodel layer - Use array API operations (xp.where, xp.clip, etc.) for JIT compatibility across backends - Add warmup support (warmup_steps, warmup_ratio, warmup_start_factor) during refactoring - Add stop_ratio parameter as alternative to stop_lr for flexible configuration - Implement mutual exclusion validation for stop_lr/stop_ratio and warmup_steps/warmup_ratio - Update all backends to use unified BaseLR implementation - Add comprehensive consistency tests across NumPy/PyTorch/JAX/array_api_strict backends
1 parent 367e626 commit 80bd704

40 files changed

Lines changed: 1638 additions & 429 deletions

deepmd/dpmodel/utils/learning_rate.py

Lines changed: 415 additions & 60 deletions
Large diffs are not rendered by default.

deepmd/pd/train/training.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def get_sample() -> dict[str, Any]:
243243
return get_sample
244244

245245
def get_lr(lr_params: dict[str, Any]) -> BaseLR:
246-
lr_params["stop_steps"] = self.num_steps - self.warmup_steps
246+
lr_params["num_steps"] = self.num_steps
247247
lr_schedule = BaseLR(**lr_params)
248248
return lr_schedule
249249

@@ -391,11 +391,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
391391
)
392392

393393
# Learning rate
394-
self.warmup_steps = training_params.get("warmup_steps", 0)
395394
self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0)
396-
assert self.num_steps - self.warmup_steps > 0 or self.warmup_steps == 0, (
397-
"Warm up steps must be less than total training steps!"
398-
)
399395
if self.multi_task and config.get("learning_rate_dict", None) is not None:
400396
self.lr_exp = {}
401397
for model_key in self.model_keys:
@@ -584,18 +580,14 @@ def single_model_finetune(
584580

585581
# TODO add lr warmups for multitask
586582
# author: iProzd
587-
def warm_up_linear(step: int, warmup_steps: int) -> float:
588-
if step < warmup_steps:
589-
return step / warmup_steps
590-
else:
591-
return self.lr_exp.value(step - warmup_steps) / self.lr_exp.start_lr
592-
593583
# TODO add optimizers for multitask
594584
# author: iProzd
595585
if self.opt_type == "Adam":
596586
self.scheduler = paddle.optimizer.lr.LambdaDecay(
597587
learning_rate=self.lr_exp.start_lr,
598-
lr_lambda=lambda step: warm_up_linear(step, self.warmup_steps),
588+
lr_lambda=lambda step: (
589+
self.lr_exp.value(step + self.start_step) / self.lr_exp.start_lr
590+
),
599591
)
600592
self.optimizer = paddle.optimizer.Adam(
601593
learning_rate=self.scheduler, parameters=self.wrapper.parameters()
@@ -759,10 +751,7 @@ def step(_step_id: int, task_key: str = "Default") -> None:
759751
fout1.flush()
760752
if self.opt_type == "Adam":
761753
cur_lr = self.scheduler.get_lr()
762-
if _step_id < self.warmup_steps:
763-
pref_lr = _lr.start_lr
764-
else:
765-
pref_lr = cur_lr
754+
pref_lr = cur_lr
766755

767756
# disable synchronization in forward-backward manually
768757
# as derivatives exist in model forward

deepmd/pd/utils/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
from .env import (
3636
DEVICE,
37+
GLOBAL_NP_FLOAT_PRECISION,
3738
)
3839
from .env import PRECISION_DICT as PD_PRECISION_DICT
3940

@@ -257,7 +258,8 @@ def to_numpy_array(
257258
):
258259
if xx is None:
259260
return None
260-
assert xx is not None
261+
if isinstance(xx, (float, int)):
262+
return np.array(xx, dtype=GLOBAL_NP_FLOAT_PRECISION)
261263
# Create a reverse mapping of PD_PRECISION_DICT
262264
reverse_precision_dict = {v: k for k, v in PD_PRECISION_DICT.items()}
263265
# Use the reverse mapping to find keys with the desired value

deepmd/pt/train/training.py

Lines changed: 20 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def get_sample() -> Any:
299299
return get_sample
300300

301301
def get_lr(lr_params: dict[str, Any]) -> BaseLR:
302-
lr_params["stop_steps"] = self.num_steps - self.warmup_steps
302+
lr_params["num_steps"] = self.num_steps
303303
lr_schedule = BaseLR(**lr_params)
304304
return lr_schedule
305305

@@ -463,27 +463,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
463463
)
464464

465465
# Learning rate
466-
warmup_steps = training_params.get("warmup_steps", None)
467-
warmup_ratio = training_params.get("warmup_ratio", None)
468-
if warmup_steps is not None:
469-
self.warmup_steps = warmup_steps
470-
elif warmup_ratio is not None:
471-
if not 0 <= warmup_ratio < 1:
472-
raise ValueError(f"warmup_ratio must be in [0, 1), got {warmup_ratio}")
473-
self.warmup_steps = int(warmup_ratio * self.num_steps)
474-
if self.warmup_steps == 0 and warmup_ratio > 0:
475-
log.warning(
476-
f"warmup_ratio {warmup_ratio} results in 0 warmup steps "
477-
f"due to truncation. Consider using a larger ratio or "
478-
f"specify warmup_steps directly."
479-
)
480-
else:
481-
self.warmup_steps = 0
482-
self.warmup_start_factor = training_params.get("warmup_start_factor", 0.0)
483466
self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0)
484-
assert self.num_steps - self.warmup_steps > 0 or self.warmup_steps == 0, (
485-
"Warm up steps must be less than total training steps!"
486-
)
487467
if self.multi_task and config.get("learning_rate_dict", None) is not None:
488468
self.lr_exp = {}
489469
for model_key in self.model_keys:
@@ -738,34 +718,30 @@ def single_model_finetune(
738718

739719
# TODO add lr warmups for multitask
740720
# author: iProzd
741-
def warm_up_linear(step: int, warmup_steps: int) -> float:
742-
if step < warmup_steps:
743-
return self.warmup_start_factor + (1.0 - self.warmup_start_factor) * (
744-
step / warmup_steps
745-
)
746-
else:
747-
return self.lr_exp.value(step - warmup_steps) / self.lr_exp.start_lr
748-
749721
# TODO add optimizers for multitask
750722
# author: iProzd
723+
initial_lr = self.lr_exp.value(self.start_step)
751724
if self.opt_type in ["Adam", "AdamW"]:
725+
# Initialize optimizer with the actual learning rate at start_step
726+
# to ensure warmup is applied from the first step
752727
if self.opt_type == "Adam":
753728
self.optimizer = self._create_optimizer(
754729
torch.optim.Adam,
755-
lr=self.lr_exp.start_lr,
730+
lr=initial_lr,
756731
fused=DEVICE.type != "cpu",
757732
)
758733
else:
759734
self.optimizer = self._create_optimizer(
760735
torch.optim.AdamW,
761-
lr=self.lr_exp.start_lr,
736+
lr=initial_lr,
762737
weight_decay=float(self.opt_param["weight_decay"]),
763738
fused=DEVICE.type != "cpu",
764739
)
765740
self._load_optimizer_state(optimizer_state_dict)
766741
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
767742
self.optimizer,
768-
lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
743+
lambda step: self.lr_exp.value(step + self.start_step) / initial_lr,
744+
last_epoch=self.start_step - 1,
769745
)
770746
elif self.opt_type == "LKF":
771747
self.optimizer = LKFOptimizer(
@@ -774,7 +750,7 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
774750
elif self.opt_type == "AdaMuon":
775751
self.optimizer = self._create_optimizer(
776752
AdaMuonOptimizer,
777-
lr=self.lr_exp.start_lr,
753+
lr=initial_lr,
778754
momentum=float(self.opt_param["momentum"]),
779755
weight_decay=float(self.opt_param["weight_decay"]),
780756
adam_betas=(
@@ -784,10 +760,17 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
784760
lr_adjust=float(self.opt_param["lr_adjust"]),
785761
lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]),
786762
)
763+
if optimizer_state_dict is not None and self.restart_training:
764+
self.optimizer.load_state_dict(optimizer_state_dict)
765+
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
766+
self.optimizer,
767+
lambda step: self.lr_exp.value(step + self.start_step) / initial_lr,
768+
last_epoch=self.start_step - 1,
769+
)
787770
elif self.opt_type == "HybridMuon":
788771
self.optimizer = self._create_optimizer(
789772
HybridMuonOptimizer,
790-
lr=self.lr_exp.start_lr,
773+
lr=initial_lr,
791774
momentum=float(self.opt_param["momentum"]),
792775
weight_decay=float(self.opt_param["weight_decay"]),
793776
adam_betas=(
@@ -802,7 +785,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
802785
self._load_optimizer_state(optimizer_state_dict)
803786
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
804787
self.optimizer,
805-
lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
788+
lambda step: self.lr_exp.value(step + self.start_step) / initial_lr,
789+
last_epoch=self.start_step - 1,
806790
)
807791
else:
808792
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")
@@ -980,10 +964,7 @@ def step(_step_id: int, task_key: str = "Default") -> None:
980964
fout1.flush()
981965
if self.opt_type in ["Adam", "AdamW", "AdaMuon", "HybridMuon"]:
982966
cur_lr = self.scheduler.get_last_lr()[0]
983-
if _step_id < self.warmup_steps:
984-
pref_lr = _lr.start_lr
985-
else:
986-
pref_lr = cur_lr
967+
pref_lr = cur_lr
987968
model_pred, loss, more_loss = self.wrapper(
988969
**input_dict, cur_lr=pref_lr, label=label_dict, task_key=task_key
989970
)

deepmd/pt/utils/utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from .env import (
1818
DEVICE,
19+
GLOBAL_NP_FLOAT_PRECISION,
1920
)
2021
from .env import PRECISION_DICT as PT_PRECISION_DICT
2122

@@ -218,6 +219,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
218219
raise RuntimeError(f"activation function {self.activation} not supported")
219220

220221

222+
@overload
223+
def to_numpy_array(xx: np.ndarray) -> np.ndarray: ...
224+
225+
226+
@overload
227+
def to_numpy_array(xx: float) -> np.ndarray: ...
228+
229+
221230
@overload
222231
def to_numpy_array(xx: torch.Tensor) -> np.ndarray: ...
223232

@@ -227,18 +236,22 @@ def to_numpy_array(xx: None) -> None: ...
227236

228237

229238
def to_numpy_array(
230-
xx: torch.Tensor | None,
239+
xx: torch.Tensor | np.ndarray | float | None,
231240
) -> np.ndarray | None:
232241
if xx is None:
233242
return None
234-
assert xx is not None
243+
if isinstance(xx, (float, int)):
244+
return np.array(xx, dtype=GLOBAL_NP_FLOAT_PRECISION)
245+
if isinstance(xx, np.ndarray):
246+
return xx.astype(GLOBAL_NP_FLOAT_PRECISION)
235247
# Create a reverse mapping of PT_PRECISION_DICT
236248
reverse_precision_dict = {v: k for k, v in PT_PRECISION_DICT.items()}
237249
# Use the reverse mapping to find keys with the desired value
238250
prec = reverse_precision_dict.get(xx.dtype, None)
239251
prec = NP_PRECISION_DICT.get(prec, None)
240252
if prec is None:
241253
raise ValueError(f"unknown precision {xx.dtype}")
254+
assert isinstance(xx, torch.Tensor)
242255
if xx.dtype == torch.bfloat16:
243256
# https://github.com/pytorch/pytorch/issues/109873
244257
xx = xx.float()

deepmd/tf/fit/dipole.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def get_loss(self, loss: dict, lr: LearningRateExp) -> Loss:
402402
----------
403403
loss : dict
404404
the loss dict
405-
lr : LearningRateExp
405+
lr : LearningRateSchedule
406406
the learning rate
407407
408408
Returns

deepmd/tf/fit/dos.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
)
5151

5252
if TYPE_CHECKING:
53-
from deepmd.tf.train.learning_rate import (
53+
from deepmd.tf.utils.learning_rate import (
5454
LearningRateExp,
5555
)
5656
from deepmd.utils.version import (
@@ -668,7 +668,7 @@ def get_loss(self, loss: dict, lr: "LearningRateExp") -> Loss:
668668
----------
669669
loss : dict
670670
the loss dict
671-
lr : LearningRateExp
671+
lr : LearningRateSchedule
672672
the learning rate
673673
674674
Returns

deepmd/tf/fit/ener.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ def get_loss(self, loss: dict, lr: LearningRateExp) -> Loss:
864864
----------
865865
loss : dict
866866
The loss function parameters.
867-
lr : LearningRateExp
867+
lr : LearningRateSchedule
868868
The learning rate.
869869
870870
Returns

deepmd/tf/fit/fitting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def get_loss(self, loss: dict, lr: LearningRateExp) -> Loss:
8585
----------
8686
loss : dict
8787
the loss dict
88-
lr : LearningRateExp
88+
lr : LearningRateSchedule
8989
the learning rate
9090
9191
Returns

deepmd/tf/fit/polar.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
)
5151

5252
if TYPE_CHECKING:
53-
from deepmd.tf.train.learning_rate import (
53+
from deepmd.tf.utils.learning_rate import (
5454
LearningRateExp,
5555
)
5656

@@ -880,7 +880,7 @@ def get_loss(self, loss: dict, lr: "LearningRateExp") -> Loss:
880880
----------
881881
loss : dict
882882
the loss dict
883-
lr : LearningRateExp
883+
lr : LearningRateSchedule
884884
the learning rate
885885
886886
Returns

0 commit comments

Comments
 (0)