Skip to content

Commit a18fd72

Browse files
committed
Update training.py
1 parent 773fb32 commit a18fd72

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,9 @@ def collect_single_finetune_params(
541541
state_dict["_extra_state"] = self.wrapper.state_dict()[
542542
"_extra_state"
543543
]
544+
old_model_params = self.wrapper.state_dict()["_extra_state"][
545+
"model_params"
546+
]
544547
try:
545548
self.wrapper.load_state_dict(state_dict)
546549
except RuntimeError as e:
@@ -555,9 +558,7 @@ def collect_single_finetune_params(
555558
rm_list.append(kk)
556559
for kk in rm_list:
557560
state_dict.pop(kk)
558-
state_dict["_extra_state"] = self.wrapper.state_dict()[
559-
"_extra_state"
560-
]
561+
state_dict["_extra_state"]["model_params"] = old_model_params
561562
out_shape_list = [
562563
"model.Default.atomic_model.out_bias",
563564
"model.Default.atomic_model.out_std",

0 commit comments

Comments
 (0)