2626
2727
2828def main (config ):
29+ # 从注册表拿模型并实例化
30+ # 动态分配模型
2931 model = MODEL_REGISTRY [config .model .type ](config )
3032
33+ # 打印模型和tokenizer
3134 logger .info (f'model: { model } ' )
3235 logger .info (f'tokenizer: { model .get_tokenizer ()} ' )
3336
37+ # 获得需要的评测种类
3438 eval_list = get_eval_list (model , config )
39+ # 真正执行评测
3540 eval_model (model , None , eval_list , eval_pos = 'pretrain' )
3641
3742 blockwise_opts = []
43+ # 取出处理模态
3844 modalities , modality_configs = get_modality (config )
3945
4046 for modality , modality_config in zip (modalities , modality_configs ):
4147 model .set_modality (modality )
4248 if not config .get ('calib' , False ):
49+ # 不需要校准数据 直接构造算法对象
4350 blockwise_opt = ALGO_REGISTRY [modality_config .method ](
4451 model ,
4552 modality_config ,
@@ -51,30 +58,54 @@ def main(config):
5158 blockwise_opts .append (blockwise_opt )
5259 dist .barrier ()
5360 else :
61+ # 需要校准数据
5462 dataset = BaseDataset (
5563 model .get_tokenizer (), config .calib , model .batch_process
5664 )
5765 calib_data , padding_mask = dataset .get_calib_dataset ()
66+ # 收集第一层block输入 为后续blockwise算法需要的输入缓存下来
5867 model .collect_first_block_input (calib_data , padding_mask )
5968 del calib_data
6069 gc .collect ()
6170 torch .cuda .empty_cache ()
71+ # 构造算法对象
6272 blockwise_opt = ALGO_REGISTRY [modality_config .method ](
6373 model ,
6474 modality_config ,
6575 model .get_first_block_input (),
6676 model .get_padding_mask (),
6777 config ,
6878 )
79+ # 项目逐层block做优化
6980 blockwise_opt .run_block_loop ()
7081 blockwise_opts .append (blockwise_opt )
7182 dist .barrier ()
7283
84+ # 对变化后的浮点模型做评测
7385 eval_model (model , blockwise_opts , eval_list , eval_pos = 'transformed' )
86+ # 只有rank 0继续做保存和导出
7487 if int (os .environ ['RANK' ]) == 0 :
88+ if 'save' in config and config .save .get ('save_calib_json' , False ):
89+ # 收集各个模态/量化器导出的校准结果。
90+ calib_json_list = [
91+ blockwise_opt .collect_calib_json ()
92+ for blockwise_opt in blockwise_opts
93+ if hasattr (blockwise_opt , 'collect_calib_json' )
94+ ]
95+ # 单模态时保持扁平结构,兼容 LightLLM 的校准文件格式。
96+ calib_json_payload = (
97+ calib_json_list [0 ] if len (calib_json_list ) == 1 else calib_json_list
98+ )
99+ # 将最终的校准 JSON 写入配置指定的输出路径。
100+ with open (save_calib_json_path , 'w' ) as file :
101+ json .dump (calib_json_payload , file , ensure_ascii = False , indent = 4 )
102+ logger .info (f'save calib json done -- { save_calib_json_path } ' )
103+
104+ # 保存变换后的浮点模型
75105 if 'save' in config and config .save .get ('save_trans' , False ):
76106 blockwise_opt .save_model (save_trans_path )
77107
108+ # 保存TensorRT-LLM格式并构建engine
78109 if 'save' in config and config .save .get ('save_trtllm' , False ):
79110 blockwise_opt .save_model (save_trtllm_trans_path )
80111 from llmc .utils .export_trtllm import cvt_trtllm_engine
@@ -88,22 +119,28 @@ def main(config):
88119 eval_model (model , blockwise_opts , eval_list , eval_pos = 'fake_quant' )
89120 eval_model (model , blockwise_opts , eval_list , eval_pos = 'fake_quant_wo_kv' )
90121
122+ # 切换到fake quant部署模式再保存
91123 if 'save' in config and config .save .get ('save_fake' , False ):
92124 deploy_all_modality (blockwise_opts , 'fake_quant' )
93125 blockwise_opt .save_model (save_fake_path )
94126
95127 if 'save' in config :
128+ # 导出真实量化模型给推理后端
96129 if (
130+ # 导出前进行遍历检查
97131 config .save .get ('save_vllm' , False )
98132 or config .save .get ('save_sgl' , False )
99133 or config .save .get ('save_lightllm' , False )
100134 ):
101135 for modality_config in modality_configs :
102136 w , a = modality_config .weight , modality_config .get ('act' )
103137
138+ # 只允许特定bit类型
104139 if isinstance (w .bit , str ):
140+ # 必须对称量化
105141 assert w .symmetric , 'Only symmetric quant is supported.'
106142 assert w .bit in ['e4m3' , 'e3m4' ], 'Supported quant: w8a16.'
143+ # 有激活量化的话,那激活也要满足对称、bit合法的要求
107144 if a :
108145 assert (
109146 w .symmetric and a .symmetric
@@ -114,6 +151,7 @@ def main(config):
114151 and a .bit in ['e4m3' , 'e5m2' ]
115152 ), 'Only WA FP8 quant is supported'
116153 else :
154+ # 是整数则必须是4 or 8
117155 assert w .symmetric , 'Only symmetric quant is supported.'
118156 assert w .bit in [4 , 8 ], 'Supported quant: w4a16, w8a16, w8a8.'
119157 if a :
@@ -130,12 +168,15 @@ def main(config):
130168 blockwise_opt .save_model (save_quant_path )
131169 update_vllm_quant_config (blockwise_opt .model , config , save_quant_path )
132170
171+ # 给特定后端(AutoAWQ导出
133172 elif config .save .get ('save_autoawq' , False ):
134173 for modality_config in modality_configs :
174+ # 只能4 bit 仅含有weight 不支持act
135175 assert (
136176 modality_config .weight .bit in [4 ] and 'act' not in modality_config
137177 ), 'AutoAWQ supports only 4-bit weight-only quantization.'
138178 assert (
179+ # 不能对称量化
139180 not modality_config .weight .symmetric
140181 ), 'Only asymmetric quant is supported.'
141182
@@ -161,18 +202,23 @@ def main(config):
161202 blockwise_opt .save_model (save_quant_path )
162203 update_lightx2v_quant_config (save_quant_path )
163204
205+ # 判断是否有opencompass
164206 if 'opencompass' in config :
165207 assert config .save .get ('save_trans' , False )
208+ # 从配置里读取cfg_path, output_path
166209 cfg_path = config ['opencompass' ]['cfg_path' ]
167210 output_path = config ['opencompass' ]['output_path' ]
211+ # 取路径
168212 eval_model_path = os .path .abspath (save_trans_path )
213+ # 拼指令
169214 opencompass_cmd = (
170215 f'opencompass { cfg_path } -w { output_path } '
171216 f'--llmc_cfg { args .config } '
172217 f'--llmc_eval_mode quant '
173218 f'--llmc_model_path { eval_model_path } '
174219 )
175220 logger .info (f'opencompass_cmd : { opencompass_cmd } ' )
221+ # 执行
176222 os .system (opencompass_cmd )
177223 dist .barrier ()
178224
@@ -181,20 +227,25 @@ def main(config):
181227 logger .add (sys .stdout , level = 'INFO' )
182228 llmc_start_time = time .time ()
183229 parser = argparse .ArgumentParser ()
230+ # 解析命令行参数
184231 parser .add_argument ('--config' , type = str , required = True )
185232 parser .add_argument ('--task_id' , type = str , required = True )
186233 args = parser .parse_args ()
187234
188235 with open (args .config , 'r' ) as file :
236+ # 读取配置文件
189237 config = yaml .safe_load (file )
190238 config = EasyDict (config )
191239
192240 init_process_group (backend = 'nccl' )
241+ # 初始化分布式环境 设置GPU
193242 torch .cuda .set_device (int (os .environ ['LOCAL_RANK' ]))
194243
244+ # 检查配置 打印依赖版本
195245 if int (os .environ ['RANK' ]) != 0 :
196246 logger .remove ()
197247
248+ # 检查配置是否合法
198249 check_config (config )
199250
200251 logger .info (f'args: { args } ' )
@@ -209,6 +260,12 @@ def main(config):
209260 # Ensure only the main process creates directories
210261 if int (os .environ ['RANK' ]) == 0 :
211262 if 'save' in config :
263+ if config .save .get ('save_calib_json' , False ):
264+ mkdirs (config .save .save_path )
265+ save_calib_json_path = os .path .join (
266+ config .save .save_path ,
267+ config .save .get ('calib_json_name' , 'calib_scales.json' ),
268+ )
212269 if config .save .get ('save_trans' , False ):
213270 save_trans_path = os .path .join (
214271 config .save .save_path , 'transformed_model'
@@ -266,3 +323,4 @@ def main(config):
266323 llmc_duration_time = llmc_end_time - llmc_start_time
267324 logger .info (f'llmc_duration_time: { llmc_duration_time } s' )
268325 logger .info ('--- llmc finished ---' )
326+
0 commit comments