fix(pt): use strict=False when loading .pt checkpoints for inference#5353
Conversation
Old checkpoints (pre-deepmodeling#5323/deepmodeling#5057) lack compression buffers (type_embd_data, compress_info, compress_data) that are now registered in DescrptBlockSeAtten.__init__. strict=True rejects these checkpoints entirely. Use strict=False with warnings, matching the pattern already used in training.py.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: d5ec8938b7
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
📝 WalkthroughWalkthroughWhen loading PyTorch checkpoints in the inference module, Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5353 +/- ##
==========================================
- Coverage 82.26% 82.26% -0.01%
==========================================
Files 799 799
Lines 82563 82567 +4
Branches 4066 4066
==========================================
+ Hits 67924 67926 +2
- Misses 13424 13426 +2
Partials 1215 1215 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Summary
strict=Falseinload_state_dictwhen loading.ptcheckpoints in theno_jit=Trueinference pathRoot cause
PR #5323 and #5057 added new buffers/parameters to
DescrptBlockSeAtten.__init__:type_embd_data(buffer, for type embedding compression)compress_info(ParameterList, for geometric compression)compress_data(ParameterList, for geometric compression)These are initialized to zero-size tensors (uncompressed state). Old checkpoints saved before these PRs don't contain these keys, so
load_state_dict(strict=True)raisesRuntimeError: Missing key(s).Fix
Use
strict=Falsewith warning logs, matching the pattern already used intraining.py:740. The zero-initialized defaults correctly represent uncompressed state.Test plan
dpa-2.4-7M.pt) withno_jit=TrueReported by @njzjz .
Summary by CodeRabbit