Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion modelopt/torch/opt/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,47 @@ def save(model: nn.Module, f: str | os.PathLike | BinaryIO, **kwargs) -> None:
torch.save(ckpt_dict, f, **kwargs)


def _validate_modelopt_state(state: Any) -> None:
"""Validate that ``state`` has the expected modelopt state dict schema.

Raises ``TypeError`` or ``ValueError`` with a descriptive message if the
object is not a valid modelopt state dictionary.
"""
if not isinstance(state, dict):
raise TypeError(
f"Invalid modelopt state file: expected a dict, got {type(state).__name__}."
)

# Detect a full checkpoint (from ``mto.save()``) passed by mistake.
if "modelopt_state" in state and "modelopt_state_dict" not in state:
raise ValueError(
"Invalid modelopt state file: the file appears to be a full "
"checkpoint saved via ``mto.save()`` (contains 'modelopt_state' "
"and 'model_state_dict' keys). Use ``mto.restore()`` to load a "
"full checkpoint, or pass only the inner modelopt state dict."
)

required_keys = ("modelopt_state_dict", "modelopt_version")
missing = [k for k in required_keys if k not in state]
if missing:
raise ValueError(
"Invalid modelopt state file: missing required key(s) "
f"{missing}. Expected keys: {list(required_keys)}."
)

if not isinstance(state["modelopt_state_dict"], list):
raise TypeError(
"Invalid modelopt state file: 'modelopt_state_dict' must be a "
f"list, got {type(state['modelopt_state_dict']).__name__}."
)

if not isinstance(state["modelopt_version"], str):
raise TypeError(
"Invalid modelopt state file: 'modelopt_version' must be a str, "
f"got {type(state['modelopt_version']).__name__}."
)
Comment thread
kevalmorabia97 marked this conversation as resolved.


def load_modelopt_state(modelopt_state_path: str | os.PathLike, **kwargs) -> dict[str, Any]:
"""Load the modelopt state from a file.

Expand All @@ -522,9 +563,15 @@ def load_modelopt_state(modelopt_state_path: str | os.PathLike, **kwargs) -> dic

Returns:
A modelopt state dictionary describing the modifications to the model.

Raises:
TypeError: If the loaded object is not a dict or has fields of the wrong type.
ValueError: If the loaded dict is missing required modelopt state keys.
"""
kwargs.setdefault("map_location", "cpu")
return safe_load(modelopt_state_path, **kwargs)
state = safe_load(modelopt_state_path, **kwargs)
_validate_modelopt_state(state)
return state


def restore_from_modelopt_state(
Expand Down
48 changes: 48 additions & 0 deletions tests/unit/torch/opt/test_load_modelopt_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from modelopt.torch.opt.conversion import load_modelopt_state


def test_load_modelopt_state_valid(tmp_path):
path = tmp_path / "state.pt"
state = {"modelopt_state_dict": [], "modelopt_version": "1.0.0"}
torch.save(state, path)
loaded = load_modelopt_state(path)
assert loaded == state


def test_load_modelopt_state_invalid_type(tmp_path):
path = tmp_path / "bad.pt"
torch.save([1, 2, 3], path)
with pytest.raises(TypeError, match="expected a dict"):
load_modelopt_state(path)


def test_load_modelopt_state_missing_keys(tmp_path):
path = tmp_path / "bad.pt"
torch.save({"foo": "bar"}, path)
with pytest.raises(ValueError, match="missing required key"):
load_modelopt_state(path)


def test_load_modelopt_state_full_checkpoint(tmp_path):
path = tmp_path / "ckpt.pt"
torch.save({"modelopt_state": {}, "model_state_dict": {}}, path)
with pytest.raises(ValueError, match="full checkpoint"):
load_modelopt_state(path)
Loading