|
20 | 20 | from llmc.models import * |
21 | 21 | from llmc.utils import (check_config, deploy_all_modality, get_modality, |
22 | 22 | mkdirs, print_important_package_version, seed_all, |
| 23 | + collect_lightllm_kv_calib_json, |
23 | 24 | update_autoawq_quant_config, |
24 | 25 | update_lightx2v_quant_config, update_vllm_quant_config) |
25 | 26 | from llmc.utils.registry_factory import ALGO_REGISTRY, MODEL_REGISTRY |
@@ -72,6 +73,21 @@ def main(config): |
72 | 73 |
|
73 | 74 | eval_model(model, blockwise_opts, eval_list, eval_pos='transformed') |
74 | 75 | if int(os.environ['RANK']) == 0: |
| 76 | + if 'save' in config and config.save.get('save_lightllm_kv_cache_calib', False): |
| 77 | + calib_json_list = [ |
| 78 | + collect_lightllm_kv_calib_json(blockwise_opt) |
| 79 | + for blockwise_opt in blockwise_opts |
| 80 | + if hasattr(blockwise_opt, 'quant_kvcache') |
| 81 | + ] |
| 82 | + calib_json_payload = ( |
| 83 | + calib_json_list[0] if len(calib_json_list) == 1 else calib_json_list |
| 84 | + ) |
| 85 | + with open(save_lightllm_kv_cache_calib_path, 'w') as file: |
| 86 | + json.dump(calib_json_payload, file, ensure_ascii=False, indent=4) |
| 87 | + logger.info( |
| 88 | + f'save lightllm kv cache calib done -- {save_lightllm_kv_cache_calib_path}' |
| 89 | + ) |
| 90 | + |
75 | 91 | if 'save' in config and config.save.get('save_trans', False): |
76 | 92 | blockwise_opt.save_model(save_trans_path) |
77 | 93 |
|
@@ -209,6 +225,14 @@ def main(config): |
209 | 225 | # Ensure only the main process creates directories |
210 | 226 | if int(os.environ['RANK']) == 0: |
211 | 227 | if 'save' in config: |
| 228 | + if config.save.get('save_lightllm_kv_cache_calib', False): |
| 229 | + mkdirs(config.save.save_path) |
| 230 | + save_lightllm_kv_cache_calib_path = os.path.join( |
| 231 | + config.save.save_path, |
| 232 | + config.save.get( |
| 233 | + 'lightllm_kv_cache_calib_name', 'kv_cache_calib.json' |
| 234 | + ), |
| 235 | + ) |
212 | 236 | if config.save.get('save_trans', False): |
213 | 237 | save_trans_path = os.path.join( |
214 | 238 | config.save.save_path, 'transformed_model' |
@@ -266,3 +290,4 @@ def main(config): |
266 | 290 | llmc_duration_time = llmc_end_time - llmc_start_time |
267 | 291 | logger.info(f'llmc_duration_time: {llmc_duration_time} s') |
268 | 292 | logger.info('--- llmc finished ---') |
| 293 | + |
0 commit comments