Skip to content

Commit 335d53e

Browse files
committed
Fix pt model loading with non-matching state dictionaries (init from
tf2pt model) - Add try-catch block for loading frozen models - Use strict=False when state_dict keys don't match - Log warnings for model state_dict mismatches - Prevent crashes when loading models with different architectures"
1 parent d774ede commit 335d53e

1 file changed

Lines changed: 9 additions & 1 deletion

File tree

deepmd/pt/train/training.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,15 @@ def single_model_finetune(
614614

615615
if init_frz_model is not None:
616616
frz_model = torch.jit.load(init_frz_model, map_location=DEVICE)
617-
self.model.load_state_dict(frz_model.state_dict())
617+
try:
618+
self.model.load_state_dict(frz_model.state_dict())
619+
except RuntimeError as e:
620+
if "Missing key(s) in state_dict" in str(e):
621+
self.model.load_state_dict(frz_model.state_dict(), strict=False)
622+
log.warning("Use strict=False to ignore non-matching keys.")
623+
log.warning(f"Model state_dict mismatch detected: {e}")
624+
else:
625+
raise e
618626

619627
# Get model prob for multi-task
620628
if self.multi_task:

0 commit comments

Comments
 (0)