Skip to content

Commit 8060865

Browse files
wanghan-iapcmHan Wang
andauthored
fix(pt): use strict=False when loading .pt checkpoints for inference (#5353)
## Summary - Use `strict=False` in `load_state_dict` when loading `.pt` checkpoints in the `no_jit=True` inference path - Log warnings for missing/unexpected keys instead of crashing ## Root cause PR #5323 and #5057 added new buffers/parameters to `DescrptBlockSeAtten.__init__`: - `type_embd_data` (buffer, for type embedding compression) - `compress_info` (ParameterList, for geometric compression) - `compress_data` (ParameterList, for geometric compression) These are initialized to zero-size tensors (uncompressed state). Old checkpoints saved before these PRs don't contain these keys, so `load_state_dict(strict=True)` raises `RuntimeError: Missing key(s)`. ## Fix Use `strict=False` with warning logs, matching the pattern already used in `training.py:740`. The zero-initialized defaults correctly represent uncompressed state. ## Test plan - [x] Pre-existing tests pass (no behavioral change for new checkpoints) - [x] Fixes loading of old DPA-2 checkpoints (`dpa-2.4-7M.pt`) with `no_jit=True` Reported by @njzjz . <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Improved checkpoint loading compatibility to support models with missing or unexpected parameters, with detailed warning messages logged for troubleshooting. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent cdd9b1a commit 8060865

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

deepmd/pt/infer/deep_eval.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,18 @@ def __init__(
170170
if not self.input_param.get("hessian_mode") and not no_jit:
171171
model = torch.jit.script(model)
172172
self.dp = ModelWrapper(model)
173-
self.dp.load_state_dict(state_dict)
173+
missing, unexpected = self.dp.load_state_dict(state_dict, strict=False)
174+
if missing:
175+
log.warning(
176+
"Checkpoint loaded with missing keys (likely from an older "
177+
"version): %s",
178+
missing,
179+
)
180+
if unexpected:
181+
log.warning(
182+
"Checkpoint loaded with unexpected keys: %s",
183+
unexpected,
184+
)
174185
elif str(self.model_path).endswith(".pth"):
175186
extra_files = {"data_modifier.pth": ""}
176187
model = torch.jit.load(

0 commit comments

Comments
 (0)