Skip to content

Commit 66b088c

Browse files
committed
fix(pt): base LambdaLR on configured start_lr
1 parent 0a481de commit 66b088c

1 file changed

Lines changed: 18 additions & 5 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -939,12 +939,10 @@ def single_model_finetune(
939939
**extra,
940940
)
941941
self._load_optimizer_state(optimizer_state_dict)
942-
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
942+
self.scheduler = self._create_lr_scheduler(
943943
self.optimizer,
944-
lambda step: (
945-
self.lr_schedule.value(step + self.start_step) / initial_lr
946-
),
947-
last_epoch=self.start_step - 1,
944+
self.lr_schedule,
945+
self.start_step,
948946
)
949947

950948
if self.zero_stage > 0 and self.rank == 0:
@@ -975,6 +973,21 @@ def single_model_finetune(
975973
if self.rank == 0:
976974
self._log_parameter_count()
977975

976+
@staticmethod
977+
def _create_lr_scheduler(
978+
optimizer: torch.optim.Optimizer,
979+
lr_schedule: BaseLR,
980+
start_step: int,
981+
) -> torch.optim.lr_scheduler.LambdaLR:
982+
base_lr = float(lr_schedule.start_lr)
983+
for group in optimizer.param_groups:
984+
group["initial_lr"] = base_lr
985+
return torch.optim.lr_scheduler.LambdaLR(
986+
optimizer,
987+
lambda step: lr_schedule.value(step) / base_lr,
988+
last_epoch=start_step - 1,
989+
)
990+
978991
def _create_full_validator(
979992
self,
980993
*,

0 commit comments

Comments
 (0)