From 163a682dec146f7069f8b75eb47376f8b0190c61 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Tue, 12 May 2026 16:31:43 +0000 Subject: [PATCH 1/2] feat(opt): validate loaded modelopt state files Add validation to load_modelopt_state() to verify the loaded object is a dict with the expected schema (modelopt_state_dict list and modelopt_version str). Raises TypeError/ValueError with clear messages when the file is malformed, and detects full checkpoints passed by mistake, pointing users to mto.restore(). Closes #1041 Co-authored-by: Keval Morabia Signed-off-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> --- modelopt/torch/opt/conversion.py | 50 +++++++++- .../torch/opt/test_load_modelopt_state.py | 95 +++++++++++++++++++ 2 files changed, 144 insertions(+), 1 deletion(-) create mode 100644 tests/unit/torch/opt/test_load_modelopt_state.py diff --git a/modelopt/torch/opt/conversion.py b/modelopt/torch/opt/conversion.py index 432bd6abc7..55aeb43da1 100644 --- a/modelopt/torch/opt/conversion.py +++ b/modelopt/torch/opt/conversion.py @@ -513,6 +513,48 @@ def save(model: nn.Module, f: str | os.PathLike | BinaryIO, **kwargs) -> None: torch.save(ckpt_dict, f, **kwargs) +def _validate_modelopt_state(state: Any) -> None: + """Validate that ``state`` has the expected modelopt state dict schema. + + Raises ``TypeError`` or ``ValueError`` with a descriptive message if the + object is not a valid modelopt state dictionary. + """ + if not isinstance(state, dict): + raise TypeError( + "Invalid modelopt state file: expected a dict, got " + f"{type(state).__name__}." + ) + + # Detect a full checkpoint (from ``mto.save()``) passed by mistake. + if "modelopt_state" in state and "modelopt_state_dict" not in state: + raise ValueError( + "Invalid modelopt state file: the file appears to be a full " + "checkpoint saved via ``mto.save()`` (contains 'modelopt_state' " + "and 'model_state_dict' keys). Use ``mto.restore()`` to load a " + "full checkpoint, or pass only the inner modelopt state dict." + ) + + required_keys = ("modelopt_state_dict", "modelopt_version") + missing = [k for k in required_keys if k not in state] + if missing: + raise ValueError( + "Invalid modelopt state file: missing required key(s) " + f"{missing}. Expected keys: {list(required_keys)}." + ) + + if not isinstance(state["modelopt_state_dict"], list): + raise TypeError( + "Invalid modelopt state file: 'modelopt_state_dict' must be a " + f"list, got {type(state['modelopt_state_dict']).__name__}." + ) + + if not isinstance(state["modelopt_version"], str): + raise TypeError( + "Invalid modelopt state file: 'modelopt_version' must be a str, " + f"got {type(state['modelopt_version']).__name__}." + ) + + def load_modelopt_state(modelopt_state_path: str | os.PathLike, **kwargs) -> dict[str, Any]: """Load the modelopt state from a file. @@ -522,9 +564,15 @@ def load_modelopt_state(modelopt_state_path: str | os.PathLike, **kwargs) -> dic Returns: A modelopt state dictionary describing the modifications to the model. + + Raises: + TypeError: If the loaded object is not a dict or has fields of the wrong type. + ValueError: If the loaded dict is missing required modelopt state keys. """ kwargs.setdefault("map_location", "cpu") - return safe_load(modelopt_state_path, **kwargs) + state = safe_load(modelopt_state_path, **kwargs) + _validate_modelopt_state(state) + return state def restore_from_modelopt_state( diff --git a/tests/unit/torch/opt/test_load_modelopt_state.py b/tests/unit/torch/opt/test_load_modelopt_state.py new file mode 100644 index 0000000000..33bcf84261 --- /dev/null +++ b/tests/unit/torch/opt/test_load_modelopt_state.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from modelopt.torch.opt.conversion import _validate_modelopt_state, load_modelopt_state + + +def test_validate_accepts_empty_state_dict(): + _validate_modelopt_state({"modelopt_state_dict": [], "modelopt_version": "0.0.0"}) + + +def test_validate_accepts_populated_state_dict(): + _validate_modelopt_state( + { + "modelopt_state_dict": [("some_mode", {"config": {}, "metadata": {}})], + "modelopt_version": "1.2.3", + } + ) + + +@pytest.mark.parametrize("bad", [[], None, 42, "state", (1, 2)]) +def test_validate_rejects_non_dict(bad): + with pytest.raises(TypeError, match="expected a dict"): + _validate_modelopt_state(bad) + + +def test_validate_rejects_full_checkpoint(): + ckpt = {"modelopt_state": {}, "model_state_dict": {}} + with pytest.raises(ValueError, match="full checkpoint"): + _validate_modelopt_state(ckpt) + + +def test_validate_rejects_missing_version(): + with pytest.raises(ValueError, match="missing required key"): + _validate_modelopt_state({"modelopt_state_dict": []}) + + +def test_validate_rejects_missing_state_dict(): + with pytest.raises(ValueError, match="missing required key"): + _validate_modelopt_state({"modelopt_version": "1.0.0"}) + + +def test_validate_rejects_non_list_state_dict(): + with pytest.raises(TypeError, match="'modelopt_state_dict' must be a list"): + _validate_modelopt_state( + {"modelopt_state_dict": {"not": "a list"}, "modelopt_version": "1.0.0"} + ) + + +def test_validate_rejects_non_str_version(): + with pytest.raises(TypeError, match="'modelopt_version' must be a str"): + _validate_modelopt_state({"modelopt_state_dict": [], "modelopt_version": 1.0}) + + +def test_load_modelopt_state_valid(tmp_path): + path = tmp_path / "state.pt" + state = {"modelopt_state_dict": [], "modelopt_version": "1.0.0"} + torch.save(state, path) + loaded = load_modelopt_state(path) + assert loaded == state + + +def test_load_modelopt_state_invalid_type(tmp_path): + path = tmp_path / "bad.pt" + torch.save([1, 2, 3], path) + with pytest.raises(TypeError, match="expected a dict"): + load_modelopt_state(path) + + +def test_load_modelopt_state_missing_keys(tmp_path): + path = tmp_path / "bad.pt" + torch.save({"foo": "bar"}, path) + with pytest.raises(ValueError, match="missing required key"): + load_modelopt_state(path) + + +def test_load_modelopt_state_full_checkpoint(tmp_path): + path = tmp_path / "ckpt.pt" + torch.save({"modelopt_state": {}, "model_state_dict": {}}, path) + with pytest.raises(ValueError, match="full checkpoint"): + load_modelopt_state(path) From bc743fc00fa35a24e1b505800c864e664b4cb037 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Tue, 12 May 2026 10:00:04 -0700 Subject: [PATCH 2/2] cleanup Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- modelopt/torch/opt/conversion.py | 3 +- .../torch/opt/test_load_modelopt_state.py | 49 +------------------ 2 files changed, 2 insertions(+), 50 deletions(-) diff --git a/modelopt/torch/opt/conversion.py b/modelopt/torch/opt/conversion.py index 55aeb43da1..bb3fe0ce4a 100644 --- a/modelopt/torch/opt/conversion.py +++ b/modelopt/torch/opt/conversion.py @@ -521,8 +521,7 @@ def _validate_modelopt_state(state: Any) -> None: """ if not isinstance(state, dict): raise TypeError( - "Invalid modelopt state file: expected a dict, got " - f"{type(state).__name__}." + f"Invalid modelopt state file: expected a dict, got {type(state).__name__}." ) # Detect a full checkpoint (from ``mto.save()``) passed by mistake. diff --git a/tests/unit/torch/opt/test_load_modelopt_state.py b/tests/unit/torch/opt/test_load_modelopt_state.py index 33bcf84261..f10e74257a 100644 --- a/tests/unit/torch/opt/test_load_modelopt_state.py +++ b/tests/unit/torch/opt/test_load_modelopt_state.py @@ -16,54 +16,7 @@ import pytest import torch -from modelopt.torch.opt.conversion import _validate_modelopt_state, load_modelopt_state - - -def test_validate_accepts_empty_state_dict(): - _validate_modelopt_state({"modelopt_state_dict": [], "modelopt_version": "0.0.0"}) - - -def test_validate_accepts_populated_state_dict(): - _validate_modelopt_state( - { - "modelopt_state_dict": [("some_mode", {"config": {}, "metadata": {}})], - "modelopt_version": "1.2.3", - } - ) - - -@pytest.mark.parametrize("bad", [[], None, 42, "state", (1, 2)]) -def test_validate_rejects_non_dict(bad): - with pytest.raises(TypeError, match="expected a dict"): - _validate_modelopt_state(bad) - - -def test_validate_rejects_full_checkpoint(): - ckpt = {"modelopt_state": {}, "model_state_dict": {}} - with pytest.raises(ValueError, match="full checkpoint"): - _validate_modelopt_state(ckpt) - - -def test_validate_rejects_missing_version(): - with pytest.raises(ValueError, match="missing required key"): - _validate_modelopt_state({"modelopt_state_dict": []}) - - -def test_validate_rejects_missing_state_dict(): - with pytest.raises(ValueError, match="missing required key"): - _validate_modelopt_state({"modelopt_version": "1.0.0"}) - - -def test_validate_rejects_non_list_state_dict(): - with pytest.raises(TypeError, match="'modelopt_state_dict' must be a list"): - _validate_modelopt_state( - {"modelopt_state_dict": {"not": "a list"}, "modelopt_version": "1.0.0"} - ) - - -def test_validate_rejects_non_str_version(): - with pytest.raises(TypeError, match="'modelopt_version' must be a str"): - _validate_modelopt_state({"modelopt_state_dict": [], "modelopt_version": 1.0}) +from modelopt.torch.opt.conversion import load_modelopt_state def test_load_modelopt_state_valid(tmp_path):