Skip to content

Commit 3e7d9e9

Browse files
committed
Update training.py
1 parent eff5320 commit 3e7d9e9

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,9 @@ def collect_single_finetune_params(
570570
state_dict["_extra_state"] = self.wrapper.state_dict()[
571571
"_extra_state"
572572
]
573+
old_model_params = self.wrapper.state_dict()["_extra_state"][
574+
"model_params"
575+
]
573576
try:
574577
self.wrapper.load_state_dict(state_dict)
575578
except RuntimeError as e:
@@ -584,9 +587,7 @@ def collect_single_finetune_params(
584587
rm_list.append(kk)
585588
for kk in rm_list:
586589
state_dict.pop(kk)
587-
state_dict["_extra_state"] = self.wrapper.state_dict()[
588-
"_extra_state"
589-
]
590+
state_dict["_extra_state"]["model_params"] = old_model_params
590591
out_shape_list = [
591592
"model.Default.atomic_model.out_bias",
592593
"model.Default.atomic_model.out_std",

0 commit comments

Comments
 (0)