Skip to content

Commit 87e6e8f

Browse files
committed
address review feedback on #1074
Signed-off-by: Pensieve Bot <pensieve-bot@nvidia.com>
1 parent 0b0d65f commit 87e6e8f

1 file changed

Lines changed: 140 additions & 0 deletions

File tree

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import os
17+
import tempfile
18+
19+
import pytest
20+
import torch
21+
22+
import modelopt.torch.opt as mto
23+
24+
25+
class TestModeloptStateValidation:
26+
"""Test suite for modelopt state validation."""
27+
28+
def test_validate_modelopt_state_valid(self):
29+
"""Test validation of a valid modelopt state."""
30+
valid_state = {
31+
"modelopt_state_dict": [],
32+
"modelopt_version": "0.1.0",
33+
}
34+
# Should not raise any exception
35+
mto.ModeloptStateManager.validate_modelopt_state(valid_state)
36+
37+
def test_validate_modelopt_state_not_dict(self):
38+
"""Test validation fails when state is not a dictionary."""
39+
with pytest.raises(TypeError) as exc_info:
40+
mto.ModeloptStateManager.validate_modelopt_state([1, 2, 3])
41+
assert "Expected loaded modelopt state to be a dictionary" in str(exc_info.value)
42+
43+
def test_validate_modelopt_state_missing_keys(self):
44+
"""Test validation fails when required keys are missing."""
45+
with pytest.raises(ValueError) as exc_info:
46+
mto.ModeloptStateManager.validate_modelopt_state({"modelopt_state_dict": []})
47+
assert "missing required keys" in str(exc_info.value)
48+
assert "modelopt_version" in str(exc_info.value)
49+
50+
def test_validate_modelopt_state_invalid_state_dict_type(self):
51+
"""Test validation fails when modelopt_state_dict is not a list."""
52+
with pytest.raises(TypeError) as exc_info:
53+
mto.ModeloptStateManager.validate_modelopt_state({
54+
"modelopt_state_dict": "not a list",
55+
"modelopt_version": "0.1.0",
56+
})
57+
assert "modelopt_state_dict" in str(exc_info.value)
58+
59+
def test_validate_modelopt_state_invalid_entry_not_tuple(self):
60+
"""Test validation fails when state_dict entry is not a tuple."""
61+
with pytest.raises(ValueError) as exc_info:
62+
mto.ModeloptStateManager.validate_modelopt_state({
63+
"modelopt_state_dict": [{"mode": "quantize"}],
64+
"modelopt_version": "0.1.0",
65+
})
66+
assert "tuple of length 2" in str(exc_info.value)
67+
68+
def test_validate_modelopt_state_invalid_entry_wrong_length(self):
69+
"""Test validation fails when tuple has wrong length."""
70+
with pytest.raises(ValueError) as exc_info:
71+
mto.ModeloptStateManager.validate_modelopt_state({
72+
"modelopt_state_dict": [("quantize",)],
73+
"modelopt_version": "0.1.0",
74+
})
75+
assert "tuple of length 2" in str(exc_info.value)
76+
77+
def test_validate_modelopt_state_invalid_mode_name_type(self):
78+
"""Test validation fails when mode name is not a string."""
79+
with pytest.raises(TypeError) as exc_info:
80+
mto.ModeloptStateManager.validate_modelopt_state({
81+
"modelopt_state_dict": [(123, {})],
82+
"modelopt_version": "0.1.0",
83+
})
84+
assert "mode name" in str(exc_info.value)
85+
assert "string" in str(exc_info.value)
86+
87+
def test_validate_modelopt_state_invalid_mode_state_type(self):
88+
"""Test validation fails when mode state is not a dictionary."""
89+
with pytest.raises(TypeError) as exc_info:
90+
mto.ModeloptStateManager.validate_modelopt_state({
91+
"modelopt_state_dict": [("quantize", "not a dict")],
92+
"modelopt_version": "0.1.0",
93+
})
94+
assert "mode state" in str(exc_info.value)
95+
assert "dictionary" in str(exc_info.value)
96+
97+
def test_load_modelopt_state_valid_file(self):
98+
"""Test loading a valid modelopt state from file."""
99+
valid_state = {
100+
"modelopt_state_dict": [],
101+
"modelopt_version": "0.1.0",
102+
}
103+
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
104+
temp_file = f.name
105+
try:
106+
torch.save(valid_state, temp_file)
107+
loaded_state = mto.load_modelopt_state(temp_file)
108+
assert loaded_state == valid_state
109+
finally:
110+
os.remove(temp_file)
111+
112+
def test_load_modelopt_state_invalid_file(self):
113+
"""Test loading an invalid modelopt state from file."""
114+
invalid_state = [1, 2, 3]
115+
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
116+
temp_file = f.name
117+
try:
118+
torch.save(invalid_state, temp_file)
119+
with pytest.raises(TypeError) as exc_info:
120+
mto.load_modelopt_state(temp_file)
121+
assert "Expected loaded modelopt state to be a dictionary" in str(exc_info.value)
122+
finally:
123+
os.remove(temp_file)
124+
125+
def test_load_modelopt_state_with_valid_entries(self):
126+
"""Test loading modelopt state with valid mode entries."""
127+
valid_state = {
128+
"modelopt_state_dict": [
129+
("quantize", {"config": {}, "metadata": {}}),
130+
],
131+
"modelopt_version": "0.1.0",
132+
}
133+
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
134+
temp_file = f.name
135+
try:
136+
torch.save(valid_state, temp_file)
137+
loaded_state = mto.load_modelopt_state(temp_file)
138+
assert loaded_state == valid_state
139+
finally:
140+
os.remove(temp_file)

0 commit comments

Comments
 (0)