Skip to content

Commit 0c09f8c

Browse files
committed
add fp8 blockwise quant
1 parent fa7304e commit 0c09f8c

2 files changed

Lines changed: 220 additions & 0 deletions

File tree

docs/source/getting_started/quickstrat.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,15 @@ python3 tools/run.py -c configs/qwen3/fp8_static/qwen3-1_7b_fp8_static.yaml
3535
from angelslim import engine
3636
engine.get_supported_compress_method()
3737
```
38+
- fp8 block-wise量化可以使用并行GPU转化脚本,其中block_size是权重量化scale对应分块形状,`num_workers`是并行数
39+
40+
```shell
41+
python3 tools/fp8_quant_blockwise.py \
42+
--block_size 128 128 \
43+
--num_workers 32 \
44+
--input_path ${INPUT_PATH} \
45+
--output_path ${OUTPUT_PATH}
46+
```
3847

3948

4049
## 部署

tools/fp8_quant_blockwise.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
import json
2+
import math
3+
import multiprocessing as mp
4+
import os
5+
import shutil
6+
from argparse import ArgumentParser
7+
8+
import torch
9+
from safetensors.torch import safe_open, save_file
10+
from tqdm import tqdm
11+
12+
SUFFIX_TO_QUANT = [
13+
".gate_and_up_proj.weight",
14+
".gate_proj.weight",
15+
".up_proj.weight",
16+
".down_proj.weight",
17+
".q_a_proj.weight",
18+
".q_b_proj.weight",
19+
".kv_a_proj_with_mqa.weight",
20+
".kv_b_proj.weight",
21+
".qkv_proj.weight",
22+
".q_proj.weight",
23+
".k_proj.weight",
24+
".v_proj.weight",
25+
".o_proj.weight",
26+
]
27+
28+
29+
def create_quantized_param(param, weight_block_size=(128, 128)):
30+
"""
31+
Quantizes weights to FP8 format using Block-wise quantization
32+
"""
33+
# Get FP8 min/max values
34+
fp8_min = torch.finfo(torch.float8_e4m3fn).min
35+
fp8_max = torch.finfo(torch.float8_e4m3fn).max
36+
37+
block_size_m, block_size_n = weight_block_size
38+
rows, cols = param.shape[-2:]
39+
40+
# Tensor-wise
41+
if block_size_m == -1 or block_size_m > rows:
42+
block_size_m = rows
43+
if block_size_n == -1 or block_size_n > cols:
44+
block_size_n = cols
45+
46+
if rows % block_size_m != 0:
47+
pad = torch.zeros(
48+
[*param.shape[:-2], block_size_m - rows % block_size_m, cols],
49+
dtype=param.dtype,
50+
device=param.device,
51+
)
52+
param = torch.concat([param, pad], dim=-2)
53+
if cols % block_size_n != 0:
54+
pad = torch.zeros(
55+
[*param.shape[:-2], rows, block_size_n - cols % block_size_n],
56+
dtype=param.dtype,
57+
device=param.device,
58+
)
59+
param = torch.concat([param, pad], dim=-1)
60+
param_value_shape = param.shape
61+
62+
param_value = (
63+
param.float()
64+
.reshape(
65+
-1,
66+
math.ceil(rows / block_size_m),
67+
block_size_m,
68+
math.ceil(cols // block_size_n),
69+
block_size_n,
70+
)
71+
.permute(0, 1, 3, 2, 4)
72+
)
73+
74+
# Calculate scaling factor for each block
75+
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
76+
scale_inv = fp8_max / max_abs
77+
scale_orig_shape = scale_inv.shape
78+
scale_inv = scale_inv.unsqueeze(-1).unsqueeze(-1)
79+
80+
# Quantize the weights
81+
quantized_param = torch.clamp(param_value * scale_inv, min=fp8_min, max=fp8_max).to(
82+
torch.float8_e4m3fn
83+
)
84+
quantized_param = quantized_param.permute(0, 1, 3, 2, 4)
85+
quantized_param = quantized_param.reshape(param_value_shape)[..., :rows, :cols]
86+
87+
scale_inv = scale_inv.reshape(scale_orig_shape).squeeze().reciprocal()
88+
89+
return quantized_param.contiguous(), scale_inv.contiguous()
90+
91+
92+
def process_safetensor(rank, file_name, input_path, output_path, block_size=(128, 128)):
93+
state_dict = {}
94+
index = {}
95+
count = 0
96+
with safe_open(
97+
os.path.join(input_path, file_name), framework="pt", device=f"cuda:{rank}"
98+
) as f:
99+
print(f"Processing {file_name} with {len(f.keys())} weights")
100+
for weight_name in f.keys():
101+
weight = f.get_tensor(weight_name)
102+
if any(weight_name.endswith(suffix) for suffix in SUFFIX_TO_QUANT):
103+
quant_weight, scale = create_quantized_param(weight, block_size)
104+
state_dict[weight_name] = quant_weight
105+
index[weight_name] = file_name
106+
107+
# Reference: https://github.com/vllm-project/vllm/blob/v0.10.1/vllm/model_executor/layers/quantization/fp8.py#L295 # noqa: E501
108+
if block_size[0] == -1 and block_size[1] == -1:
109+
# Tensor-wise
110+
state_dict[f"{weight_name}_scale"] = scale
111+
index[f"{weight_name}_scale"] = file_name
112+
else:
113+
# Block-wise
114+
state_dict[f"{weight_name}_scale_inv"] = scale
115+
index[f"{weight_name}_scale_inv"] = file_name
116+
else:
117+
state_dict[weight_name] = weight
118+
index[weight_name] = file_name
119+
count += 1
120+
121+
new_safetensor_file = os.path.join(output_path, file_name)
122+
save_file(state_dict, new_safetensor_file)
123+
return index
124+
125+
126+
def worker(i, file_names, input_path, output_path, block_size, return_dict):
127+
world_size = torch.cuda.device_count()
128+
for file_name in tqdm(file_names, desc=f"Worker {i}"):
129+
index = process_safetensor(
130+
i % world_size, file_name, input_path, output_path, block_size
131+
)
132+
return_dict[file_name] = index
133+
134+
135+
def main(input_path, output_path, block_size):
136+
os.makedirs(output_path, exist_ok=True)
137+
model_index_file = os.path.join(input_path, "model.safetensors.index.json")
138+
with open(model_index_file, "r") as f:
139+
model_index = json.load(f)
140+
weight_map = model_index["weight_map"]
141+
safetensor_files = set(weight_map.values())
142+
safetensor_files = list(sorted(safetensor_files))
143+
print(f"Found {len(safetensor_files)} safetensor files")
144+
145+
file_subsets = [
146+
safetensor_files[i :: args.num_workers] for i in range(args.num_workers)
147+
]
148+
manager = mp.Manager()
149+
return_dict = manager.dict()
150+
processes = []
151+
for i in range(args.num_workers):
152+
p = mp.Process(
153+
target=worker,
154+
args=(i, file_subsets[i], input_path, output_path, block_size, return_dict),
155+
)
156+
p.start()
157+
processes.append(p)
158+
for p in processes:
159+
p.join()
160+
161+
index = {}
162+
for result in return_dict.values():
163+
index.update(result)
164+
with open(os.path.join(output_path, "model.safetensors.index.json"), "w") as f:
165+
json.dump({"metadata": {}, "weight_map": index}, f, indent=2)
166+
167+
# Copy config file
168+
for file in os.listdir(input_path):
169+
if (
170+
file.endswith(".py")
171+
or file.endswith(".json")
172+
or file.endswith(".md")
173+
or file.endswith(".txt")
174+
):
175+
src_path = os.path.join(input_path, file)
176+
dst_path = os.path.join(output_path, file)
177+
if os.path.exists(dst_path):
178+
continue
179+
print(f"cp {src_path} {dst_path}")
180+
shutil.copy2(src_path, dst_path)
181+
182+
# Quantization config
183+
with open(os.path.join(output_path, "config.json"), "r") as f:
184+
config = json.load(f)
185+
config["quantization_config"] = {
186+
"activation_scheme": "dynamic",
187+
"fmt": "e4m3",
188+
"quant_method": "fp8",
189+
}
190+
if block_size[0] != -1 and block_size[1] != -1:
191+
config["quantization_config"]["weight_block_size"] = block_size
192+
print(f"quant config: {config['quantization_config']}")
193+
with open(os.path.join(output_path, "config.json"), "w") as f:
194+
json.dump(config, f, indent=4)
195+
196+
197+
if __name__ == "__main__":
198+
parser = ArgumentParser()
199+
parser.add_argument("--block_size", type=int, nargs=2, default=(128, 128))
200+
parser.add_argument("--num_workers", type=int, default=32)
201+
parser.add_argument("--input_path", type=str, default="")
202+
parser.add_argument("--output_path", type=str, default="")
203+
args = parser.parse_args()
204+
print(args)
205+
with open(os.path.join(args.input_path, "config.json"), "r", encoding="utf8") as fp:
206+
json_data = json.load(fp)
207+
print(json_data)
208+
if "quantization_config" in json_data.keys():
209+
raise AssertionError("NOT SUPPORT FP8 DS")
210+
211+
main(args.input_path, args.output_path, args.block_size)

0 commit comments

Comments
 (0)