1818from datetime import date
1919from importlib .metadata import version
2020from pathlib import Path
21- from typing import Any
21+ from typing import Any , Union
2222import json
2323import logging
2424import os
@@ -113,6 +113,7 @@ def config_defaults() -> dict:
113113 "qkvsync" : False ,
114114 "extend_act_range" : False ,
115115 "plotsvg" : False ,
116+ "qskip_large_mag_layers" : False ,
116117 # Iterable vars
117118 "qlayer_name_pattern" : [],
118119 "qskip_layer_name" : [],
@@ -141,21 +142,24 @@ def config_defaults() -> dict:
141142 "temp_disable_calib" : False ,
142143 "org_batch_size" : {},
143144 "ptqmod_to_be_optimized" : [],
145+ # SmoothQuant vars
146+ "smoothq" : False ,
147+ "smoothq_scale_layers" : [],
148+ "smoothq_act_scale_path" : None ,
144149 # Other vars
145150 "which2patch_contextmanager" : None ,
146151 "force_stop_if_qbmm_auto_check_failed" : False ,
147152 "world_size" : max (1 , torch .cuda .device_count ()),
148153 "global_rank" : 0 ,
149154 "batch_size" : 2 ,
155+ "keys_to_save" : [],
150156 # items could be obsoleted
151157 "output_attentions" : False ,
152158 "bias_corr" : False ,
153159 "qwav2vec" : False ,
154160 "qvit" : False ,
155161 "numparamsfromloadertomodel" : 1 ,
156162 "gradclip" : 0.0 ,
157- "smoothq" : False ,
158- "keys_to_save" : [],
159163 }
160164
161165 return cfg_defaults
@@ -200,7 +204,7 @@ def find_recipe_json(recipe: str, subdir: str = None) -> Path:
200204 return json_file
201205
202206
203- def get_recipe (recipe : str , subdir : str = None ) -> Any :
207+ def get_recipe (recipe : str , subdir : str = None ) -> Union [ list , dict ] :
204208 """
205209 Get a json recipe.
206210
@@ -218,6 +222,10 @@ def get_recipe(recipe: str, subdir: str = None) -> Any:
218222 temp_data = json .load (openfile )
219223 logger .info (f"Loaded settings from { json_file } ." )
220224
225+ # Any recipe should be a dict (qcfg) or list (keys_to_save)
226+ if not isinstance (temp_data , (dict , list )):
227+ raise ValueError (f"Loaded recipe { json_file } was not a dict or list" )
228+
221229 return temp_data
222230
223231
@@ -376,8 +384,14 @@ def qconfig_init(recipe: str = None, args: Any = None) -> dict:
376384 # this can be used to load a previously saved ckpt as well
377385 if recipe :
378386 # qcfg recipes should reside in fms_mo/recipes
379- temp_cfg = get_recipe (recipe )
387+ temp_cfg = qconfig_load (recipe )
388+
380389 if temp_cfg :
390+ if not isinstance (temp_cfg , dict ):
391+ raise ValueError (
392+ f"Quantized config recipe={ recipe } is not a dictionary"
393+ )
394+
381395 qcfg .update (temp_cfg )
382396 logger .info ("Updated config with recipe values" )
383397 else :
@@ -560,7 +574,12 @@ def qconfig_save(
560574
561575 # Next, check in fms_mo/recipes and merge them into a unique set (in case they differ)
562576 keys_to_save_json = get_recipe (recipe )
577+
563578 if keys_to_save_json :
579+ if not isinstance (keys_to_save_json , list ):
580+ raise ValueError (f"Save recipe={ recipe } is not a list!" )
581+
582+ # Merge keys_to_save lists
564583 keys_to_save = list (set (keys_to_save + keys_to_save_json ))
565584
566585 # If we found keys to save, fetch them from qcfg
@@ -602,9 +621,12 @@ def qconfig_save(
602621
603622def qconfig_load (fname : str = "qcfg.json" ) -> dict :
604623 """Read config in json format, work together with qconfig_save"""
605- if os .path .isfile (fname ):
606- with open (fname , "r" , encoding = "utf-8" ) as openfile :
607- config = json .load (openfile )
624+ config = get_recipe (fname )
625+
626+ if config :
627+ # Check that loaded file is a dict
628+ if not isinstance (config , dict ):
629+ raise ValueError (f"Quantized config={ fname } is not a dictionary" )
608630
609631 # Add back wanted defaults for any missing vars
610632 add_wanted_defaults_to_config (config , minimal = False )
@@ -854,6 +876,8 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
854876 "plotsvg" ,
855877 "ptq_freezecvs" ,
856878 "ptq_qdrop" ,
879+ "qskip_large_mag_layers" ,
880+ "smoothq" ,
857881 ]
858882 for boolean_var_str in boolean_vars_str :
859883 boolean_var = config .get (
@@ -910,6 +934,7 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
910934 "firstptqmodule" ,
911935 "params2optim" ,
912936 "clip_val_asst_percentile" ,
937+ "smoothq_scale_layers" ,
913938 ]
914939 for iterable_var_str in iterable_vars_str :
915940 iterable_var_default = default_config .get (iterable_var_str )
@@ -988,3 +1013,7 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
9881013 f"which2patch_contextmanager = { which2patch_contextmanager } is not one of "
9891014 f"the following: { which2patch_contextmanager_settings } "
9901015 )
1016+
1017+ smoothq_act_scale_path = config .get ("smoothq_act_scale_path" , None )
1018+ if smoothq_act_scale_path and not smoothq_act_scale_path .endswith (".pt" ):
1019+ raise ValueError (f"{ smoothq_act_scale_path = } is not a .pt checkpoint" )
0 commit comments