feat(opt): validate loaded modelopt state files#1471
Conversation
Add validation to load_modelopt_state() to verify the loaded object is a dict with the expected schema (modelopt_state_dict list and modelopt_version str). Raises TypeError/ValueError with clear messages when the file is malformed, and detects full checkpoints passed by mistake, pointing users to mto.restore(). Closes #1041 Co-authored-by: Keval Morabia <kevalmorabia97@users.noreply.github.com> Signed-off-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
📝 WalkthroughWalkthroughAdds _validate_modelopt_state and calls it from load_modelopt_state immediately after safe_load(map_location="cpu"); invalid shapes now raise TypeError/ValueError. New tests cover valid load and three failure modes (non-dict, missing keys, full-checkpoint shape). ChangesSchema validation for modelopt state loading
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
/ok to test 163a682 |
|
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/opt/conversion.py`:
- Around line 545-555: Add per-item schema validation for
state["modelopt_state_dict"]: iterate the list and for each entry ensure it's a
2-tuple or 2-list, raise TypeError if not a tuple/list or if its length != 2;
ensure the first element (mode name) is a str and raise TypeError otherwise;
ensure the second element is a dict that contains keys "config" and "metadata"
(and optionally validate those are dicts), raising ValueError/TypeError with
clear messages when missing or wrong-typed. Place these checks immediately after
the existing top-level checks that validate state["modelopt_state_dict"] and
state["modelopt_version"] so malformed files fail fast during load/restore.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 29621597-0dab-405d-8d3b-921c026a392c
📒 Files selected for processing (2)
modelopt/torch/opt/conversion.pytests/unit/torch/opt/test_load_modelopt_state.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1471 +/- ##
==========================================
- Coverage 76.95% 76.92% -0.04%
==========================================
Files 478 478
Lines 51648 51661 +13
==========================================
- Hits 39744 39738 -6
- Misses 11904 11923 +19
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/unit/torch/opt/test_load_modelopt_state.py (1)
22-48: ⚡ Quick winAdd schema type-validation tests for required keys.
Current coverage checks key presence/shape but not wrong value types (
modelopt_state_dictnon-list,modelopt_versionnon-str), which are part of the new contract.✅ Suggested test additions
def test_load_modelopt_state_full_checkpoint(tmp_path): path = tmp_path / "ckpt.pt" torch.save({"modelopt_state": {}, "model_state_dict": {}}, path) with pytest.raises(ValueError, match="full checkpoint"): load_modelopt_state(path) + + +def test_load_modelopt_state_invalid_modelopt_state_dict_type(tmp_path): + path = tmp_path / "bad_state_dict_type.pt" + torch.save({"modelopt_state_dict": {}, "modelopt_version": "1.0.0"}, path) + with pytest.raises((TypeError, ValueError), match="modelopt_state_dict"): + load_modelopt_state(path) + + +def test_load_modelopt_state_invalid_modelopt_version_type(tmp_path): + path = tmp_path / "bad_version_type.pt" + torch.save({"modelopt_state_dict": [], "modelopt_version": 100}, path) + with pytest.raises((TypeError, ValueError), match="modelopt_version"): + load_modelopt_state(path)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unit/torch/opt/test_load_modelopt_state.py` around lines 22 - 48, Add unit tests to validate the types of required keys when calling load_modelopt_state: create cases where saved state dict has modelopt_state_dict with a non-list value and where modelopt_version has a non-str value, then assert load_modelopt_state() raises TypeError (or the project’s chosen exception) with a message indicating the incorrect type; add these tests alongside the existing ones (e.g., in tests/unit/torch/opt/test_load_modelopt_state.py) referencing load_modelopt_state so they catch schema-type violations for those keys.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@tests/unit/torch/opt/test_load_modelopt_state.py`:
- Around line 22-48: Add unit tests to validate the types of required keys when
calling load_modelopt_state: create cases where saved state dict has
modelopt_state_dict with a non-list value and where modelopt_version has a
non-str value, then assert load_modelopt_state() raises TypeError (or the
project’s chosen exception) with a message indicating the incorrect type; add
these tests alongside the existing ones (e.g., in
tests/unit/torch/opt/test_load_modelopt_state.py) referencing
load_modelopt_state so they catch schema-type violations for those keys.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 725382bd-9383-4f7a-a346-112a6505d448
📒 Files selected for processing (2)
modelopt/torch/opt/conversion.pytests/unit/torch/opt/test_load_modelopt_state.py
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/opt/conversion.py
|
/ok to test bc743fc |
Add validation to load_modelopt_state() to verify the loaded object is a dict with the expected schema (modelopt_state_dict list and modelopt_version str). Raises TypeError/ValueError with clear messages when the file is malformed, and detects full checkpoints passed by mistake, pointing users to mto.restore(). Closes #1041 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Added strict validation for model state files to surface format errors with clear messages. * Malformed or invalid state files now fail fast instead of being returned silently. * Improved detection to prevent accidental loading of full checkpoints when only state dicts are expected. * **Tests** * New unit tests covering validation and loading behavior for various malformed and valid state files. [](https://app.coderabbit.ai/change-stack/NVIDIA/Model-Optimizer/pull/1471) <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: Keval Morabia <kevalmorabia97@users.noreply.github.com>
Add validation to load_modelopt_state() to verify the loaded object is a dict with the expected schema (modelopt_state_dict list and modelopt_version str). Raises TypeError/ValueError with clear messages when the file is malformed, and detects full checkpoints passed by mistake, pointing users to mto.restore().
Closes #1041
Summary by CodeRabbit
Bug Fixes
Tests