Skip to content

Commit dec0429

Browse files
committed
revise err capture in loading stat_dict
1 parent bfcbd8a commit dec0429

1 file changed

Lines changed: 7 additions & 10 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -613,16 +613,13 @@ def single_model_finetune(
613613

614614
if init_frz_model is not None:
615615
frz_model = torch.jit.load(init_frz_model, map_location=DEVICE)
616-
try:
617-
self.model.load_state_dict(frz_model.state_dict())
618-
except RuntimeError as err_msg:
619-
if "Missing key(s) in state_dict" in str(
620-
err_msg
621-
) or "Unexpected key(s) in state_dict" in str(err_msg):
622-
self.model.load_state_dict(frz_model.state_dict(), strict=False)
623-
log.warning("Loaded with strict=False to ignore non-matching keys.")
624-
else:
625-
raise
616+
state = frz_model.state_dict()
617+
missing, unexpected = self.model.load_state_dict(state, strict=False)
618+
if missing or unexpected:
619+
log.warning(
620+
"Checkpoint loaded non-strictly. "
621+
f"Missing keys: {missing}, Unexpected keys: {unexpected}"
622+
)
626623

627624
# Get model prob for multi-task
628625
if self.multi_task:

0 commit comments

Comments
 (0)