Skip to content

Commit 72fef38

Browse files
committed
Fix pt model loading with non-matching state dictionaries (init from
tf2pt model) - Add try-catch block for loading frozen models - Use strict=False when state_dict keys don't match - Log warnings for model state_dict mismatches - Prevent crashes when loading models with different architectures"
1 parent 0f53edc commit 72fef38

1 file changed

Lines changed: 9 additions & 1 deletion

File tree

deepmd/pt/train/training.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,15 @@ def single_model_finetune(
613613

614614
if init_frz_model is not None:
615615
frz_model = torch.jit.load(init_frz_model, map_location=DEVICE)
616-
self.model.load_state_dict(frz_model.state_dict())
616+
try:
617+
self.model.load_state_dict(frz_model.state_dict())
618+
except RuntimeError as e:
619+
if "Missing key(s) in state_dict" in str(e):
620+
self.model.load_state_dict(frz_model.state_dict(), strict=False)
621+
log.warning("Use strict=False to ignore non-matching keys.")
622+
log.warning(f"Model state_dict mismatch detected: {e}")
623+
else:
624+
raise e
617625

618626
# Get model prob for multi-task
619627
if self.multi_task:

0 commit comments

Comments
 (0)