Skip to content

Commit 56de6eb

Browse files
committed
feat(pt): add force load
1 parent 1079027 commit 56de6eb

1 file changed

Lines changed: 9 additions & 1 deletion

File tree

deepmd/pt/train/training.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,14 @@ def get_lr(lr_params):
420420
if missing_keys:
421421
target_state_dict = self.wrapper.state_dict()
422422
slim_keys = []
423+
out_shape_list = [
424+
"model.Default.atomic_model.out_bias",
425+
"model.Default.atomic_model.out_std",
426+
]
427+
for kk in out_shape_list:
428+
old_stat = state_dict[kk].clone().detach()
429+
state_dict[kk] = target_state_dict[kk].clone().detach()
430+
state_dict[kk][:1, :, :1] = old_stat
423431
for item in missing_keys:
424432
state_dict[item] = target_state_dict[item].clone().detach()
425433
new_key = True
@@ -428,7 +436,7 @@ def get_lr(lr_params):
428436
new_key = False
429437
break
430438
if new_key:
431-
tmp_keys = ".".join(item.split(".")[:3])
439+
tmp_keys = ".".join(item.split(".")[:-2])
432440
slim_keys.append(tmp_keys)
433441
slim_keys = [i + ".*" for i in slim_keys]
434442
log.warning(

0 commit comments

Comments
 (0)