Skip to content

Commit 71c6aab

Browse files
committed
test: Added tests for recipe save feature in qconfig_save
Signed-off-by: Brandon Groth <brandon.m.groth@gmail.com>
1 parent b55196c commit 71c6aab

4 files changed

Lines changed: 153 additions & 6 deletions

File tree

tests/models/conftest.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,36 @@ def config_fp16(request):
798798
qconfig["nbits_w"] = 16
799799
return qconfig
800800

801+
save_list_params = [
802+
["qa_mode", "qw_mode", "nbits_a", "nbits_w", "qskip_layer_name"],
803+
]
804+
@pytest.fixture(scope="session", params=save_list_params)
805+
def save_list(request):
806+
"""
807+
Generate a save list for testing user-requested save config.
808+
809+
Args:
810+
request (list): List of variables to save in a quantized config.
811+
812+
Returns:
813+
list: List of variables to save in a quantized config.
814+
"""
815+
return request.param
816+
817+
bad_recipe_params = ["qat_int7", "pzq_int8"]
818+
819+
@pytest.fixture(scope="session", params=bad_recipe_params)
820+
def bad_recipe(request):
821+
"""
822+
Get a bad recipe json file name in fms_mo/recipes
823+
824+
Args:
825+
request (str): Bad recipe name in fms_mo/recipes
826+
827+
Returns:
828+
str: Bad recipe name
829+
"""
830+
return request.param
801831

802832
# Create QAT/PTQ int8 config fixture.
803833
config_params = ["qat_int8", "ptq_int8"]

tests/models/test_model_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,16 @@ def load_json(file_path: str = "qcfg.json"):
173173
assert json_file is not None, f"JSON at {file_path} was not found"
174174
return json_file
175175

176+
def save_json(data, file_path: str = "qcfg.json"):
177+
"""
178+
Save data object to json file
179+
180+
Args:
181+
data (_type_): _description_
182+
file_path (str, optional): _description_. Defaults to "qcfg.json".
183+
"""
184+
with open(file_path, "w", encoding="utf-8") as outfile:
185+
json.dump(data, outfile, indent=4)
176186

177187
def save_serialized_json(config: dict, file_path: str = "qcfg.json"):
178188
"""
@@ -189,5 +199,4 @@ def save_serialized_json(config: dict, file_path: str = "qcfg.json"):
189199
del config[key]
190200

191201
serialize_config(config) # Only remove stuff necessary to dump
192-
with open(file_path, "w", encoding="utf-8") as outfile:
193-
json.dump(config, outfile, indent=4)
202+
save_json(config, file_path)

tests/models/test_qmodelprep.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
# Local
2626
# fms_mo imports
27-
from fms_mo import qmodel_prep
27+
from fms_mo import qconfig_init, qmodel_prep
2828
from fms_mo.prep import has_quantized_module
2929
from fms_mo.utils.utils import patch_torch_bmm
3030
from tests.models.test_model_utils import count_qmodules, delete_config, qmodule_error
@@ -53,6 +53,18 @@ def test_model_quantized(
5353
with pytest.raises(RuntimeError):
5454
qmodel_prep(model_quantized, sample_input_fp32, config_fp32)
5555

56+
def test_bad_recipe(
57+
bad_recipe: str,
58+
):
59+
"""
60+
Test if giving a bad recipe .json file name results in a ValueError.
61+
62+
Args:
63+
bad_recipe (str): Bad .json file name
64+
"""
65+
with pytest.raises(ValueError):
66+
qconfig_init(recipe=bad_recipe)
67+
5668

5769
def test_double_qmodel_prep_assert(
5870
model_fp32: torch.nn.Module,

tests/models/test_saveconfig.py

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121

2222
# Local
2323
from fms_mo.utils.qconfig_utils import qconfig_load, qconfig_save
24-
from tests.models.test_model_utils import delete_config, load_json, save_serialized_json
24+
from tests.models.test_model_utils import (
25+
delete_config,
26+
load_json,
27+
save_json,
28+
save_serialized_json,
29+
)
2530

2631
#########
2732
# Tests #
@@ -45,7 +50,7 @@ def test_save_config_warn_bad_pair(
4550
# Add bad key,val pair and save ; should generate UserWarning(s) for removing bad pair
4651
config_fp32[key] = val
4752
with pytest.warns(UserWarning):
48-
qconfig_save(config_fp32)
53+
qconfig_save(config_fp32, minimal=False)
4954

5055
# Load saved config and assert the key is not saved
5156
loaded_config = load_json("qcfg.json") # load json as is - do not modify
@@ -71,14 +76,102 @@ def test_save_config_wanted_pairs(
7176
# Delete wanted pair from config and save ; should be reset to default
7277
if key in config_fp32:
7378
del config_fp32[key]
74-
qconfig_save(config_fp32)
79+
qconfig_save(config_fp32, minimal=False)
7580

7681
# Load saved config and check the wanted pair was reset to default
7782
loaded_config = load_json()
7883
assert loaded_config.get(key) == default_val
7984

8085
delete_config()
8186

87+
def test_save_config_with_qcfg_save(
88+
config_fp32: dict,
89+
save_list: list,
90+
):
91+
"""
92+
Test for checking that the "save_list" functionality works from within a quantized config
93+
94+
Args:
95+
config_fp32 (dict): Config for fp32 quantization
96+
save_list (list): List of variables to save in a quantized config.
97+
"""
98+
delete_config()
99+
config_fp32["save"] = save_list
100+
101+
qconfig_save(config_fp32, minimal=False)
102+
103+
loaded_config = load_json()
104+
105+
# Remove pkg_versions and date before processing
106+
del loaded_config["pkg_versions"]
107+
del loaded_config["date"]
108+
109+
assert len(loaded_config) == len(save_list)
110+
111+
# Now ensure that every value in save_list was properly saved
112+
for key in save_list:
113+
assert key in loaded_config
114+
assert loaded_config.get(key) == config_fp32.get(key)
115+
116+
delete_config()
117+
del config_fp32["save"]
118+
119+
def test_save_config_with_recipe_save(
120+
config_fp32: dict,
121+
save_list: list,
122+
):
123+
"""
124+
Test for checking that the "save_list" functionality works from a saved json file
125+
126+
Args:
127+
config_fp32 (dict): Config for fp32 quantization
128+
save_list (list): List of variables to save in a quantized config.
129+
"""
130+
# Delete both qcfg and the save.json before starting
131+
delete_config()
132+
delete_config("save.json")
133+
134+
# Save new "save.json"
135+
save_path = "save_list.json"
136+
save_json(save_list, file_path=save_path)
137+
138+
qconfig_save(config_fp32, recipe="save_list")
139+
140+
# Check that saved qcfg matches
141+
loaded_config = load_json()
142+
143+
# Remove pkg_versions and date before processing
144+
del loaded_config["pkg_versions"]
145+
del loaded_config["date"]
146+
147+
assert len(loaded_config) == len(save_list)
148+
149+
# Now ensure that every value in save_list was properly saved
150+
for key in save_list:
151+
assert key in loaded_config
152+
assert loaded_config.get(key) == config_fp32.get(key)
153+
154+
delete_config()
155+
delete_config("save.json")
156+
157+
def test_save_config_minimal(
158+
config_fp32: dict,
159+
):
160+
delete_config()
161+
162+
qconfig_save(config_fp32, minimal=True)
163+
164+
# Check that saved qcfg matches
165+
loaded_config = load_json()
166+
167+
# Remove pkg_versions and date before processing
168+
del loaded_config["pkg_versions"]
169+
del loaded_config["date"]
170+
171+
# No items should exist - default config should be completely removed
172+
assert len(loaded_config) == 0
173+
174+
delete_config()
82175

83176
def test_load_config_restored_pair(
84177
config_fp32: dict,
@@ -96,13 +189,16 @@ def test_load_config_restored_pair(
96189

97190
if key in config_fp32:
98191
del config_fp32[key]
192+
99193
save_serialized_json(
100194
config_fp32
101195
) # Save config as is, no other edits other than to serialize
102196

103197
loaded_config = qconfig_load("qcfg.json")
104198
assert loaded_config.get(key) == default_val
105199

200+
delete_config()
201+
106202

107203
def test_load_config_required_pair(
108204
config_fp32: dict,

0 commit comments

Comments
 (0)