2121# Standard
2222from pathlib import Path
2323import logging
24- import os
25- import sys
2624
2725# Third Party
2826from datasets import load_from_disk
5250from fms_mo .utils .dq_inf import (
5351 check_quantization_setting ,
5452 convert_fp8_vllm_to_fms_mo ,
53+ load_inference_qconfig_file ,
5554 save_vllm_fp8 ,
5655)
5756from fms_mo .utils .dq_utils import config_quantize_smooth_layers
@@ -134,18 +133,6 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
134133 low_cpu_mem_usage = bool (model_args .device_map ),
135134 )
136135
137- inference_qconfig = None
138- if hasattr (model , "config" ):
139- inference_qconfig = model .config .to_dict ().get ("quantization_config" , None )
140-
141- if inference_qconfig :
142- quant_setting = check_quantization_setting (inference_qconfig )
143- if quant_setting :
144- logger .info ("Quantization config settings validated " )
145- model = convert_fp8_vllm_to_fms_mo (model = model )
146- else :
147- sys .exit ("Error: This quantization config is wrong/not supported" )
148-
149136 embedding_size = model .get_input_embeddings ().weight .shape [0 ]
150137 if len (tokenizer ) > embedding_size :
151138 model .resize_token_embeddings (len (tokenizer ))
@@ -154,29 +141,17 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
154141 logger .info (f"Model is at { model .device } after intialization" )
155142 logger .info (f"Tokenizer is { tokenizer } , block size is { block_size } " )
156143
157- if not inference_qconfig :
144+ quant_mode = check_quantization_setting (model )
145+
146+ if not quant_mode :
158147 logger .info ("quantization mode activated, initalizing the qcfg file " )
159148 qcfg = qconfig_init (recipe = "dq" , args = fms_mo_args )
160149 else :
161150 logger .info ("inference mode activated" )
162- if os .path .isfile (model_args .model_name_or_path + "/qcfg.json" ):
163- if fms_mo_args .override_fms_args :
164- logger .info (
165- "qcfg file found and some parameters are being over-written "
166- )
167- qcfg = qconfig_init (
168- recipe = model_args .model_name_or_path + "/qcfg" , args = fms_mo_args
169- )
170- else :
171- logger .info ("qcfg file found, loading the qcfg file " )
172- qcfg = qconfig_init (recipe = model_args .model_name_or_path + "/qcfg" )
173- else :
174- logger .info (
175- "qcfg file not found in {model_args.model_name_or_path},\
176- loading fms_mo_args and recipe"
177- )
178- qcfg = qconfig_init (recipe = "dq" , args = fms_mo_args )
179- qcfg ["fp8_inference" ] = True
151+ qcfg = load_inference_qconfig_file (model_args , fms_mo_args )
152+
153+ if quant_mode :
154+ model = convert_fp8_vllm_to_fms_mo (model = model )
180155
181156 model_size = model_size_Wb (model , unit = "GB" )
182157 gpu_mem_util_per = model_size / total_gpu_memory
@@ -201,7 +176,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
201176
202177 qcfg ["model" ] = model_args .model_name_or_path
203178 # config layers to skip, smooth scale
204- if not inference_qconfig :
179+ if not quant_mode :
205180 config_quantize_smooth_layers (qcfg )
206181
207182 use_dynamo = True
@@ -234,7 +209,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
234209 )
235210
236211 # For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well.
237- if not inference_qconfig and qcfg ["smoothq" ]:
212+ if not quant_mode and qcfg ["smoothq" ]:
238213 scale_file = Path (f"./act_scales/{ qcfg ['model' ].replace ('/' , '-' )} .pt" )
239214 if qcfg .get ("act_scale_path" , None ):
240215 # user provided a scale file (or a dir)
@@ -272,7 +247,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
272247 )
273248 logger .info (f"Quantized model { model } " )
274249 logger .info ("==" * 20 )
275- if not inference_qconfig :
250+
251+ if not quant_mode :
276252 if qcfg ["smoothq" ]:
277253 logger .info ("Starting to apply smooth scale" )
278254 dq_llm (model , act_scales , qcfg )
0 commit comments