Skip to content

Commit c96476a

Browse files
committed
add unit test for pb2pth model
1 parent 335d53e commit c96476a

3 files changed

Lines changed: 19841 additions & 7 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -616,13 +616,14 @@ def single_model_finetune(
616616
frz_model = torch.jit.load(init_frz_model, map_location=DEVICE)
617617
try:
618618
self.model.load_state_dict(frz_model.state_dict())
619-
except RuntimeError as e:
620-
if "Missing key(s) in state_dict" in str(e):
619+
except RuntimeError as err_msg:
620+
if "Missing key(s) in state_dict" in str(
621+
err_msg
622+
) or "Unexpected key(s) in state_dict" in str(err_msg):
621623
self.model.load_state_dict(frz_model.state_dict(), strict=False)
622-
log.warning("Use strict=False to ignore non-matching keys.")
623-
log.warning(f"Model state_dict mismatch detected: {e}")
624+
log.warning("Loaded with strict=False to ignore non-matching keys.")
624625
else:
625-
raise e
626+
raise
626627

627628
# Get model prob for multi-task
628629
if self.multi_task:

0 commit comments

Comments
 (0)