Skip to content

Commit 337b334

Browse files
committed
refactor: unify learning rate schedulers with array API
- Refactor BaseLR in dpmodel to use array_api_compat for backend-agnostic implementation - Consolidate learning rate logic from TF/PT/PD backends into unified dpmodel layer - Use array API operations (xp.where, xp.clip, etc.) for JIT compatibility across backends - Add warmup support (warmup_steps, warmup_ratio, warmup_start_factor) during refactoring - Add stop_ratio parameter as alternative to stop_lr for flexible configuration - Implement mutual exclusion validation for stop_lr/stop_ratio and warmup_steps/warmup_ratio - Update all backends to use unified BaseLR implementation - Add comprehensive consistency tests across NumPy/PyTorch/JAX/array_api_strict backends
1 parent a0bd530 commit 337b334

40 files changed

Lines changed: 1633 additions & 427 deletions

deepmd/dpmodel/utils/learning_rate.py

Lines changed: 400 additions & 60 deletions
Large diffs are not rendered by default.

deepmd/pd/train/training.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def get_sample() -> dict[str, Any]:
243243
return get_sample
244244

245245
def get_lr(lr_params: dict[str, Any]) -> BaseLR:
246-
lr_params["stop_steps"] = self.num_steps - self.warmup_steps
246+
lr_params["num_steps"] = self.num_steps
247247
lr_schedule = BaseLR(**lr_params)
248248
return lr_schedule
249249

@@ -391,11 +391,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
391391
)
392392

393393
# Learning rate
394-
self.warmup_steps = training_params.get("warmup_steps", 0)
395394
self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0)
396-
assert self.num_steps - self.warmup_steps > 0 or self.warmup_steps == 0, (
397-
"Warm up steps must be less than total training steps!"
398-
)
399395
if self.multi_task and config.get("learning_rate_dict", None) is not None:
400396
self.lr_exp = {}
401397
for model_key in self.model_keys:
@@ -584,18 +580,14 @@ def single_model_finetune(
584580

585581
# TODO add lr warmups for multitask
586582
# author: iProzd
587-
def warm_up_linear(step: int, warmup_steps: int) -> float:
588-
if step < warmup_steps:
589-
return step / warmup_steps
590-
else:
591-
return self.lr_exp.value(step - warmup_steps) / self.lr_exp.start_lr
592-
593583
# TODO add optimizers for multitask
594584
# author: iProzd
595585
if self.opt_type == "Adam":
596586
self.scheduler = paddle.optimizer.lr.LambdaDecay(
597587
learning_rate=self.lr_exp.start_lr,
598-
lr_lambda=lambda step: warm_up_linear(step, self.warmup_steps),
588+
lr_lambda=lambda step: (
589+
self.lr_exp.value(step + self.start_step) / self.lr_exp.start_lr
590+
),
599591
)
600592
self.optimizer = paddle.optimizer.Adam(
601593
learning_rate=self.scheduler, parameters=self.wrapper.parameters()
@@ -759,10 +751,7 @@ def step(_step_id: int, task_key: str = "Default") -> None:
759751
fout1.flush()
760752
if self.opt_type == "Adam":
761753
cur_lr = self.scheduler.get_lr()
762-
if _step_id < self.warmup_steps:
763-
pref_lr = _lr.start_lr
764-
else:
765-
pref_lr = cur_lr
754+
pref_lr = cur_lr
766755

767756
# disable synchronization in forward-backward manually
768757
# as derivatives exist in model forward

deepmd/pd/utils/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
from .env import (
3636
DEVICE,
37+
GLOBAL_NP_FLOAT_PRECISION,
3738
)
3839
from .env import PRECISION_DICT as PD_PRECISION_DICT
3940

@@ -257,7 +258,8 @@ def to_numpy_array(
257258
):
258259
if xx is None:
259260
return None
260-
assert xx is not None
261+
if isinstance(xx, (float, int)):
262+
return np.array(xx, dtype=GLOBAL_NP_FLOAT_PRECISION)
261263
# Create a reverse mapping of PD_PRECISION_DICT
262264
reverse_precision_dict = {v: k for k, v in PD_PRECISION_DICT.items()}
263265
# Use the reverse mapping to find keys with the desired value

deepmd/pt/train/training.py

Lines changed: 26 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def get_sample() -> Any:
279279
return get_sample
280280

281281
def get_lr(lr_params: dict[str, Any]) -> BaseLR:
282-
lr_params["stop_steps"] = self.num_steps - self.warmup_steps
282+
lr_params["num_steps"] = self.num_steps
283283
lr_schedule = BaseLR(**lr_params)
284284
return lr_schedule
285285

@@ -437,27 +437,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
437437
)
438438

439439
# Learning rate
440-
warmup_steps = training_params.get("warmup_steps", None)
441-
warmup_ratio = training_params.get("warmup_ratio", None)
442-
if warmup_steps is not None:
443-
self.warmup_steps = warmup_steps
444-
elif warmup_ratio is not None:
445-
if not 0 <= warmup_ratio < 1:
446-
raise ValueError(f"warmup_ratio must be in [0, 1), got {warmup_ratio}")
447-
self.warmup_steps = int(warmup_ratio * self.num_steps)
448-
if self.warmup_steps == 0 and warmup_ratio > 0:
449-
log.warning(
450-
f"warmup_ratio {warmup_ratio} results in 0 warmup steps "
451-
f"due to truncation. Consider using a larger ratio or "
452-
f"specify warmup_steps directly."
453-
)
454-
else:
455-
self.warmup_steps = 0
456-
self.warmup_start_factor = training_params.get("warmup_start_factor", 0.0)
457440
self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0)
458-
assert self.num_steps - self.warmup_steps > 0 or self.warmup_steps == 0, (
459-
"Warm up steps must be less than total training steps!"
460-
)
461441
if self.multi_task and config.get("learning_rate_dict", None) is not None:
462442
self.lr_exp = {}
463443
for model_key in self.model_keys:
@@ -702,44 +682,43 @@ def single_model_finetune(
702682

703683
# TODO add lr warmups for multitask
704684
# author: iProzd
705-
def warm_up_linear(step: int, warmup_steps: int) -> float:
706-
if step < warmup_steps:
707-
return self.warmup_start_factor + (1.0 - self.warmup_start_factor) * (
708-
step / warmup_steps
709-
)
710-
else:
711-
return self.lr_exp.value(step - warmup_steps) / self.lr_exp.start_lr
712-
713685
# TODO add optimizers for multitask
714686
# author: iProzd
715687
if self.opt_type in ["Adam", "AdamW"]:
688+
# Initialize optimizer with the actual learning rate at start_step
689+
# to ensure warmup is applied from the first step
690+
initial_lr = self.lr_exp.value(self.start_step)
716691
if self.opt_type == "Adam":
717692
self.optimizer = torch.optim.Adam(
718693
self.wrapper.parameters(),
719-
lr=self.lr_exp.start_lr,
694+
lr=initial_lr,
720695
fused=False if DEVICE.type == "cpu" else True,
721696
)
722697
else:
723698
self.optimizer = torch.optim.AdamW(
724699
self.wrapper.parameters(),
725-
lr=self.lr_exp.start_lr,
700+
lr=initial_lr,
726701
weight_decay=float(self.opt_param["weight_decay"]),
727702
fused=False if DEVICE.type == "cpu" else True,
728703
)
729704
if optimizer_state_dict is not None and self.restart_training:
730705
self.optimizer.load_state_dict(optimizer_state_dict)
731706
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
732707
self.optimizer,
733-
lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
708+
lambda step: self.lr_exp.value(step + self.start_step) / initial_lr,
709+
last_epoch=self.start_step - 1,
734710
)
735711
elif self.opt_type == "LKF":
736712
self.optimizer = LKFOptimizer(
737713
self.wrapper.parameters(), 0.98, 0.99870, self.opt_param["kf_blocksize"]
738714
)
739715
elif self.opt_type == "AdaMuon":
716+
# Initialize optimizer with the actual learning rate at start_step
717+
# to ensure warmup is applied from the first step
718+
initial_lr = self.lr_exp.value(self.start_step)
740719
self.optimizer = AdaMuonOptimizer(
741720
self.wrapper.parameters(),
742-
lr=self.lr_exp.start_lr,
721+
lr=initial_lr,
743722
momentum=float(self.opt_param["momentum"]),
744723
weight_decay=float(self.opt_param["weight_decay"]),
745724
adam_betas=(
@@ -749,10 +728,20 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
749728
lr_adjust=float(self.opt_param["lr_adjust"]),
750729
lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]),
751730
)
731+
if optimizer_state_dict is not None and self.restart_training:
732+
self.optimizer.load_state_dict(optimizer_state_dict)
733+
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
734+
self.optimizer,
735+
lambda step: self.lr_exp.value(step + self.start_step) / initial_lr,
736+
last_epoch=self.start_step - 1,
737+
)
752738
elif self.opt_type == "HybridMuon":
739+
# Initialize optimizer with the actual learning rate at start_step
740+
# to ensure warmup is applied from the first step
741+
initial_lr = self.lr_exp.value(self.start_step)
753742
self.optimizer = HybridMuonOptimizer(
754743
self.wrapper.parameters(),
755-
lr=self.lr_exp.start_lr,
744+
lr=initial_lr,
756745
momentum=float(self.opt_param["momentum"]),
757746
weight_decay=float(self.opt_param["weight_decay"]),
758747
adam_betas=(
@@ -768,7 +757,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
768757
self.optimizer.load_state_dict(optimizer_state_dict)
769758
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
770759
self.optimizer,
771-
lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
760+
lambda step: self.lr_exp.value(step + self.start_step) / initial_lr,
761+
last_epoch=self.start_step - 1,
772762
)
773763
else:
774764
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")
@@ -883,10 +873,7 @@ def step(_step_id: int, task_key: str = "Default") -> None:
883873
fout1.flush()
884874
if self.opt_type in ["Adam", "AdamW", "AdaMuon", "HybridMuon"]:
885875
cur_lr = self.scheduler.get_last_lr()[0]
886-
if _step_id < self.warmup_steps:
887-
pref_lr = _lr.start_lr
888-
else:
889-
pref_lr = cur_lr
876+
pref_lr = cur_lr
890877
model_pred, loss, more_loss = self.wrapper(
891878
**input_dict, cur_lr=pref_lr, label=label_dict, task_key=task_key
892879
)

deepmd/pt/utils/utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from .env import (
1818
DEVICE,
19+
GLOBAL_NP_FLOAT_PRECISION,
1920
)
2021
from .env import PRECISION_DICT as PT_PRECISION_DICT
2122

@@ -218,6 +219,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
218219
raise RuntimeError(f"activation function {self.activation} not supported")
219220

220221

222+
@overload
223+
def to_numpy_array(xx: np.ndarray) -> np.ndarray: ...
224+
225+
226+
@overload
227+
def to_numpy_array(xx: float) -> np.ndarray: ...
228+
229+
221230
@overload
222231
def to_numpy_array(xx: torch.Tensor) -> np.ndarray: ...
223232

@@ -227,18 +236,22 @@ def to_numpy_array(xx: None) -> None: ...
227236

228237

229238
def to_numpy_array(
230-
xx: torch.Tensor | None,
239+
xx: torch.Tensor | np.ndarray | float | None,
231240
) -> np.ndarray | None:
232241
if xx is None:
233242
return None
234-
assert xx is not None
243+
if isinstance(xx, (float, int)):
244+
return np.array(xx, dtype=GLOBAL_NP_FLOAT_PRECISION)
245+
if isinstance(xx, np.ndarray):
246+
return xx.astype(GLOBAL_NP_FLOAT_PRECISION)
235247
# Create a reverse mapping of PT_PRECISION_DICT
236248
reverse_precision_dict = {v: k for k, v in PT_PRECISION_DICT.items()}
237249
# Use the reverse mapping to find keys with the desired value
238250
prec = reverse_precision_dict.get(xx.dtype, None)
239251
prec = NP_PRECISION_DICT.get(prec, None)
240252
if prec is None:
241253
raise ValueError(f"unknown precision {xx.dtype}")
254+
assert isinstance(xx, torch.Tensor)
242255
if xx.dtype == torch.bfloat16:
243256
# https://github.com/pytorch/pytorch/issues/109873
244257
xx = xx.float()

deepmd/tf/fit/dipole.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def get_loss(self, loss: dict, lr: LearningRateExp) -> Loss:
402402
----------
403403
loss : dict
404404
the loss dict
405-
lr : LearningRateExp
405+
lr : LearningRateSchedule
406406
the learning rate
407407
408408
Returns

deepmd/tf/fit/dos.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
)
5151

5252
if TYPE_CHECKING:
53-
from deepmd.tf.train.learning_rate import (
53+
from deepmd.tf.utils.learning_rate import (
5454
LearningRateExp,
5555
)
5656
from deepmd.utils.version import (
@@ -668,7 +668,7 @@ def get_loss(self, loss: dict, lr: "LearningRateExp") -> Loss:
668668
----------
669669
loss : dict
670670
the loss dict
671-
lr : LearningRateExp
671+
lr : LearningRateSchedule
672672
the learning rate
673673
674674
Returns

deepmd/tf/fit/ener.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ def get_loss(self, loss: dict, lr: LearningRateExp) -> Loss:
864864
----------
865865
loss : dict
866866
The loss function parameters.
867-
lr : LearningRateExp
867+
lr : LearningRateSchedule
868868
The learning rate.
869869
870870
Returns

deepmd/tf/fit/fitting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def get_loss(self, loss: dict, lr: LearningRateExp) -> Loss:
8585
----------
8686
loss : dict
8787
the loss dict
88-
lr : LearningRateExp
88+
lr : LearningRateSchedule
8989
the learning rate
9090
9191
Returns

deepmd/tf/fit/polar.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
)
5151

5252
if TYPE_CHECKING:
53-
from deepmd.tf.train.learning_rate import (
53+
from deepmd.tf.utils.learning_rate import (
5454
LearningRateExp,
5555
)
5656

@@ -880,7 +880,7 @@ def get_loss(self, loss: dict, lr: "LearningRateExp") -> Loss:
880880
----------
881881
loss : dict
882882
the loss dict
883-
lr : LearningRateExp
883+
lr : LearningRateSchedule
884884
the learning rate
885885
886886
Returns

0 commit comments

Comments
 (0)