Commit 8060865
fix(pt): use strict=False when loading .pt checkpoints for inference (#5353)
## Summary
- Use `strict=False` in `load_state_dict` when loading `.pt` checkpoints
in the `no_jit=True` inference path
- Log warnings for missing/unexpected keys instead of crashing
## Root cause
PR #5323 and #5057 added new buffers/parameters to
`DescrptBlockSeAtten.__init__`:
- `type_embd_data` (buffer, for type embedding compression)
- `compress_info` (ParameterList, for geometric compression)
- `compress_data` (ParameterList, for geometric compression)
These are initialized to zero-size tensors (uncompressed state). Old
checkpoints saved before these PRs don't contain these keys, so
`load_state_dict(strict=True)` raises `RuntimeError: Missing key(s)`.
## Fix
Use `strict=False` with warning logs, matching the pattern already used
in `training.py:740`. The zero-initialized defaults correctly represent
uncompressed state.
## Test plan
- [x] Pre-existing tests pass (no behavioral change for new checkpoints)
- [x] Fixes loading of old DPA-2 checkpoints (`dpa-2.4-7M.pt`) with
`no_jit=True`
Reported by @njzjz .
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **Bug Fixes**
* Improved checkpoint loading compatibility to support models with
missing or unexpected parameters, with detailed warning messages logged
for troubleshooting.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>1 parent cdd9b1a commit 8060865
1 file changed
+12
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
170 | 170 | | |
171 | 171 | | |
172 | 172 | | |
173 | | - | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
174 | 185 | | |
175 | 186 | | |
176 | 187 | | |
| |||
0 commit comments