Skip to content

Commit 938ebe1

Browse files
committed
refactor opt initialization
1 parent 523abb0 commit 938ebe1

1 file changed

Lines changed: 34 additions & 59 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 34 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -784,71 +784,48 @@ def single_model_finetune(
784784
# TODO add optimizers for multitask
785785
# author: iProzd
786786
initial_lr = self.lr_schedule.value(self.start_step)
787-
if self.opt_type in ["Adam", "AdamW"]:
787+
if self.opt_type == "LKF":
788+
self.optimizer = LKFOptimizer(
789+
self.wrapper.parameters(), 0.98, 0.99870, self.opt_param["kf_blocksize"]
790+
)
791+
else:
792+
# === Common path for gradient-based optimizers ===
788793
adam_betas = (
789794
float(self.opt_param["adam_beta1"]),
790795
float(self.opt_param["adam_beta2"]),
791796
)
792797
weight_decay = float(self.opt_param["weight_decay"])
793-
optimizer_class = (
794-
torch.optim.Adam if self.opt_type == "Adam" else torch.optim.AdamW
795-
)
798+
799+
if self.opt_type in ("Adam", "AdamW"):
800+
cls = torch.optim.Adam if self.opt_type == "Adam" else torch.optim.AdamW
801+
extra = {"betas": adam_betas, "fused": DEVICE.type != "cpu"}
802+
elif self.opt_type == "AdaMuon":
803+
cls = AdaMuonOptimizer
804+
extra = {
805+
"adam_betas": adam_betas,
806+
"momentum": float(self.opt_param["momentum"]),
807+
"lr_adjust": float(self.opt_param["lr_adjust"]),
808+
"lr_adjust_coeff": float(self.opt_param["lr_adjust_coeff"]),
809+
}
810+
elif self.opt_type == "HybridMuon":
811+
cls = HybridMuonOptimizer
812+
extra = {
813+
"adam_betas": adam_betas,
814+
"momentum": float(self.opt_param["momentum"]),
815+
"lr_adjust": float(self.opt_param["lr_adjust"]),
816+
"lr_adjust_coeff": float(self.opt_param["lr_adjust_coeff"]),
817+
"muon_2d_only": bool(self.opt_param["muon_2d_only"]),
818+
"min_2d_dim": int(self.opt_param["min_2d_dim"]),
819+
"flash_muon": bool(self.opt_param["flash_muon"]),
820+
}
821+
else:
822+
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")
823+
796824
self.optimizer = self._create_optimizer(
797-
optimizer_class,
825+
cls,
798826
lr=initial_lr,
799-
betas=adam_betas,
800827
weight_decay=weight_decay,
801-
fused=DEVICE.type != "cpu",
802-
)
803-
self._load_optimizer_state(optimizer_state_dict)
804-
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
805-
self.optimizer,
806-
lambda step: (
807-
self.lr_schedule.value(step + self.start_step) / initial_lr
808-
),
809-
last_epoch=self.start_step - 1,
810-
)
811-
elif self.opt_type == "LKF":
812-
self.optimizer = LKFOptimizer(
813-
self.wrapper.parameters(), 0.98, 0.99870, self.opt_param["kf_blocksize"]
814-
)
815-
elif self.opt_type == "AdaMuon":
816-
self.optimizer = self._create_optimizer(
817-
AdaMuonOptimizer,
818-
lr=initial_lr,
819-
momentum=float(self.opt_param["momentum"]),
820-
weight_decay=float(self.opt_param["weight_decay"]),
821-
adam_betas=(
822-
float(self.opt_param["adam_beta1"]),
823-
float(self.opt_param["adam_beta2"]),
824-
),
825-
lr_adjust=float(self.opt_param["lr_adjust"]),
826-
lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]),
827-
)
828-
if optimizer_state_dict is not None and self.restart_training:
829-
self.optimizer.load_state_dict(optimizer_state_dict)
830-
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
831-
self.optimizer,
832-
lambda step: (
833-
self.lr_schedule.value(step + self.start_step) / initial_lr
834-
),
835-
last_epoch=self.start_step - 1,
836-
)
837-
elif self.opt_type == "HybridMuon":
838-
self.optimizer = self._create_optimizer(
839-
HybridMuonOptimizer,
840-
lr=initial_lr,
841-
momentum=float(self.opt_param["momentum"]),
842-
weight_decay=float(self.opt_param["weight_decay"]),
843-
adam_betas=(
844-
float(self.opt_param["adam_beta1"]),
845-
float(self.opt_param["adam_beta2"]),
846-
),
847-
lr_adjust=float(self.opt_param["lr_adjust"]),
848-
lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]),
849-
muon_2d_only=bool(self.opt_param["muon_2d_only"]),
850-
min_2d_dim=int(self.opt_param["min_2d_dim"]),
851-
flash_muon=bool(self.opt_param["flash_muon"]),
828+
**extra,
852829
)
853830
self._load_optimizer_state(optimizer_state_dict)
854831
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
@@ -858,8 +835,6 @@ def single_model_finetune(
858835
),
859836
last_epoch=self.start_step - 1,
860837
)
861-
else:
862-
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")
863838

864839
if self.zero_stage > 0 and self.rank == 0:
865840
if self.zero_stage == 1:

0 commit comments

Comments
 (0)