We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent eff5320 commit 3e7d9e9Copy full SHA for 3e7d9e9
1 file changed
deepmd/pt/train/training.py
@@ -570,6 +570,9 @@ def collect_single_finetune_params(
570
state_dict["_extra_state"] = self.wrapper.state_dict()[
571
"_extra_state"
572
]
573
+ old_model_params = self.wrapper.state_dict()["_extra_state"][
574
+ "model_params"
575
+ ]
576
try:
577
self.wrapper.load_state_dict(state_dict)
578
except RuntimeError as e:
@@ -584,9 +587,7 @@ def collect_single_finetune_params(
584
587
rm_list.append(kk)
585
588
for kk in rm_list:
586
589
state_dict.pop(kk)
- state_dict["_extra_state"] = self.wrapper.state_dict()[
- "_extra_state"
- ]
590
+ state_dict["_extra_state"]["model_params"] = old_model_params
591
out_shape_list = [
592
"model.Default.atomic_model.out_bias",
593
"model.Default.atomic_model.out_std",
0 commit comments