diff --git a/modelopt/torch/opt/conversion.py b/modelopt/torch/opt/conversion.py index 432bd6abc7..bb3fe0ce4a 100644 --- a/modelopt/torch/opt/conversion.py +++ b/modelopt/torch/opt/conversion.py @@ -513,6 +513,47 @@ def save(model: nn.Module, f: str | os.PathLike | BinaryIO, **kwargs) -> None: torch.save(ckpt_dict, f, **kwargs) +def _validate_modelopt_state(state: Any) -> None: + """Validate that ``state`` has the expected modelopt state dict schema. + + Raises ``TypeError`` or ``ValueError`` with a descriptive message if the + object is not a valid modelopt state dictionary. + """ + if not isinstance(state, dict): + raise TypeError( + f"Invalid modelopt state file: expected a dict, got {type(state).__name__}." + ) + + # Detect a full checkpoint (from ``mto.save()``) passed by mistake. + if "modelopt_state" in state and "modelopt_state_dict" not in state: + raise ValueError( + "Invalid modelopt state file: the file appears to be a full " + "checkpoint saved via ``mto.save()`` (contains 'modelopt_state' " + "and 'model_state_dict' keys). Use ``mto.restore()`` to load a " + "full checkpoint, or pass only the inner modelopt state dict." + ) + + required_keys = ("modelopt_state_dict", "modelopt_version") + missing = [k for k in required_keys if k not in state] + if missing: + raise ValueError( + "Invalid modelopt state file: missing required key(s) " + f"{missing}. Expected keys: {list(required_keys)}." + ) + + if not isinstance(state["modelopt_state_dict"], list): + raise TypeError( + "Invalid modelopt state file: 'modelopt_state_dict' must be a " + f"list, got {type(state['modelopt_state_dict']).__name__}." + ) + + if not isinstance(state["modelopt_version"], str): + raise TypeError( + "Invalid modelopt state file: 'modelopt_version' must be a str, " + f"got {type(state['modelopt_version']).__name__}." + ) + + def load_modelopt_state(modelopt_state_path: str | os.PathLike, **kwargs) -> dict[str, Any]: """Load the modelopt state from a file. @@ -522,9 +563,15 @@ def load_modelopt_state(modelopt_state_path: str | os.PathLike, **kwargs) -> dic Returns: A modelopt state dictionary describing the modifications to the model. + + Raises: + TypeError: If the loaded object is not a dict or has fields of the wrong type. + ValueError: If the loaded dict is missing required modelopt state keys. """ kwargs.setdefault("map_location", "cpu") - return safe_load(modelopt_state_path, **kwargs) + state = safe_load(modelopt_state_path, **kwargs) + _validate_modelopt_state(state) + return state def restore_from_modelopt_state( diff --git a/tests/unit/torch/opt/test_load_modelopt_state.py b/tests/unit/torch/opt/test_load_modelopt_state.py new file mode 100644 index 0000000000..f10e74257a --- /dev/null +++ b/tests/unit/torch/opt/test_load_modelopt_state.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from modelopt.torch.opt.conversion import load_modelopt_state + + +def test_load_modelopt_state_valid(tmp_path): + path = tmp_path / "state.pt" + state = {"modelopt_state_dict": [], "modelopt_version": "1.0.0"} + torch.save(state, path) + loaded = load_modelopt_state(path) + assert loaded == state + + +def test_load_modelopt_state_invalid_type(tmp_path): + path = tmp_path / "bad.pt" + torch.save([1, 2, 3], path) + with pytest.raises(TypeError, match="expected a dict"): + load_modelopt_state(path) + + +def test_load_modelopt_state_missing_keys(tmp_path): + path = tmp_path / "bad.pt" + torch.save({"foo": "bar"}, path) + with pytest.raises(ValueError, match="missing required key"): + load_modelopt_state(path) + + +def test_load_modelopt_state_full_checkpoint(tmp_path): + path = tmp_path / "ckpt.pt" + torch.save({"modelopt_state": {}, "model_state_dict": {}}, path) + with pytest.raises(ValueError, match="full checkpoint"): + load_modelopt_state(path)