We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 773fb32 commit a18fd72Copy full SHA for a18fd72
1 file changed
deepmd/pt/train/training.py
@@ -541,6 +541,9 @@ def collect_single_finetune_params(
541
state_dict["_extra_state"] = self.wrapper.state_dict()[
542
"_extra_state"
543
]
544
+ old_model_params = self.wrapper.state_dict()["_extra_state"][
545
+ "model_params"
546
+ ]
547
try:
548
self.wrapper.load_state_dict(state_dict)
549
except RuntimeError as e:
@@ -555,9 +558,7 @@ def collect_single_finetune_params(
555
558
rm_list.append(kk)
556
559
for kk in rm_list:
557
560
state_dict.pop(kk)
- state_dict["_extra_state"] = self.wrapper.state_dict()[
- "_extra_state"
- ]
561
+ state_dict["_extra_state"]["model_params"] = old_model_params
562
out_shape_list = [
563
"model.Default.atomic_model.out_bias",
564
"model.Default.atomic_model.out_std",
0 commit comments