Skip to content

Commit bd0a8bd

Browse files
authored
Add fp8 blockwise quant (#52)
1 parent 62d3e69 commit bd0a8bd

2 files changed

Lines changed: 234 additions & 0 deletions

File tree

docs/source/getting_started/quickstrat.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,15 @@ python3 tools/run.py -c configs/qwen3/fp8_static/qwen3-1_7b_fp8_static.yaml
3535
from angelslim import engine
3636
engine.get_supported_compress_method()
3737
```
38+
- fp8 block-wise量化可以使用并行GPU转化脚本,其中block_size是权重量化scale对应分块形状,`num_workers`是并行数
39+
40+
```shell
41+
python3 tools/fp8_quant_blockwise.py \
42+
--block_size 128 128 \
43+
--num_workers 32 \
44+
--input_path ${INPUT_PATH} \
45+
--output_path ${OUTPUT_PATH}
46+
```
3847

3948

4049
## 部署

tools/fp8_quant_blockwise.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
# Copyright 2025 Tencent Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
import math
17+
import multiprocessing as mp
18+
import os
19+
import shutil
20+
from argparse import ArgumentParser
21+
22+
import torch
23+
from safetensors.torch import safe_open, save_file
24+
from tqdm import tqdm
25+
26+
SUFFIX_TO_QUANT = [
27+
".gate_and_up_proj.weight",
28+
".gate_proj.weight",
29+
".up_proj.weight",
30+
".down_proj.weight",
31+
".q_a_proj.weight",
32+
".q_b_proj.weight",
33+
".kv_a_proj_with_mqa.weight",
34+
".kv_b_proj.weight",
35+
".qkv_proj.weight",
36+
".q_proj.weight",
37+
".k_proj.weight",
38+
".v_proj.weight",
39+
".o_proj.weight",
40+
]
41+
42+
43+
def create_quantized_param(param, weight_block_size=(128, 128)):
44+
"""
45+
Quantizes weights to FP8 format using Block-wise quantization
46+
"""
47+
# Get FP8 min/max values
48+
fp8_min = torch.finfo(torch.float8_e4m3fn).min
49+
fp8_max = torch.finfo(torch.float8_e4m3fn).max
50+
51+
block_size_m, block_size_n = weight_block_size
52+
rows, cols = param.shape[-2:]
53+
54+
# Tensor-wise
55+
if block_size_m == -1 or block_size_m > rows:
56+
block_size_m = rows
57+
if block_size_n == -1 or block_size_n > cols:
58+
block_size_n = cols
59+
60+
if rows % block_size_m != 0:
61+
pad = torch.zeros(
62+
[*param.shape[:-2], block_size_m - rows % block_size_m, cols],
63+
dtype=param.dtype,
64+
device=param.device,
65+
)
66+
param = torch.concat([param, pad], dim=-2)
67+
if cols % block_size_n != 0:
68+
pad = torch.zeros(
69+
[*param.shape[:-2], rows, block_size_n - cols % block_size_n],
70+
dtype=param.dtype,
71+
device=param.device,
72+
)
73+
param = torch.concat([param, pad], dim=-1)
74+
param_value_shape = param.shape
75+
76+
param_value = (
77+
param.float()
78+
.reshape(
79+
-1,
80+
math.ceil(rows / block_size_m),
81+
block_size_m,
82+
math.ceil(cols // block_size_n),
83+
block_size_n,
84+
)
85+
.permute(0, 1, 3, 2, 4)
86+
)
87+
88+
# Calculate scaling factor for each block
89+
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
90+
scale_inv = fp8_max / max_abs
91+
scale_orig_shape = scale_inv.shape
92+
scale_inv = scale_inv.unsqueeze(-1).unsqueeze(-1)
93+
94+
# Quantize the weights
95+
quantized_param = torch.clamp(param_value * scale_inv, min=fp8_min, max=fp8_max).to(
96+
torch.float8_e4m3fn
97+
)
98+
quantized_param = quantized_param.permute(0, 1, 3, 2, 4)
99+
quantized_param = quantized_param.reshape(param_value_shape)[..., :rows, :cols]
100+
101+
scale_inv = scale_inv.reshape(scale_orig_shape).squeeze().reciprocal()
102+
103+
return quantized_param.contiguous(), scale_inv.contiguous()
104+
105+
106+
def process_safetensor(rank, file_name, input_path, output_path, block_size=(128, 128)):
107+
state_dict = {}
108+
index = {}
109+
count = 0
110+
with safe_open(
111+
os.path.join(input_path, file_name), framework="pt", device=f"cuda:{rank}"
112+
) as f:
113+
print(f"Processing {file_name} with {len(f.keys())} weights")
114+
for weight_name in f.keys():
115+
weight = f.get_tensor(weight_name)
116+
if any(weight_name.endswith(suffix) for suffix in SUFFIX_TO_QUANT):
117+
quant_weight, scale = create_quantized_param(weight, block_size)
118+
state_dict[weight_name] = quant_weight
119+
index[weight_name] = file_name
120+
121+
# Reference: https://github.com/vllm-project/vllm/blob/v0.10.1/vllm/model_executor/layers/quantization/fp8.py#L295 # noqa: E501
122+
if block_size[0] == -1 and block_size[1] == -1:
123+
# Tensor-wise
124+
state_dict[f"{weight_name}_scale"] = scale
125+
index[f"{weight_name}_scale"] = file_name
126+
else:
127+
# Block-wise
128+
state_dict[f"{weight_name}_scale_inv"] = scale
129+
index[f"{weight_name}_scale_inv"] = file_name
130+
else:
131+
state_dict[weight_name] = weight
132+
index[weight_name] = file_name
133+
count += 1
134+
135+
new_safetensor_file = os.path.join(output_path, file_name)
136+
save_file(state_dict, new_safetensor_file)
137+
return index
138+
139+
140+
def worker(i, file_names, input_path, output_path, block_size, return_dict):
141+
world_size = torch.cuda.device_count()
142+
for file_name in tqdm(file_names, desc=f"Worker {i}"):
143+
index = process_safetensor(
144+
i % world_size, file_name, input_path, output_path, block_size
145+
)
146+
return_dict[file_name] = index
147+
148+
149+
def main(input_path, output_path, block_size):
150+
os.makedirs(output_path, exist_ok=True)
151+
model_index_file = os.path.join(input_path, "model.safetensors.index.json")
152+
with open(model_index_file, "r") as f:
153+
model_index = json.load(f)
154+
weight_map = model_index["weight_map"]
155+
safetensor_files = set(weight_map.values())
156+
safetensor_files = list(sorted(safetensor_files))
157+
print(f"Found {len(safetensor_files)} safetensor files")
158+
159+
file_subsets = [
160+
safetensor_files[i :: args.num_workers] for i in range(args.num_workers)
161+
]
162+
manager = mp.Manager()
163+
return_dict = manager.dict()
164+
processes = []
165+
for i in range(args.num_workers):
166+
p = mp.Process(
167+
target=worker,
168+
args=(i, file_subsets[i], input_path, output_path, block_size, return_dict),
169+
)
170+
p.start()
171+
processes.append(p)
172+
for p in processes:
173+
p.join()
174+
175+
index = {}
176+
for result in return_dict.values():
177+
index.update(result)
178+
with open(os.path.join(output_path, "model.safetensors.index.json"), "w") as f:
179+
json.dump({"metadata": {}, "weight_map": index}, f, indent=2)
180+
181+
# Copy config file
182+
for file in os.listdir(input_path):
183+
if (
184+
file.endswith(".py")
185+
or file.endswith(".json")
186+
or file.endswith(".md")
187+
or file.endswith(".txt")
188+
):
189+
src_path = os.path.join(input_path, file)
190+
dst_path = os.path.join(output_path, file)
191+
if os.path.exists(dst_path):
192+
continue
193+
print(f"cp {src_path} {dst_path}")
194+
shutil.copy2(src_path, dst_path)
195+
196+
# Quantization config
197+
with open(os.path.join(output_path, "config.json"), "r") as f:
198+
config = json.load(f)
199+
config["quantization_config"] = {
200+
"activation_scheme": "dynamic",
201+
"fmt": "e4m3",
202+
"quant_method": "fp8",
203+
}
204+
if block_size[0] != -1 and block_size[1] != -1:
205+
config["quantization_config"]["weight_block_size"] = block_size
206+
print(f"quant config: {config['quantization_config']}")
207+
with open(os.path.join(output_path, "config.json"), "w") as f:
208+
json.dump(config, f, indent=4)
209+
210+
211+
if __name__ == "__main__":
212+
parser = ArgumentParser()
213+
parser.add_argument("--block_size", type=int, nargs=2, default=(128, 128))
214+
parser.add_argument("--num_workers", type=int, default=32)
215+
parser.add_argument("--input_path", type=str, default="")
216+
parser.add_argument("--output_path", type=str, default="")
217+
args = parser.parse_args()
218+
print(args)
219+
with open(os.path.join(args.input_path, "config.json"), "r", encoding="utf8") as fp:
220+
json_data = json.load(fp)
221+
print(json_data)
222+
if "quantization_config" in json_data.keys():
223+
raise AssertionError("NOT SUPPORT FP8 DS")
224+
225+
main(args.input_path, args.output_path, args.block_size)

0 commit comments

Comments
 (0)