Skip to content

Commit eff5320

Browse files
committed
add init from direct model
1 parent fcd034a commit eff5320

1 file changed

Lines changed: 24 additions & 2 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -570,8 +570,30 @@ def collect_single_finetune_params(
570570
state_dict["_extra_state"] = self.wrapper.state_dict()[
571571
"_extra_state"
572572
]
573-
574-
self.wrapper.load_state_dict(state_dict)
573+
try:
574+
self.wrapper.load_state_dict(state_dict)
575+
except RuntimeError as e:
576+
# init from direct fitting
577+
rm_list = []
578+
for kk in state_dict:
579+
# delete direct heads
580+
if (
581+
"fitting_net.force_embed." in kk
582+
or "fitting_net.noise_embed" in kk
583+
):
584+
rm_list.append(kk)
585+
for kk in rm_list:
586+
state_dict.pop(kk)
587+
state_dict["_extra_state"] = self.wrapper.state_dict()[
588+
"_extra_state"
589+
]
590+
out_shape_list = [
591+
"model.Default.atomic_model.out_bias",
592+
"model.Default.atomic_model.out_std",
593+
]
594+
for kk in out_shape_list:
595+
state_dict[kk] = state_dict[kk][:1, :, :1]
596+
self.wrapper.load_state_dict(state_dict)
575597

576598
# change bias for fine-tuning
577599
if finetune_model is not None:

0 commit comments

Comments
 (0)