Skip to content

Commit 4f6157a

Browse files
make the quant only do what it should do, more moduled
1 parent 4a1b6ea commit 4f6157a

File tree

3 files changed

+35
-46
lines changed

3 files changed

+35
-46
lines changed

llmc/__main__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from llmc.models import *
2121
from llmc.utils import (check_config, deploy_all_modality, get_modality,
2222
mkdirs, print_important_package_version, seed_all,
23+
collect_lightllm_kv_calib_json,
2324
update_autoawq_quant_config,
2425
update_lightx2v_quant_config, update_vllm_quant_config)
2526
from llmc.utils.registry_factory import ALGO_REGISTRY, MODEL_REGISTRY
@@ -74,9 +75,9 @@ def main(config):
7475
if int(os.environ['RANK']) == 0:
7576
if 'save' in config and config.save.get('save_lightllm_kv_cache_calib', False):
7677
calib_json_list = [
77-
blockwise_opt.collect_calib_json()
78+
collect_lightllm_kv_calib_json(blockwise_opt)
7879
for blockwise_opt in blockwise_opts
79-
if hasattr(blockwise_opt, 'collect_calib_json')
80+
if hasattr(blockwise_opt, 'quant_kvcache')
8081
]
8182
calib_json_payload = (
8283
calib_json_list[0] if len(calib_json_list) == 1 else calib_json_list

llmc/compression/quantization/base_blockwise_quantization.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import copy
22
import functools
33
import gc
4-
import json
54
import os
65
import re
76
from collections import defaultdict
@@ -12,7 +11,6 @@
1211
import torch.nn as nn
1312
from loguru import logger
1413

15-
from llmc.utils.export_calib import collect_lightllm_kv_calib_json
1614
from llmc.utils.registry_factory import KV_REGISTRY, TOKEN_REDUCTION_REGISTRY
1715

1816
from ..blockwise_optimization import BlockwiseOpt
@@ -1011,45 +1009,6 @@ def contiguous_params(self):
10111009
if not param.is_contiguous():
10121010
param.data = param.data.contiguous()
10131011

1014-
# Convert tensors and similar objects into Python values that can be
1015-
# directly serialized into JSON.
1016-
def _to_jsonable(self, value):
1017-
if isinstance(value, torch.Tensor):
1018-
return value.detach().cpu().tolist()
1019-
return value
1020-
1021-
# Normalize inputs into CPU tensors so the following range computation
1022-
# and serialization logic can handle them consistently.
1023-
def _to_tensor(self, value, dtype=torch.float32):
1024-
if isinstance(value, torch.Tensor):
1025-
return value.detach().cpu().to(dtype)
1026-
return torch.as_tensor(value, dtype=dtype)
1027-
1028-
# LightLLM expects offline FP8 KV descales. Recover the real-value range
1029-
# from the qparams first, then convert it into per-layer K/V scales that
1030-
# align with torch.float8_e4m3fn.
1031-
def _collect_lightllm_kv_scale(self, scales, zeros, qmin, qmax):
1032-
if isinstance(scales, torch.Tensor) and scales.numel() == 0:
1033-
return None
1034-
1035-
scales_tensor = self._to_tensor(scales)
1036-
zeros_tensor = self._to_tensor(zeros, dtype=scales_tensor.dtype)
1037-
qmin_tensor = self._to_tensor(qmin, dtype=scales_tensor.dtype)
1038-
qmax_tensor = self._to_tensor(qmax, dtype=scales_tensor.dtype)
1039-
min_tensor = (qmin_tensor - zeros_tensor) * scales_tensor
1040-
max_tensor = (qmax_tensor - zeros_tensor) * scales_tensor
1041-
absmax_tensor = torch.maximum(min_tensor.abs(), max_tensor.abs())
1042-
fp8_qmax = torch.tensor(
1043-
torch.finfo(torch.float8_e4m3fn).max, dtype=absmax_tensor.dtype
1044-
)
1045-
return absmax_tensor / fp8_qmax
1046-
1047-
# Export calibration results in the LightLLM kv_cache_calib.json format.
1048-
# At the moment, only the per_tensor and per_head KV formats supported by
1049-
# LightLLM are handled here.
1050-
def collect_calib_json(self):
1051-
return collect_lightllm_kv_calib_json(self)
1052-
10531012
@torch.no_grad()
10541013
def save_model(self, path):
10551014
if int(os.environ['RANK']) != 0:

llmc/utils/export_calib.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,35 @@
11
import torch
22

33

4+
def _to_jsonable(value):
5+
if isinstance(value, torch.Tensor):
6+
return value.detach().cpu().tolist()
7+
return value
8+
9+
10+
def _to_tensor(value, dtype=torch.float32):
11+
if isinstance(value, torch.Tensor):
12+
return value.detach().cpu().to(dtype)
13+
return torch.as_tensor(value, dtype=dtype)
14+
15+
16+
def _collect_lightllm_kv_scale(scales, zeros, qmin, qmax):
17+
if isinstance(scales, torch.Tensor) and scales.numel() == 0:
18+
return None
19+
20+
scales_tensor = _to_tensor(scales)
21+
zeros_tensor = _to_tensor(zeros, dtype=scales_tensor.dtype)
22+
qmin_tensor = _to_tensor(qmin, dtype=scales_tensor.dtype)
23+
qmax_tensor = _to_tensor(qmax, dtype=scales_tensor.dtype)
24+
min_tensor = (qmin_tensor - zeros_tensor) * scales_tensor
25+
max_tensor = (qmax_tensor - zeros_tensor) * scales_tensor
26+
absmax_tensor = torch.maximum(min_tensor.abs(), max_tensor.abs())
27+
fp8_qmax = torch.tensor(
28+
torch.finfo(torch.float8_e4m3fn).max, dtype=absmax_tensor.dtype
29+
)
30+
return absmax_tensor / fp8_qmax
31+
32+
433
def collect_lightllm_kv_calib_json(blockwise_opt):
534
if not getattr(blockwise_opt, 'quant_kvcache', False):
635
raise ValueError(
@@ -24,13 +53,13 @@ def collect_lightllm_kv_calib_json(blockwise_opt):
2453
)
2554
scales = []
2655
for layer_idx in range(num_layers):
27-
key_scale = blockwise_opt._collect_lightllm_kv_scale(
56+
key_scale = _collect_lightllm_kv_scale(
2857
blockwise_opt.kv_module.k_scales_buffer[layer_idx],
2958
blockwise_opt.kv_module.k_zeros_buffer[layer_idx],
3059
blockwise_opt.kv_module.k_qmin_buffer[layer_idx],
3160
blockwise_opt.kv_module.k_qmax_buffer[layer_idx],
3261
)
33-
value_scale = blockwise_opt._collect_lightllm_kv_scale(
62+
value_scale = _collect_lightllm_kv_scale(
3463
blockwise_opt.kv_module.v_scales_buffer[layer_idx],
3564
blockwise_opt.kv_module.v_zeros_buffer[layer_idx],
3665
blockwise_opt.kv_module.v_qmin_buffer[layer_idx],
@@ -65,5 +94,5 @@ def collect_lightllm_kv_calib_json(blockwise_opt):
6594
'num_layers': num_layers,
6695
'num_head': num_head,
6796
'scales_shape': [num_layers, scale_width],
68-
'scales': scales,
97+
'scales': _to_jsonable(scales),
6998
}

0 commit comments

Comments
 (0)