Skip to content

Commit b55196c

Browse files
committed
feat: Added recipe save functionality to qconfig_save
Signed-off-by: Brandon Groth <brandon.m.groth@gmail.com>
1 parent 18621d7 commit b55196c

1 file changed

Lines changed: 126 additions & 67 deletions

File tree

fms_mo/utils/qconfig_utils.py

Lines changed: 126 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,19 @@
1414
"""Util functions for qconfig."""
1515

1616
# Standard
17+
from copy import deepcopy
18+
from datetime import date
1719
from pathlib import Path
1820
from typing import Any
1921
import json
2022
import logging
2123
import os
22-
import warnings
23-
import pkg_resources
2424
import sys
25+
import warnings
2526

2627
# Third Party
2728
from torch import nn
29+
import pkg_resources
2830
import torch
2931

3032
# Local
@@ -35,7 +37,8 @@
3537

3638
logger = logging.getLogger(__name__)
3739

38-
def get_pkg_info(other_pkgs:list = None):
40+
41+
def get_pkg_info(other_pkgs: list = None):
3942
"""
4043
Get the package name and version of important packages currently in use.
4144
@@ -51,24 +54,25 @@ def get_pkg_info(other_pkgs:list = None):
5154
other_pkgs = []
5255

5356
# Get installed packages name:version
54-
pkgs.update( {
55-
pkg: pkg_resources.get_distribution(pkg).version for pkg in [
56-
"fms-model-optimizer",
57-
"torch",
58-
"transformers",
59-
"triton",
60-
] + other_pkgs
61-
})
57+
pkgs.update(
58+
{
59+
pkg: pkg_resources.get_distribution(pkg).version
60+
for pkg in [
61+
"fms-model-optimizer",
62+
"torch",
63+
"transformers",
64+
"triton",
65+
]
66+
+ other_pkgs
67+
}
68+
)
6269

6370
return pkgs
6471

6572

6673
def config_defaults():
6774
"""Create defaults for qconfig"""
6875
cfg_defaults = {
69-
# Environment
70-
"pkg_versions": get_pkg_info(),
71-
7276
# nbits vars
7377
"nbits_a": 32,
7478
"nbits_w": 32,
@@ -82,7 +86,6 @@ def config_defaults():
8286
"nbits_w_lstm": None,
8387
"nbits_i_lstm": None,
8488
"nbits_h_lstm": None,
85-
8689
# qmodes vars
8790
"qa_mode": "pact+",
8891
"qw_mode": "sawb+",
@@ -93,7 +96,6 @@ def config_defaults():
9396
"bmm2_qm1_mode": "pact",
9497
"bmm2_qm2_mode": "pact",
9598
"qa_mode_lstm": "pact+",
96-
9799
# mode_calib vars
98100
"qa_mode_calib": "percentile",
99101
"qw_mode_calib": "percentile",
@@ -117,11 +119,10 @@ def config_defaults():
117119
"qspecial_layers": {},
118120
"qsinglesided_name": [],
119121
"clip_val_asst_percentile": (0.1, 99.9),
120-
"params2optim":
121-
{
122-
"W": [[] for _ in range(torch.cuda.device_count())],
123-
"cvs": [[] for _ in range(torch.cuda.device_count())],
124-
},
122+
"params2optim": {
123+
"W": [[] for _ in range(torch.cuda.device_count())],
124+
"cvs": [[] for _ in range(torch.cuda.device_count())],
125+
},
125126
# PTQ vars
126127
"ptq_nbatch": 100,
127128
"ptq_batchsize": 12,
@@ -153,11 +154,57 @@ def config_defaults():
153154
"qvit": False,
154155
"numparamsfromloadertomodel": 1,
155156
"gradclip": 0.0,
157+
"smoothq": False,
156158
}
157159

158160
return cfg_defaults
159161

160162

163+
def find_recipe_json(recipe: str, subdir: str = None):
164+
"""
165+
Search for recipe .json file in fms_mo and return the path
166+
167+
Args:
168+
recipe (str): Recipe file name (can be the "name" or "prefix.json").
169+
subdir (str, optional): Alternative subdir path from pkg_root. Defaults to None.
170+
171+
Returns:
172+
Path: Path to recipe .json if found, else None
173+
"""
174+
cwd = Path().resolve()
175+
pkg_root = Path(__file__).parent.parent.resolve()
176+
file_in_cwd = cwd / recipe
177+
if subdir:
178+
file_in_recipes = pkg_root / subdir / "recipes" / recipe
179+
file_in_recipes2 = pkg_root / subdir / "recipes" / f"{recipe}.json"
180+
else:
181+
file_in_recipes = pkg_root / "recipes" / recipe
182+
file_in_recipes2 = pkg_root / "recipes" / f"{recipe}.json"
183+
184+
if not recipe.endswith(".json") and file_in_recipes2.exists():
185+
json_file = file_in_recipes2
186+
elif file_in_cwd.exists():
187+
json_file = file_in_cwd
188+
elif file_in_recipes.exists():
189+
json_file = file_in_recipes
190+
else:
191+
json_file = None
192+
193+
return json_file
194+
195+
196+
def get_recipe(recipe: str, subdir: str = None):
197+
json_file = find_recipe_json(recipe, subdir)
198+
199+
temp_data = None
200+
if json_file:
201+
with open(json_file, "r", encoding="utf-8") as openfile:
202+
temp_data = json.load(openfile)
203+
logger.info(f"Loaded settings from {json_file}.")
204+
205+
return temp_data
206+
207+
161208
def qconfig_init(recipe: str = None, args: Any = None):
162209
"""Three possible ways to create qcfg:
163210
1. create a default qcfg
@@ -315,29 +362,13 @@ def qconfig_init(recipe: str = None, args: Any = None):
315362
# 2. load values from json, if specified and exists
316363
# this can be used to load a previously saved ckpt as well
317364
if recipe:
318-
cwd = Path().resolve()
319-
pkg_root = Path(__file__).parent.parent.resolve()
320-
file_in_cwd = cwd / recipe
321-
file_in_recipes = pkg_root / "recipes" / recipe
322-
file_in_recipes2 = pkg_root / "recipes" / f"{recipe}.json"
323-
temp_cfg = None
324-
325-
if not recipe.endswith(".json") and file_in_recipes2.exists():
326-
qcfg_json = file_in_recipes2
327-
elif file_in_cwd.exists():
328-
qcfg_json = file_in_cwd
329-
elif file_in_recipes.exists():
330-
qcfg_json = file_in_recipes
331-
else:
332-
qcfg_json = None
333-
334-
if qcfg_json:
335-
with open(qcfg_json, "r", encoding="utf-8") as openfile:
336-
temp_cfg = json.load(openfile)
365+
# qcfg recipes should reside in fms_mo/recipes
366+
temp_cfg = get_recipe(recipe)
367+
if temp_cfg:
337368
qcfg.update(temp_cfg)
338-
logger.info(
339-
f"Loaded settings from {qcfg_json} and updated the default values."
340-
)
369+
logger.info(f"Updated config w/ recipe values")
370+
else:
371+
raise ValueError(f"Config recipe {recipe} was not found.")
341372

342373
# 3. parse args, if provided
343374
if hasattr(args, "__dict__"):
@@ -414,7 +445,7 @@ def serialize_config(config):
414445
return config, dump
415446

416447

417-
def remove_unwanted_from_config(config, minimal:bool=True):
448+
def remove_unwanted_from_config(config, minimal: bool = True):
418449
"""Remove deprecated items or things cannot be saved as text (json)"""
419450
unwanted_items = [
420451
"sweep_cv_percentile",
@@ -435,12 +466,10 @@ def remove_unwanted_from_config(config, minimal:bool=True):
435466
if minimal:
436467
default_config = config_defaults()
437468
for key, val in config.items():
469+
# If config has a default setting, add to unwanted items
438470
if key in default_config and default_config.get(key) == val:
439471
unwanted_items.append(key)
440472

441-
# Do not remove back pkg_versions
442-
unwanted_items.remove("pkg_versions")
443-
444473
len_before = len(config)
445474
dump = {k: config.pop(k) for k in unwanted_items if k in config}
446475
assert (
@@ -484,32 +513,67 @@ def add_required_defaults_to_config(config):
484513
config[key] = default_val
485514

486515

487-
def add_wanted_defaults_to_config(config):
516+
def add_wanted_defaults_to_config(config, minimal: bool = True):
488517
"""Util function to add basic config defaults that are missing into a config
489518
if a wanted item is not in the config, add it w/ default value
490519
"""
491-
wanted_items = config_defaults()
492-
for wanted_name, wanted_default_val in wanted_items.items():
493-
if wanted_name not in config:
494-
config[wanted_name] = wanted_default_val
520+
if not minimal:
521+
wanted_items = config_defaults()
522+
for wanted_name, wanted_default_val in wanted_items.items():
523+
if wanted_name not in config:
524+
config[wanted_name] = wanted_default_val
495525

496526

497-
def qconfig_save(qcfg, minimal=False, fname="qcfg.json"):
527+
def qconfig_save(
528+
qcfg: dict, recipe: str = "save", minimal: bool = True, fname="qcfg.json"
529+
):
498530
"""
499531
Try to save qcfg into a JSON file (or use .pt format if something really can't be text-only).
500532
For example, qcfg['mapping'] has some classes as keys and values, json won't work. We will try
501533
to remove unserializable items first.
534+
535+
Args:
536+
qcfg (dict): Quantized config.
537+
recipe (str, optional): String name for a save recipe. Defaults to None.
538+
minimal (bool, optional): Save a minimal quantized config. Defaults to True.
539+
fname (str, optional): File name to save quantized config. Defaults to "qcfg.json".
502540
"""
541+
recipe_items = []
542+
543+
# First check in qcfg for added save list
544+
if recipe in qcfg:
545+
recipe_items = qcfg[recipe]
546+
547+
# Next, check in fms_mo/recipes and merge them into a unique set (in case they differ)
548+
recipe_items_json = get_recipe(recipe + ".json")
549+
if recipe_items_json:
550+
recipe_items = list(set(recipe_items + recipe_items_json))
551+
552+
# If we found recipe items to add, fetch them from qcfg
553+
if recipe_items:
554+
temp_qcfg = {}
555+
for key in recipe_items:
556+
if key in qcfg:
557+
temp_qcfg[key] = qcfg[key]
558+
else:
559+
raise ValueError(f"Desired save recipe {key=} not in qcfg!")
503560

504-
# Remove deprecated/unwanted key,vals in config
505-
temp_qcfg, removed_items = remove_unwanted_from_config(qcfg, minimal)
561+
else:
562+
# We assume a full qcfg is being saved - trim it!
563+
temp_qcfg = deepcopy(qcfg)
506564

507-
# Add back wanted defaults for any missing vars
508-
if not minimal:
509-
add_wanted_defaults_to_config(temp_qcfg)
565+
# Remove deprecated/unwanted key,vals in config
566+
temp_qcfg, _ = remove_unwanted_from_config(temp_qcfg, minimal)
510567

511-
# Clean config of any unwanted key,vals not found in unwanted list
512-
temp_qcfg, removed_items2 = serialize_config(temp_qcfg)
568+
# Add back wanted defaults for any missing vars
569+
add_wanted_defaults_to_config(temp_qcfg, minimal)
570+
571+
# Clean config of any unwanted key,vals not found in unwanted list
572+
temp_qcfg, _ = serialize_config(temp_qcfg)
573+
574+
# Add in date and system information for archival
575+
temp_qcfg["date"] = date.today().strftime("%Y-%B-%d")
576+
temp_qcfg["pkg_versions"] = get_pkg_info()
513577

514578
# Finally, check to ensure all values are valid before saving
515579
check_config(temp_qcfg)
@@ -521,10 +585,6 @@ def qconfig_save(qcfg, minimal=False, fname="qcfg.json"):
521585
with open(fname, "w", encoding="utf-8") as outfile:
522586
json.dump(temp_qcfg, outfile, indent=4)
523587

524-
# restore original qcfg
525-
qcfg.update(removed_items)
526-
qcfg.update(removed_items2)
527-
528588

529589
def qconfig_load(fname="qcfg.json"):
530590
"""Read config in json format, work together with qconfig_save"""
@@ -533,7 +593,7 @@ def qconfig_load(fname="qcfg.json"):
533593
config = json.load(openfile)
534594

535595
# Add back wanted defaults for any missing vars
536-
add_wanted_defaults_to_config(config)
596+
add_wanted_defaults_to_config(config, minimal=False)
537597
add_required_defaults_to_config(config)
538598

539599
# Ensure config has correct values before continuing
@@ -850,8 +910,7 @@ def check_config(config, model_dtype=None):
850910
# clip_val_asst is the percentile to use for calibration. TODO: consider renaming
851911
clip_val_asst_percentile_default = default_config.get("clip_val_asst_percentile")
852912
clip_val_asst_percentile = config.get(
853-
"clip_val_asst_percentile",
854-
clip_val_asst_percentile_default
913+
"clip_val_asst_percentile", clip_val_asst_percentile_default
855914
)
856915
if len(clip_val_asst_percentile) != 2:
857916
raise ValueError(

0 commit comments

Comments
 (0)