Skip to content

feat(opt): validate loaded modelopt state files#1471

Merged
kevalmorabia97 merged 2 commits into
mainfrom
claude/issue-1041-20260512-1627
May 12, 2026
Merged

feat(opt): validate loaded modelopt state files#1471
kevalmorabia97 merged 2 commits into
mainfrom
claude/issue-1041-20260512-1627

Conversation

@kevalmorabia97

@kevalmorabia97 kevalmorabia97 commented May 12, 2026

Copy link
Copy Markdown
Collaborator

Add validation to load_modelopt_state() to verify the loaded object is a dict with the expected schema (modelopt_state_dict list and modelopt_version str). Raises TypeError/ValueError with clear messages when the file is malformed, and detects full checkpoints passed by mistake, pointing users to mto.restore().

Closes #1041

Summary by CodeRabbit

  • Bug Fixes

    • Added strict validation for model state files to surface format errors with clear messages.
    • Malformed or invalid state files now fail fast instead of being returned silently.
    • Improved detection to prevent accidental loading of full checkpoints when only state dicts are expected.
  • Tests

    • New unit tests covering validation and loading behavior for various malformed and valid state files.

Review Change Stack

Add validation to load_modelopt_state() to verify the loaded object is a
dict with the expected schema (modelopt_state_dict list and
modelopt_version str). Raises TypeError/ValueError with clear messages
when the file is malformed, and detects full checkpoints passed by
mistake, pointing users to mto.restore().

Closes #1041

Co-authored-by: Keval Morabia <kevalmorabia97@users.noreply.github.com>
Signed-off-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
@kevalmorabia97 kevalmorabia97 requested a review from a team as a code owner May 12, 2026 16:42
@kevalmorabia97 kevalmorabia97 requested a review from ChenhanYu May 12, 2026 16:42
@copy-pr-bot

copy-pr-bot Bot commented May 12, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai

coderabbitai Bot commented May 12, 2026

Copy link
Copy Markdown
Contributor
📝 Walkthrough

Walkthrough

Adds _validate_modelopt_state and calls it from load_modelopt_state immediately after safe_load(map_location="cpu"); invalid shapes now raise TypeError/ValueError. New tests cover valid load and three failure modes (non-dict, missing keys, full-checkpoint shape).

Changes

Schema validation for modelopt state loading

Layer / File(s) Summary
Validation function and load integration
modelopt/torch/opt/conversion.py
Introduces _validate_modelopt_state that enforces modelopt_state_dict (list) and modelopt_version (string), detects full-checkpoint shapes, and is invoked by load_modelopt_state after safe_load(..., map_location="cpu").
Validation and loading tests
tests/unit/torch/opt/test_load_modelopt_state.py
New pytest module verifying successful round-trip for a valid state dict and that loading non-dict, missing-key dicts, and full-checkpoint-shaped dicts raise TypeError/ValueError with descriptive messages.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: adding validation to loaded modelopt state files, which is the primary focus of the PR.
Linked Issues check ✅ Passed The PR implements all requirements from #1041: validates the loaded object is a dict, checks for expected schema (modelopt_state_dict list and modelopt_version str), raises clear errors with descriptive messages, and detects full checkpoint mistakenly provided.
Out of Scope Changes check ✅ Passed All changes are directly related to the validation objective in #1041; the conversion module receives validation logic and tests are added to verify the validator works correctly with no unrelated modifications.
Security Anti-Patterns ✅ Passed No security anti-patterns detected. PR uses safe_load() with PyTorch 2.8+ defaults (weights_only=True). No unsafe patterns, eval/exec, trust_remote_code, or new non-permissive dependencies found.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch claude/issue-1041-20260512-1627

Comment @coderabbitai help to get the list of available commands and usage tips.

@kevalmorabia97

Copy link
Copy Markdown
Collaborator Author

/ok to test 163a682

@github-actions

github-actions Bot commented May 12, 2026

Copy link
Copy Markdown
Contributor
PR Preview Action v1.8.1
Preview removed because the pull request was closed.
2026-05-12 19:27 UTC

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/opt/conversion.py`:
- Around line 545-555: Add per-item schema validation for
state["modelopt_state_dict"]: iterate the list and for each entry ensure it's a
2-tuple or 2-list, raise TypeError if not a tuple/list or if its length != 2;
ensure the first element (mode name) is a str and raise TypeError otherwise;
ensure the second element is a dict that contains keys "config" and "metadata"
(and optionally validate those are dicts), raising ValueError/TypeError with
clear messages when missing or wrong-typed. Place these checks immediately after
the existing top-level checks that validate state["modelopt_state_dict"] and
state["modelopt_version"] so malformed files fail fast during load/restore.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 29621597-0dab-405d-8d3b-921c026a392c

📥 Commits

Reviewing files that changed from the base of the PR and between 794a4e3 and 163a682.

📒 Files selected for processing (2)
  • modelopt/torch/opt/conversion.py
  • tests/unit/torch/opt/test_load_modelopt_state.py

Comment thread modelopt/torch/opt/conversion.py

@shengliangxu shengliangxu left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@codecov

codecov Bot commented May 12, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 85.71429% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.92%. Comparing base (d30ebbd) to head (bc743fc).
⚠️ Report is 5 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/opt/conversion.py 85.71% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1471      +/-   ##
==========================================
- Coverage   76.95%   76.92%   -0.04%     
==========================================
  Files         478      478              
  Lines       51648    51661      +13     
==========================================
- Hits        39744    39738       -6     
- Misses      11904    11923      +19     
Flag Coverage Δ
examples 41.63% <78.57%> (+0.91%) ⬆️
gpu 59.60% <7.14%> (-0.60%) ⬇️
regression 15.16% <78.57%> (+0.09%) ⬆️
unit 52.74% <85.71%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
@kevalmorabia97 kevalmorabia97 enabled auto-merge (squash) May 12, 2026 17:01

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/unit/torch/opt/test_load_modelopt_state.py (1)

22-48: ⚡ Quick win

Add schema type-validation tests for required keys.

Current coverage checks key presence/shape but not wrong value types (modelopt_state_dict non-list, modelopt_version non-str), which are part of the new contract.

✅ Suggested test additions
 def test_load_modelopt_state_full_checkpoint(tmp_path):
     path = tmp_path / "ckpt.pt"
     torch.save({"modelopt_state": {}, "model_state_dict": {}}, path)
     with pytest.raises(ValueError, match="full checkpoint"):
         load_modelopt_state(path)
+
+
+def test_load_modelopt_state_invalid_modelopt_state_dict_type(tmp_path):
+    path = tmp_path / "bad_state_dict_type.pt"
+    torch.save({"modelopt_state_dict": {}, "modelopt_version": "1.0.0"}, path)
+    with pytest.raises((TypeError, ValueError), match="modelopt_state_dict"):
+        load_modelopt_state(path)
+
+
+def test_load_modelopt_state_invalid_modelopt_version_type(tmp_path):
+    path = tmp_path / "bad_version_type.pt"
+    torch.save({"modelopt_state_dict": [], "modelopt_version": 100}, path)
+    with pytest.raises((TypeError, ValueError), match="modelopt_version"):
+        load_modelopt_state(path)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unit/torch/opt/test_load_modelopt_state.py` around lines 22 - 48, Add
unit tests to validate the types of required keys when calling
load_modelopt_state: create cases where saved state dict has modelopt_state_dict
with a non-list value and where modelopt_version has a non-str value, then
assert load_modelopt_state() raises TypeError (or the project’s chosen
exception) with a message indicating the incorrect type; add these tests
alongside the existing ones (e.g., in
tests/unit/torch/opt/test_load_modelopt_state.py) referencing
load_modelopt_state so they catch schema-type violations for those keys.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@tests/unit/torch/opt/test_load_modelopt_state.py`:
- Around line 22-48: Add unit tests to validate the types of required keys when
calling load_modelopt_state: create cases where saved state dict has
modelopt_state_dict with a non-list value and where modelopt_version has a
non-str value, then assert load_modelopt_state() raises TypeError (or the
project’s chosen exception) with a message indicating the incorrect type; add
these tests alongside the existing ones (e.g., in
tests/unit/torch/opt/test_load_modelopt_state.py) referencing
load_modelopt_state so they catch schema-type violations for those keys.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 725382bd-9383-4f7a-a346-112a6505d448

📥 Commits

Reviewing files that changed from the base of the PR and between 163a682 and bc743fc.

📒 Files selected for processing (2)
  • modelopt/torch/opt/conversion.py
  • tests/unit/torch/opt/test_load_modelopt_state.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/opt/conversion.py

@kevalmorabia97

Copy link
Copy Markdown
Collaborator Author

/ok to test bc743fc

@kevalmorabia97 kevalmorabia97 merged commit d738995 into main May 12, 2026
49 checks passed
@kevalmorabia97 kevalmorabia97 deleted the claude/issue-1041-20260512-1627 branch May 12, 2026 19:27
jenchen13 pushed a commit that referenced this pull request May 27, 2026
Add validation to load_modelopt_state() to verify the loaded object is a
dict with the expected schema (modelopt_state_dict list and
modelopt_version str). Raises TypeError/ValueError with clear messages
when the file is malformed, and detects full checkpoints passed by
mistake, pointing users to mto.restore().

Closes #1041


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Added strict validation for model state files to surface format errors
with clear messages.
* Malformed or invalid state files now fail fast instead of being
returned silently.
* Improved detection to prevent accidental loading of full checkpoints
when only state dicts are expected.

* **Tests**
* New unit tests covering validation and loading behavior for various
malformed and valid state files.

[![Review Change
Stack](https://storage.googleapis.com/coderabbit_public_assets/review-stack-in-coderabbit-ui.svg)](https://app.coderabbit.ai/change-stack/NVIDIA/Model-Optimizer/pull/1471)
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Keval Morabia <kevalmorabia97@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature: Add validation for loaded modelopt state files

2 participants