Skip to content

Commit 50e9a4b

Browse files
authored
Merge pull request #460 from Michael20070814/Repair_Quant
Add the perhead path and debug some
2 parents ac6c4c2 + 4f6157a commit 50e9a4b

File tree

9 files changed

+206
-18
lines changed

9 files changed

+206
-18
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,5 @@ save*
2222
.log
2323
*.pid
2424
*.ipynb*
25+
.venv/
26+
*.sh
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
base:
2+
seed: &seed 42
3+
model:
4+
type: model_type
5+
path: model path
6+
torch_dtype: auto
7+
calib:
8+
name: pileval
9+
download: False
10+
path: calib data path
11+
n_samples: 128
12+
bs: 1
13+
seq_len: 2048
14+
preproc: txt_general_preproc
15+
seed: *seed
16+
eval:
17+
eval_pos: [transformed, fake_quant, fake_quant_wo_kv] #long_ppl eval not support pretrain eval pos
18+
name: wikitext2
19+
type: decode_ppl
20+
download: False
21+
path: eval_data_path
22+
bs: 1
23+
inference_per_block: False
24+
num_samples: 10
25+
# num_eval_tokens: 3
26+
quant:
27+
method: RTN
28+
weight:
29+
bit: 8
30+
symmetric: True
31+
granularity: per_channel
32+
group_size: -1
33+
act:
34+
bit: 8
35+
symmetric: True
36+
granularity: per_tensor
37+
static: True
38+
kvcache:
39+
method: Naive
40+
bit: 8
41+
symmetric: True
42+
granularity: per_head
43+
head_num: kv head num
44+
save:
45+
save_lightllm_kv_calib: True
46+
lightllm_kv_cache_name: kv_cache_calib.json
47+
save_fake: False
48+
save_path: /path/to/save/

configs/quantization/methods/KVQuant/rtn_w_a_pertensor_static_naive_quant_kv.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,7 @@ quant:
4141
symmetric: True
4242
granularity: per_tensor
4343
save:
44+
save_lightllm_kv_calib: True
45+
lightllm_kv_cache_name: kv_cache_calib.json
4446
save_fake: False
45-
save_path: /path/to/save/
47+
save_path: /path/to/save/

llmc/__main__.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from llmc.models import *
2121
from llmc.utils import (check_config, deploy_all_modality, get_modality,
2222
mkdirs, print_important_package_version, seed_all,
23+
collect_lightllm_kv_calib_json,
2324
update_autoawq_quant_config,
2425
update_lightx2v_quant_config, update_vllm_quant_config)
2526
from llmc.utils.registry_factory import ALGO_REGISTRY, MODEL_REGISTRY
@@ -72,6 +73,21 @@ def main(config):
7273

7374
eval_model(model, blockwise_opts, eval_list, eval_pos='transformed')
7475
if int(os.environ['RANK']) == 0:
76+
if 'save' in config and config.save.get('save_lightllm_kv_cache_calib', False):
77+
calib_json_list = [
78+
collect_lightllm_kv_calib_json(blockwise_opt)
79+
for blockwise_opt in blockwise_opts
80+
if hasattr(blockwise_opt, 'quant_kvcache')
81+
]
82+
calib_json_payload = (
83+
calib_json_list[0] if len(calib_json_list) == 1 else calib_json_list
84+
)
85+
with open(save_lightllm_kv_cache_calib_path, 'w') as file:
86+
json.dump(calib_json_payload, file, ensure_ascii=False, indent=4)
87+
logger.info(
88+
f'save lightllm kv cache calib done -- {save_lightllm_kv_cache_calib_path}'
89+
)
90+
7591
if 'save' in config and config.save.get('save_trans', False):
7692
blockwise_opt.save_model(save_trans_path)
7793

@@ -209,6 +225,14 @@ def main(config):
209225
# Ensure only the main process creates directories
210226
if int(os.environ['RANK']) == 0:
211227
if 'save' in config:
228+
if config.save.get('save_lightllm_kv_cache_calib', False):
229+
mkdirs(config.save.save_path)
230+
save_lightllm_kv_cache_calib_path = os.path.join(
231+
config.save.save_path,
232+
config.save.get(
233+
'lightllm_kv_cache_calib_name', 'kv_cache_calib.json'
234+
),
235+
)
212236
if config.save.get('save_trans', False):
213237
save_trans_path = os.path.join(
214238
config.save.save_path, 'transformed_model'
@@ -266,3 +290,4 @@ def main(config):
266290
llmc_duration_time = llmc_end_time - llmc_start_time
267291
logger.info(f'llmc_duration_time: {llmc_duration_time} s')
268292
logger.info('--- llmc finished ---')
293+

llmc/compression/quantization/base_blockwise_quantization.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import copy
22
import functools
33
import gc
4-
import json
54
import os
65
import re
76
from collections import defaultdict
@@ -175,13 +174,18 @@ def set_quant_config(self):
175174
self.act_quant_module = IntegerQuantizer
176175
elif quant_type == 'float-quant':
177176
self.act_quant_module = FloatQuantizer
178-
self.quant_config['act']['tp'] = self.tp
179-
self.aquantizer = self.act_quant_module(**self.quant_config['act'])
180177
self.act_static = self.quant_config['act'].get('static', False)
181178
if self.act_static:
182179
assert (
183180
self.quant_config['act']['granularity'] == 'per_tensor'
184181
), 'Only support per_tensor static quant'
182+
# Static activation quantization uses the batched calibration
183+
# path, so normalize the default minmax setting to
184+
# static_minmax to match the downstream calibration logic.
185+
if self.quant_config['act'].get('calib_algo', 'minmax') == 'minmax':
186+
self.quant_config['act']['calib_algo'] = 'static_minmax'
187+
self.quant_config['act']['tp'] = self.tp
188+
self.aquantizer = self.act_quant_module(**self.quant_config['act'])
185189
self.quant_attn = self.quant_config['act'].get('quant_attn', False)
186190
if self.quant_attn:
187191
assert self.config['model']['type'] in ['Vit', 'DeepseekV2']
@@ -203,8 +207,10 @@ def set_quant_config(self):
203207
kv_special_cfg = self.quant_config['kvcache'].get('special', {})
204208
act_static_cfg = {}
205209
if self.act_static:
206-
act_static_cfg.update(self.config.calib.n_sample)
207-
act_static_cfg.update(self.config.calib.bs)
210+
# The KV cache constructor expects num_samples / bsz, so map
211+
# the calibration config fields to the parameter names it uses.
212+
act_static_cfg['num_samples'] = self.config.calib.n_samples
213+
act_static_cfg['bsz'] = self.config.calib.bs
208214
kv_quant_type = self.quant_config['kvcache'].get('quant_type', 'int-quant')
209215
self.kv_module = KV_REGISTRY[self.quant_config['kvcache']['method']](
210216
kv_quant_type, self.quant_config['kvcache'],

llmc/compression/quantization/kvquant.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import torch
23
from loguru import logger
34
from transformers import DynamicCache
@@ -12,12 +13,20 @@ class NaiveQuantKVCache(DynamicCache):
1213
def __init__(self, quant_type, kvquant_cfg, num_hidden_layers, num_samples=128, bsz=1):
1314
super().__init__()
1415

15-
assert kvquant_cfg.granularity in ['per_token', 'per_tensor', 'per_group']
16+
# Copy the config to avoid mutating the original quantization config in static KV calibration.
17+
kvquant_cfg = copy.deepcopy(kvquant_cfg)
18+
assert kvquant_cfg.granularity in ['per_token', 'per_tensor', 'per_group', 'per_head']
1619
self.num_hidden_layers, self.num_samples, self.bsz = (
1720
num_hidden_layers,
1821
num_samples,
1922
bsz,
2023
)
24+
if kvquant_cfg.get('static', False) and kvquant_cfg.get(
25+
'calib_algo', 'minmax'
26+
) == 'minmax':
27+
# Static KV calibration uses the batched tensor statistics path, so convert the default
28+
# minmax setting to static_minmax here to avoid a later calibration algo name mismatch.
29+
kvquant_cfg['calib_algo'] = 'static_minmax'
2130
if quant_type == 'int-quant':
2231
self.kvquantizer = IntegerQuantizer(**kvquant_cfg)
2332
elif quant_type == 'float-quant':

llmc/compression/quantization/quant.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -224,27 +224,24 @@ def get_minmax_stats(self, act_tensors):
224224
for tensor in tensors:
225225
tensor = self.reshape_tensor(tensor)
226226
tensor_range = self.get_minmax_range(tensor)
227-
min_val, max_val = tensor_range[0], tensor_range[1]
227+
min_val = tensor_range[0].detach().cpu().to(torch.float32)
228+
max_val = tensor_range[1].detach().cpu().to(torch.float32)
228229

229230
if input_idx not in stats_min_max:
230231
stats_min_max[input_idx] = {}
231-
stats_min_max[input_idx]['min'] = torch.tensor(
232-
[min_val], dtype=torch.float32
233-
)
234-
stats_min_max[input_idx]['max'] = torch.tensor(
235-
[max_val], dtype=torch.float32
236-
)
232+
stats_min_max[input_idx]['min'] = min_val.unsqueeze(0)
233+
stats_min_max[input_idx]['max'] = max_val.unsqueeze(0)
237234
else:
238235
stats_min_max[input_idx]['min'] = torch.cat(
239236
[
240237
stats_min_max[input_idx]['min'],
241-
torch.tensor([min_val], dtype=torch.float32),
238+
min_val.unsqueeze(0),
242239
]
243240
)
244241
stats_min_max[input_idx]['max'] = torch.cat(
245242
[
246243
stats_min_max[input_idx]['max'],
247-
torch.tensor([max_val], dtype=torch.float32),
244+
max_val.unsqueeze(0),
248245
]
249246
)
250247

@@ -255,8 +252,8 @@ def get_static_minmax_range(self, act_tensors):
255252
stats_min_max = self.get_minmax_stats(act_tensors)
256253
min_vals, max_vals = [], []
257254
for input_idx, tensor_range in stats_min_max.items():
258-
min_val = tensor_range['min'].mean()
259-
max_val = tensor_range['max'].mean()
255+
min_val = tensor_range['min'].mean(dim=0)
256+
max_val = tensor_range['max'].mean(dim=0)
260257
min_vals.append(min_val)
261258
max_vals.append(max_val)
262259

llmc/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .export_autoawq import update_autoawq_quant_config
2+
from .export_calib import collect_lightllm_kv_calib_json
23
from .export_lightx2v import update_lightx2v_quant_config
34
from .export_vllm import update_vllm_quant_config
45
from .utils import (check_config, copy_files, deploy_all_modality,

llmc/utils/export_calib.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import torch
2+
3+
4+
def _to_jsonable(value):
5+
if isinstance(value, torch.Tensor):
6+
return value.detach().cpu().tolist()
7+
return value
8+
9+
10+
def _to_tensor(value, dtype=torch.float32):
11+
if isinstance(value, torch.Tensor):
12+
return value.detach().cpu().to(dtype)
13+
return torch.as_tensor(value, dtype=dtype)
14+
15+
16+
def _collect_lightllm_kv_scale(scales, zeros, qmin, qmax):
17+
if isinstance(scales, torch.Tensor) and scales.numel() == 0:
18+
return None
19+
20+
scales_tensor = _to_tensor(scales)
21+
zeros_tensor = _to_tensor(zeros, dtype=scales_tensor.dtype)
22+
qmin_tensor = _to_tensor(qmin, dtype=scales_tensor.dtype)
23+
qmax_tensor = _to_tensor(qmax, dtype=scales_tensor.dtype)
24+
min_tensor = (qmin_tensor - zeros_tensor) * scales_tensor
25+
max_tensor = (qmax_tensor - zeros_tensor) * scales_tensor
26+
absmax_tensor = torch.maximum(min_tensor.abs(), max_tensor.abs())
27+
fp8_qmax = torch.tensor(
28+
torch.finfo(torch.float8_e4m3fn).max, dtype=absmax_tensor.dtype
29+
)
30+
return absmax_tensor / fp8_qmax
31+
32+
33+
def collect_lightllm_kv_calib_json(blockwise_opt):
34+
if not getattr(blockwise_opt, 'quant_kvcache', False):
35+
raise ValueError(
36+
'save_lightllm_kv_cache_calib requires kvcache quantization.'
37+
)
38+
39+
kv_cfg = blockwise_opt.quant_config['kvcache']
40+
granularity = kv_cfg.get('granularity')
41+
if granularity not in ['per_tensor', 'per_head']:
42+
raise ValueError(
43+
f'LightLLM calib export only supports per_tensor/per_head, got {granularity}'
44+
)
45+
46+
num_layers = blockwise_opt.model.model_config.num_hidden_layers
47+
num_head = int(
48+
getattr(
49+
blockwise_opt.model.model_config,
50+
'num_key_value_heads',
51+
blockwise_opt.model.get_num_attention_heads(),
52+
)
53+
)
54+
scales = []
55+
for layer_idx in range(num_layers):
56+
key_scale = _collect_lightllm_kv_scale(
57+
blockwise_opt.kv_module.k_scales_buffer[layer_idx],
58+
blockwise_opt.kv_module.k_zeros_buffer[layer_idx],
59+
blockwise_opt.kv_module.k_qmin_buffer[layer_idx],
60+
blockwise_opt.kv_module.k_qmax_buffer[layer_idx],
61+
)
62+
value_scale = _collect_lightllm_kv_scale(
63+
blockwise_opt.kv_module.v_scales_buffer[layer_idx],
64+
blockwise_opt.kv_module.v_zeros_buffer[layer_idx],
65+
blockwise_opt.kv_module.v_qmin_buffer[layer_idx],
66+
blockwise_opt.kv_module.v_qmax_buffer[layer_idx],
67+
)
68+
if key_scale is None or value_scale is None:
69+
raise ValueError(f'Calibration scale for layer {layer_idx} is empty.')
70+
71+
scale_row = torch.cat([key_scale.reshape(-1), value_scale.reshape(-1)]).tolist()
72+
scales.append(scale_row)
73+
74+
scale_width = len(scales[0]) if scales else 0
75+
if granularity == 'per_tensor' and scale_width != 2:
76+
raise ValueError(f'per_tensor export expects 2 scales per layer, got {scale_width}')
77+
if granularity == 'per_head' and scale_width != num_head * 2:
78+
raise ValueError(
79+
f'per_head export expects {num_head * 2} scales per layer, got {scale_width}'
80+
)
81+
82+
architectures = getattr(blockwise_opt.model.model_config, 'architectures', None)
83+
if isinstance(architectures, list) and len(architectures) > 0:
84+
architectures = architectures[0]
85+
elif architectures is None:
86+
architectures = blockwise_opt.config.model.type
87+
88+
return {
89+
'version': '1.0',
90+
'architectures': architectures,
91+
'quant_type': granularity,
92+
'qmin': float(torch.finfo(torch.float8_e4m3fn).min),
93+
'qmax': float(torch.finfo(torch.float8_e4m3fn).max),
94+
'num_layers': num_layers,
95+
'num_head': num_head,
96+
'scales_shape': [num_layers, scale_width],
97+
'scales': _to_jsonable(scales),
98+
}

0 commit comments

Comments
 (0)