Skip to content

Commit 782d17b

Browse files
committed
Add start_lr() method to BaseLR for better encapsulation
- Add start_lr() getter method to BaseLR class - Replace direct attribute access lr_exp.start_lr with method call lr_exp.start_lr() in pt and pd training modules - Align API consistency with existing TF implementation This addresses the code review comment about providing a method for accessing start_lr rather than directly reading object data.
1 parent 4efc0c9 commit 782d17b

3 files changed

Lines changed: 19 additions & 8 deletions

File tree

deepmd/dpmodel/utils/learning_rate.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,17 @@ def __init__(
123123
# Decay phase covers (num_steps - warmup_steps) steps
124124
self.decay_num_steps = num_steps - self.warmup_steps
125125

126+
def start_lr(self) -> float:
127+
"""
128+
Get the starting learning rate.
129+
130+
Returns
131+
-------
132+
float
133+
The starting learning rate.
134+
"""
135+
return self.start_lr
136+
126137
@abstractmethod
127138
def _decay_value(self, step: int | Array) -> Array:
128139
"""

deepmd/pd/train/training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,9 +580,9 @@ def single_model_finetune(
580580
# author: iProzd
581581
if self.opt_type == "Adam":
582582
self.scheduler = paddle.optimizer.lr.LambdaDecay(
583-
learning_rate=self.lr_exp.start_lr,
583+
learning_rate=self.lr_exp.start_lr(),
584584
lr_lambda=lambda step: self.lr_exp.value(step + self.start_step)
585-
/ self.lr_exp.start_lr,
585+
/ self.lr_exp.start_lr(),
586586
)
587587
self.optimizer = paddle.optimizer.Adam(
588588
learning_rate=self.scheduler, parameters=self.wrapper.parameters()

deepmd/pt/train/training.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -688,13 +688,13 @@ def single_model_finetune(
688688
if self.opt_type == "Adam":
689689
self.optimizer = torch.optim.Adam(
690690
self.wrapper.parameters(),
691-
lr=self.lr_exp.start_lr,
691+
lr=self.lr_exp.start_lr(),
692692
fused=False if DEVICE.type == "cpu" else True,
693693
)
694694
else:
695695
self.optimizer = torch.optim.AdamW(
696696
self.wrapper.parameters(),
697-
lr=self.lr_exp.start_lr,
697+
lr=self.lr_exp.start_lr(),
698698
weight_decay=float(self.opt_param["weight_decay"]),
699699
fused=False if DEVICE.type == "cpu" else True,
700700
)
@@ -703,7 +703,7 @@ def single_model_finetune(
703703
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
704704
self.optimizer,
705705
lambda step: self.lr_exp.value(step + self.start_step)
706-
/ self.lr_exp.start_lr,
706+
/ self.lr_exp.start_lr(),
707707
)
708708
elif self.opt_type == "LKF":
709709
self.optimizer = LKFOptimizer(
@@ -712,7 +712,7 @@ def single_model_finetune(
712712
elif self.opt_type == "AdaMuon":
713713
self.optimizer = AdaMuonOptimizer(
714714
self.wrapper.parameters(),
715-
lr=self.lr_exp.start_lr,
715+
lr=self.lr_exp.start_lr(),
716716
momentum=float(self.opt_param["momentum"]),
717717
weight_decay=float(self.opt_param["weight_decay"]),
718718
adam_betas=(
@@ -725,7 +725,7 @@ def single_model_finetune(
725725
elif self.opt_type == "HybridMuon":
726726
self.optimizer = HybridMuonOptimizer(
727727
self.wrapper.parameters(),
728-
lr=self.lr_exp.start_lr,
728+
lr=self.lr_exp.start_lr(),
729729
momentum=float(self.opt_param["momentum"]),
730730
weight_decay=float(self.opt_param["weight_decay"]),
731731
adam_betas=(
@@ -742,7 +742,7 @@ def single_model_finetune(
742742
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
743743
self.optimizer,
744744
lambda step: self.lr_exp.value(step + self.start_step)
745-
/ self.lr_exp.start_lr,
745+
/ self.lr_exp.start_lr(),
746746
)
747747
else:
748748
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")

0 commit comments

Comments
 (0)