Skip to content

Commit 773fb32

Browse files
committed
add init from direct model
1 parent 7f547b8 commit 773fb32

1 file changed

Lines changed: 24 additions & 2 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -541,8 +541,30 @@ def collect_single_finetune_params(
541541
state_dict["_extra_state"] = self.wrapper.state_dict()[
542542
"_extra_state"
543543
]
544-
545-
self.wrapper.load_state_dict(state_dict)
544+
try:
545+
self.wrapper.load_state_dict(state_dict)
546+
except RuntimeError as e:
547+
# init from direct fitting
548+
rm_list = []
549+
for kk in state_dict:
550+
# delete direct heads
551+
if (
552+
"fitting_net.force_embed." in kk
553+
or "fitting_net.noise_embed" in kk
554+
):
555+
rm_list.append(kk)
556+
for kk in rm_list:
557+
state_dict.pop(kk)
558+
state_dict["_extra_state"] = self.wrapper.state_dict()[
559+
"_extra_state"
560+
]
561+
out_shape_list = [
562+
"model.Default.atomic_model.out_bias",
563+
"model.Default.atomic_model.out_std",
564+
]
565+
for kk in out_shape_list:
566+
state_dict[kk] = state_dict[kk][:1, :, :1]
567+
self.wrapper.load_state_dict(state_dict)
546568

547569
# change bias for fine-tuning
548570
if finetune_model is not None:

0 commit comments

Comments
 (0)