1414"""Util functions for qconfig."""
1515
1616# Standard
17+ from copy import deepcopy
18+ from datetime import date
1719from pathlib import Path
1820from typing import Any
1921import json
2022import logging
2123import os
22- import warnings
23- import pkg_resources
2424import sys
25+ import warnings
2526
2627# Third Party
2728from torch import nn
29+ import pkg_resources
2830import torch
2931
3032# Local
3537
3638logger = logging .getLogger (__name__ )
3739
38- def get_pkg_info (other_pkgs :list = None ):
40+
41+ def get_pkg_info (other_pkgs : list = None ):
3942 """
4043 Get the package name and version of important packages currently in use.
4144
@@ -51,24 +54,25 @@ def get_pkg_info(other_pkgs:list = None):
5154 other_pkgs = []
5255
5356 # Get installed packages name:version
54- pkgs .update ( {
55- pkg : pkg_resources .get_distribution (pkg ).version for pkg in [
56- "fms-model-optimizer" ,
57- "torch" ,
58- "transformers" ,
59- "triton" ,
60- ] + other_pkgs
61- })
57+ pkgs .update (
58+ {
59+ pkg : pkg_resources .get_distribution (pkg ).version
60+ for pkg in [
61+ "fms-model-optimizer" ,
62+ "torch" ,
63+ "transformers" ,
64+ "triton" ,
65+ ]
66+ + other_pkgs
67+ }
68+ )
6269
6370 return pkgs
6471
6572
6673def config_defaults ():
6774 """Create defaults for qconfig"""
6875 cfg_defaults = {
69- # Environment
70- "pkg_versions" : get_pkg_info (),
71-
7276 # nbits vars
7377 "nbits_a" : 32 ,
7478 "nbits_w" : 32 ,
@@ -82,7 +86,6 @@ def config_defaults():
8286 "nbits_w_lstm" : None ,
8387 "nbits_i_lstm" : None ,
8488 "nbits_h_lstm" : None ,
85-
8689 # qmodes vars
8790 "qa_mode" : "pact+" ,
8891 "qw_mode" : "sawb+" ,
@@ -93,7 +96,6 @@ def config_defaults():
9396 "bmm2_qm1_mode" : "pact" ,
9497 "bmm2_qm2_mode" : "pact" ,
9598 "qa_mode_lstm" : "pact+" ,
96-
9799 # mode_calib vars
98100 "qa_mode_calib" : "percentile" ,
99101 "qw_mode_calib" : "percentile" ,
@@ -117,11 +119,10 @@ def config_defaults():
117119 "qspecial_layers" : {},
118120 "qsinglesided_name" : [],
119121 "clip_val_asst_percentile" : (0.1 , 99.9 ),
120- "params2optim" :
121- {
122- "W" : [[] for _ in range (torch .cuda .device_count ())],
123- "cvs" : [[] for _ in range (torch .cuda .device_count ())],
124- },
122+ "params2optim" : {
123+ "W" : [[] for _ in range (torch .cuda .device_count ())],
124+ "cvs" : [[] for _ in range (torch .cuda .device_count ())],
125+ },
125126 # PTQ vars
126127 "ptq_nbatch" : 100 ,
127128 "ptq_batchsize" : 12 ,
@@ -153,11 +154,57 @@ def config_defaults():
153154 "qvit" : False ,
154155 "numparamsfromloadertomodel" : 1 ,
155156 "gradclip" : 0.0 ,
157+ "smoothq" : False ,
156158 }
157159
158160 return cfg_defaults
159161
160162
163+ def find_recipe_json (recipe : str , subdir : str = None ):
164+ """
165+ Search for recipe .json file in fms_mo and return the path
166+
167+ Args:
168+ recipe (str): Recipe file name (can be the "name" or "prefix.json").
169+ subdir (str, optional): Alternative subdir path from pkg_root. Defaults to None.
170+
171+ Returns:
172+ Path: Path to recipe .json if found, else None
173+ """
174+ cwd = Path ().resolve ()
175+ pkg_root = Path (__file__ ).parent .parent .resolve ()
176+ file_in_cwd = cwd / recipe
177+ if subdir :
178+ file_in_recipes = pkg_root / subdir / "recipes" / recipe
179+ file_in_recipes2 = pkg_root / subdir / "recipes" / f"{ recipe } .json"
180+ else :
181+ file_in_recipes = pkg_root / "recipes" / recipe
182+ file_in_recipes2 = pkg_root / "recipes" / f"{ recipe } .json"
183+
184+ if not recipe .endswith (".json" ) and file_in_recipes2 .exists ():
185+ json_file = file_in_recipes2
186+ elif file_in_cwd .exists ():
187+ json_file = file_in_cwd
188+ elif file_in_recipes .exists ():
189+ json_file = file_in_recipes
190+ else :
191+ json_file = None
192+
193+ return json_file
194+
195+
196+ def get_recipe (recipe : str , subdir : str = None ):
197+ json_file = find_recipe_json (recipe , subdir )
198+
199+ temp_data = None
200+ if json_file :
201+ with open (json_file , "r" , encoding = "utf-8" ) as openfile :
202+ temp_data = json .load (openfile )
203+ logger .info (f"Loaded settings from { json_file } ." )
204+
205+ return temp_data
206+
207+
161208def qconfig_init (recipe : str = None , args : Any = None ):
162209 """Three possible ways to create qcfg:
163210 1. create a default qcfg
@@ -315,29 +362,13 @@ def qconfig_init(recipe: str = None, args: Any = None):
315362 # 2. load values from json, if specified and exists
316363 # this can be used to load a previously saved ckpt as well
317364 if recipe :
318- cwd = Path ().resolve ()
319- pkg_root = Path (__file__ ).parent .parent .resolve ()
320- file_in_cwd = cwd / recipe
321- file_in_recipes = pkg_root / "recipes" / recipe
322- file_in_recipes2 = pkg_root / "recipes" / f"{ recipe } .json"
323- temp_cfg = None
324-
325- if not recipe .endswith (".json" ) and file_in_recipes2 .exists ():
326- qcfg_json = file_in_recipes2
327- elif file_in_cwd .exists ():
328- qcfg_json = file_in_cwd
329- elif file_in_recipes .exists ():
330- qcfg_json = file_in_recipes
331- else :
332- qcfg_json = None
333-
334- if qcfg_json :
335- with open (qcfg_json , "r" , encoding = "utf-8" ) as openfile :
336- temp_cfg = json .load (openfile )
365+ # qcfg recipes should reside in fms_mo/recipes
366+ temp_cfg = get_recipe (recipe )
367+ if temp_cfg :
337368 qcfg .update (temp_cfg )
338- logger .info (
339- f"Loaded settings from { qcfg_json } and updated the default values."
340- )
369+ logger .info (f"Updated config w/ recipe values" )
370+ else :
371+ raise ValueError ( f"Config recipe { recipe } was not found." )
341372
342373 # 3. parse args, if provided
343374 if hasattr (args , "__dict__" ):
@@ -414,7 +445,7 @@ def serialize_config(config):
414445 return config , dump
415446
416447
417- def remove_unwanted_from_config (config , minimal :bool = True ):
448+ def remove_unwanted_from_config (config , minimal : bool = True ):
418449 """Remove deprecated items or things cannot be saved as text (json)"""
419450 unwanted_items = [
420451 "sweep_cv_percentile" ,
@@ -435,12 +466,10 @@ def remove_unwanted_from_config(config, minimal:bool=True):
435466 if minimal :
436467 default_config = config_defaults ()
437468 for key , val in config .items ():
469+ # If config has a default setting, add to unwanted items
438470 if key in default_config and default_config .get (key ) == val :
439471 unwanted_items .append (key )
440472
441- # Do not remove back pkg_versions
442- unwanted_items .remove ("pkg_versions" )
443-
444473 len_before = len (config )
445474 dump = {k : config .pop (k ) for k in unwanted_items if k in config }
446475 assert (
@@ -484,32 +513,67 @@ def add_required_defaults_to_config(config):
484513 config [key ] = default_val
485514
486515
487- def add_wanted_defaults_to_config (config ):
516+ def add_wanted_defaults_to_config (config , minimal : bool = True ):
488517 """Util function to add basic config defaults that are missing into a config
489518 if a wanted item is not in the config, add it w/ default value
490519 """
491- wanted_items = config_defaults ()
492- for wanted_name , wanted_default_val in wanted_items .items ():
493- if wanted_name not in config :
494- config [wanted_name ] = wanted_default_val
520+ if not minimal :
521+ wanted_items = config_defaults ()
522+ for wanted_name , wanted_default_val in wanted_items .items ():
523+ if wanted_name not in config :
524+ config [wanted_name ] = wanted_default_val
495525
496526
497- def qconfig_save (qcfg , minimal = False , fname = "qcfg.json" ):
527+ def qconfig_save (
528+ qcfg : dict , recipe : str = "save" , minimal : bool = True , fname = "qcfg.json"
529+ ):
498530 """
499531 Try to save qcfg into a JSON file (or use .pt format if something really can't be text-only).
500532 For example, qcfg['mapping'] has some classes as keys and values, json won't work. We will try
501533 to remove unserializable items first.
534+
535+ Args:
536+ qcfg (dict): Quantized config.
537+ recipe (str, optional): String name for a save recipe. Defaults to None.
538+ minimal (bool, optional): Save a minimal quantized config. Defaults to True.
539+ fname (str, optional): File name to save quantized config. Defaults to "qcfg.json".
502540 """
541+ recipe_items = []
542+
543+ # First check in qcfg for added save list
544+ if recipe in qcfg :
545+ recipe_items = qcfg [recipe ]
546+
547+ # Next, check in fms_mo/recipes and merge them into a unique set (in case they differ)
548+ recipe_items_json = get_recipe (recipe + ".json" )
549+ if recipe_items_json :
550+ recipe_items = list (set (recipe_items + recipe_items_json ))
551+
552+ # If we found recipe items to add, fetch them from qcfg
553+ if recipe_items :
554+ temp_qcfg = {}
555+ for key in recipe_items :
556+ if key in qcfg :
557+ temp_qcfg [key ] = qcfg [key ]
558+ else :
559+ raise ValueError (f"Desired save recipe { key = } not in qcfg!" )
503560
504- # Remove deprecated/unwanted key,vals in config
505- temp_qcfg , removed_items = remove_unwanted_from_config (qcfg , minimal )
561+ else :
562+ # We assume a full qcfg is being saved - trim it!
563+ temp_qcfg = deepcopy (qcfg )
506564
507- # Add back wanted defaults for any missing vars
508- if not minimal :
509- add_wanted_defaults_to_config (temp_qcfg )
565+ # Remove deprecated/unwanted key,vals in config
566+ temp_qcfg , _ = remove_unwanted_from_config (temp_qcfg , minimal )
510567
511- # Clean config of any unwanted key,vals not found in unwanted list
512- temp_qcfg , removed_items2 = serialize_config (temp_qcfg )
568+ # Add back wanted defaults for any missing vars
569+ add_wanted_defaults_to_config (temp_qcfg , minimal )
570+
571+ # Clean config of any unwanted key,vals not found in unwanted list
572+ temp_qcfg , _ = serialize_config (temp_qcfg )
573+
574+ # Add in date and system information for archival
575+ temp_qcfg ["date" ] = date .today ().strftime ("%Y-%B-%d" )
576+ temp_qcfg ["pkg_versions" ] = get_pkg_info ()
513577
514578 # Finally, check to ensure all values are valid before saving
515579 check_config (temp_qcfg )
@@ -521,10 +585,6 @@ def qconfig_save(qcfg, minimal=False, fname="qcfg.json"):
521585 with open (fname , "w" , encoding = "utf-8" ) as outfile :
522586 json .dump (temp_qcfg , outfile , indent = 4 )
523587
524- # restore original qcfg
525- qcfg .update (removed_items )
526- qcfg .update (removed_items2 )
527-
528588
529589def qconfig_load (fname = "qcfg.json" ):
530590 """Read config in json format, work together with qconfig_save"""
@@ -533,7 +593,7 @@ def qconfig_load(fname="qcfg.json"):
533593 config = json .load (openfile )
534594
535595 # Add back wanted defaults for any missing vars
536- add_wanted_defaults_to_config (config )
596+ add_wanted_defaults_to_config (config , minimal = False )
537597 add_required_defaults_to_config (config )
538598
539599 # Ensure config has correct values before continuing
@@ -850,8 +910,7 @@ def check_config(config, model_dtype=None):
850910 # clip_val_asst is the percentile to use for calibration. TODO: consider renaming
851911 clip_val_asst_percentile_default = default_config .get ("clip_val_asst_percentile" )
852912 clip_val_asst_percentile = config .get (
853- "clip_val_asst_percentile" ,
854- clip_val_asst_percentile_default
913+ "clip_val_asst_percentile" , clip_val_asst_percentile_default
855914 )
856915 if len (clip_val_asst_percentile ) != 2 :
857916 raise ValueError (
0 commit comments