Skip to content

Commit f673464

Browse files
dynamicheartlinchuanxie
authored andcommitted
[Kunlunxin] Support DS V3/R1 fp8 cast to channel-wise int8 (#208)
1 parent d22ac4b commit f673464

1 file changed

Lines changed: 250 additions & 0 deletions

File tree

tools/fp8_cast_channel_int8.py

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
# This file is based on DeepSeek code (MIT License).
2+
#
3+
# Original code:
4+
# Copyright (c) 2023 DeepSeek
5+
# https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py
6+
# https://huggingface.co/meituan/DeepSeek-R1-Channel-INT8/blob/main/inference/bf16_cast_channel_int8.py (Meituan fork) # noqa: E501
7+
#
8+
# Additional contributions:
9+
# Copyright (c) 2026 Kunlunxin (Beijing) Technology Co., Ltd. (Kunlunxin)
10+
#
11+
# Modifications:
12+
# - Merged implementations
13+
# - Added multi-GPU parallel processing
14+
#
15+
# SPDX-License-Identifier: Apache-2.0 AND MIT
16+
17+
import json
18+
import os
19+
import shutil
20+
from argparse import ArgumentParser
21+
from glob import glob
22+
23+
import torch
24+
import torch.multiprocessing as mp
25+
from safetensors.torch import safe_open, save_file
26+
27+
from angelslim.compressor.quant.core.quant_func import weight_dequant
28+
29+
30+
def process_worker(
31+
worker_id, safetensor_files, fp8_path, int8_path, weight_map, return_dict
32+
):
33+
"""
34+
Process worker.
35+
36+
Each worker process is responsible for a subset of safetensor files:
37+
- FP8 → BF16 dequantization
38+
- BF16 → INT8 quantization
39+
- Generation of the updated weight_map
40+
"""
41+
num_gpus = torch.cuda.device_count()
42+
rank = worker_id % num_gpus
43+
torch.cuda.set_device(rank)
44+
quant_count = 0
45+
new_weight_map = {}
46+
for safetensor_file in safetensor_files:
47+
file_name = os.path.basename(safetensor_file)
48+
print(f"[Worker {worker_id}][GPU {rank}] processing {file_name}")
49+
with safe_open(safetensor_file, framework="pt", device=f"cuda:{rank}") as f:
50+
new_state_dict = {}
51+
keys = set(f.keys())
52+
for weight_name in keys:
53+
weight = f.get_tensor(weight_name)
54+
scale_inv_name = f"{weight_name}_scale_inv"
55+
if scale_inv_name in weight_map:
56+
quant_count += 1
57+
# 1. fp8 dequant to bf16
58+
scale_inv = get_tensor_from_file(
59+
rank, scale_inv_name, weight_map, fp8_path
60+
)
61+
weight_bf16 = weight_dequant(weight, scale_inv)
62+
# 2. bf16 quant to int8
63+
int8_weight, scale_inv = weight_quant(weight_bf16)
64+
new_state_dict[weight_name] = int8_weight
65+
new_scale_name = scale_inv_name.replace("_scale_inv", "_scale")
66+
new_state_dict[new_scale_name] = scale_inv
67+
new_weight_map[weight_name] = file_name
68+
new_weight_map[new_scale_name] = file_name
69+
else:
70+
if weight_name.endswith("_scale_inv"):
71+
continue
72+
new_state_dict[weight_name] = weight
73+
new_weight_map[weight_name] = file_name
74+
75+
new_safetensor_file = os.path.join(int8_path, file_name)
76+
save_file(new_state_dict, new_safetensor_file)
77+
return_dict[worker_id] = (quant_count, new_weight_map)
78+
79+
80+
# Helper function to get tensor from the correct file
81+
def get_tensor_from_file(rank, tensor_name, weight_map, fp8_path):
82+
"""
83+
Retrieves a tensor from mmap safe_tensors
84+
85+
Args:
86+
tensor_name (str): The name of the tensor to retrieve.
87+
88+
Returns:
89+
torch.Tensor: The retrieved tensor.
90+
91+
Raises:
92+
KeyError: If the tensor does not exist in the safetensor file.
93+
"""
94+
torch.cuda.set_device(rank)
95+
file_name = weight_map[tensor_name]
96+
file_path = os.path.join(fp8_path, file_name)
97+
98+
with safe_open(file_path, framework="pt", device=f"cuda:{rank}") as f:
99+
return f.get_tensor(tensor_name)
100+
101+
102+
def weight_quant(tensor: torch.Tensor):
103+
"""
104+
Quantize a 2D tensor row-wise from BF16/FP32 to INT8.
105+
Args:
106+
tensor (torch.Tensor): Input 2D tensor.
107+
Returns:
108+
Tuple[torch.Tensor, torch.Tensor]:
109+
- Quantized INT8 tensor.
110+
- Scale tensor (float32) used for quantization.
111+
"""
112+
assert tensor.dim() == 2
113+
qmax = 127.0
114+
abs_max = torch.abs(tensor).max(dim=1, keepdim=True)[0] # [rows, 1]
115+
scale = abs_max / qmax # [rows, 1]
116+
assert scale.shape == (tensor.shape[0], 1)
117+
quantized = torch.round(tensor / scale)
118+
quantized = torch.clamp(quantized, -qmax, qmax)
119+
return quantized.to(torch.int8), scale.to(torch.float32)
120+
121+
122+
def main(fp8_path, int8_path, num_workers):
123+
"""
124+
Run the FP8-to-INT8 per-channel quantization pipeline.
125+
126+
This function:
127+
1. Copy the config file
128+
2. Loads FP8 safetensors.
129+
3. Dequantizes FP8 → BF16, then quantizes BF16 → INT8.
130+
4. Saves quantized safetensors and updates model index.
131+
132+
Args:
133+
fp8_path (str): Path to directory containing FP8 safetensors.
134+
int8_path (str): Output directory to save INT8 safetensors.
135+
num_workers (int): Number of processing workers
136+
"""
137+
torch.set_default_dtype(torch.bfloat16)
138+
os.makedirs(int8_path, exist_ok=True)
139+
model_index_file = os.path.join(int8_path, "model.safetensors.index.json")
140+
config_file = os.path.join(int8_path, "config.json")
141+
142+
for fname in os.listdir(fp8_path):
143+
if fname.endswith(".safetensors"):
144+
continue
145+
src = os.path.join(fp8_path, fname)
146+
dst = os.path.join(int8_path, fname)
147+
if os.path.isdir(src):
148+
print(f"cp -r {src} {dst}")
149+
shutil.copytree(src, dst, dirs_exist_ok=True)
150+
elif os.path.isfile(src):
151+
print(f"cp {src} {dst}")
152+
shutil.copy2(src, dst)
153+
154+
# modify config.json and save it
155+
config = json.load(open(config_file))
156+
# delete quantization_config
157+
config.pop("quantization_config", None)
158+
config["quantization_config"] = {
159+
"config_groups": {
160+
"group_0": {
161+
"input_activations": {
162+
"actorder": None,
163+
"block_structure": None,
164+
"dynamic": True,
165+
"group_size": None,
166+
"num_bits": 8,
167+
"observer": "memoryless",
168+
"observer_kwargs": {},
169+
"strategy": "token",
170+
"symmetric": True,
171+
"type": "int",
172+
},
173+
"output_activations": None,
174+
"weights": {
175+
"actorder": None,
176+
"block_structure": None,
177+
"dynamic": False,
178+
"group_size": None,
179+
"num_bits": 8,
180+
"observer": "minmax",
181+
"observer_kwargs": {},
182+
"strategy": "channel",
183+
"symmetric": True,
184+
"type": "int",
185+
},
186+
"targets": ["Linear"],
187+
}
188+
},
189+
"format": "int-quantized",
190+
"ignore": ["lm_head"],
191+
"kv_cache_scheme": None,
192+
"quant_method": "compressed-tensors",
193+
"quantization_status": "compressed",
194+
}
195+
196+
with open(config_file, "w", encoding="utf-8") as f:
197+
json.dump(config, f, indent=2, ensure_ascii=False, sort_keys=True)
198+
print(f"config.json modified and saved to {config_file}")
199+
200+
with open(model_index_file, "r") as f:
201+
model_index = json.load(f)
202+
weight_map = model_index["weight_map"]
203+
scale_count = len([key for key in weight_map.keys() if key.endswith("_scale_inv")])
204+
205+
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
206+
safetensor_files.sort()
207+
quant_count = 0
208+
new_weight_map = {}
209+
210+
file_subsets = [safetensor_files[i::num_workers] for i in range(num_workers)]
211+
212+
mp.set_start_method("spawn", force=True)
213+
manager = mp.Manager()
214+
return_dict = manager.dict()
215+
processes = []
216+
for i in range(num_workers):
217+
p = mp.Process(
218+
target=process_worker,
219+
args=(i, file_subsets[i], fp8_path, int8_path, weight_map, return_dict),
220+
)
221+
p.start()
222+
processes.append(p)
223+
for p in processes:
224+
p.join()
225+
226+
for i in range(num_workers):
227+
qc, wm = return_dict[i]
228+
quant_count += qc
229+
new_weight_map.update(wm)
230+
assert quant_count == scale_count
231+
print(f"{quant_count} weights are quantized.")
232+
233+
# modify model.safetensors.index.json
234+
with open(model_index_file, "r") as f:
235+
model_index = json.load(f)
236+
model_index["weight_map"] = new_weight_map
237+
with open(model_index_file, "w", encoding="utf-8") as f:
238+
json.dump(model_index, f, indent=2, ensure_ascii=False, sort_keys=True)
239+
print(f"model.safetensors.index.json modified and saved to {model_index_file}")
240+
241+
242+
if __name__ == "__main__":
243+
parser = ArgumentParser()
244+
parser.add_argument("--input-fp8-path", type=str, required=True)
245+
parser.add_argument("--output-int8-path", type=str, required=True)
246+
parser.add_argument("--num-workers", type=int, default=32)
247+
248+
args = parser.parse_args()
249+
main(args.input_fp8_path, args.output_int8_path, args.num_workers)
250+
print("done")

0 commit comments

Comments
 (0)