From 0c09f8cd741f6c7596dc24f99ab2a7563a717963 Mon Sep 17 00:00:00 2001 From: RuBing-Yang Date: Thu, 21 Aug 2025 23:53:00 +0800 Subject: [PATCH 1/2] add fp8 blockwise quant --- docs/source/getting_started/quickstrat.md | 9 + tools/fp8_quant_blockwise.py | 211 ++++++++++++++++++++++ 2 files changed, 220 insertions(+) create mode 100644 tools/fp8_quant_blockwise.py diff --git a/docs/source/getting_started/quickstrat.md b/docs/source/getting_started/quickstrat.md index 8275cde6..7340e55f 100644 --- a/docs/source/getting_started/quickstrat.md +++ b/docs/source/getting_started/quickstrat.md @@ -35,6 +35,15 @@ python3 tools/run.py -c configs/qwen3/fp8_static/qwen3-1_7b_fp8_static.yaml from angelslim import engine engine.get_supported_compress_method() ``` +- fp8 block-wise量化可以使用并行GPU转化脚本,其中block_size是权重量化scale对应分块形状,`num_workers`是并行数 + +```shell +python3 tools/fp8_quant_blockwise.py \ + --block_size 128 128 \ + --num_workers 32 \ + --input_path ${INPUT_PATH} \ + --output_path ${OUTPUT_PATH} +``` ## 部署 diff --git a/tools/fp8_quant_blockwise.py b/tools/fp8_quant_blockwise.py new file mode 100644 index 00000000..27237076 --- /dev/null +++ b/tools/fp8_quant_blockwise.py @@ -0,0 +1,211 @@ +import json +import math +import multiprocessing as mp +import os +import shutil +from argparse import ArgumentParser + +import torch +from safetensors.torch import safe_open, save_file +from tqdm import tqdm + +SUFFIX_TO_QUANT = [ + ".gate_and_up_proj.weight", + ".gate_proj.weight", + ".up_proj.weight", + ".down_proj.weight", + ".q_a_proj.weight", + ".q_b_proj.weight", + ".kv_a_proj_with_mqa.weight", + ".kv_b_proj.weight", + ".qkv_proj.weight", + ".q_proj.weight", + ".k_proj.weight", + ".v_proj.weight", + ".o_proj.weight", +] + + +def create_quantized_param(param, weight_block_size=(128, 128)): + """ + Quantizes weights to FP8 format using Block-wise quantization + """ + # Get FP8 min/max values + fp8_min = torch.finfo(torch.float8_e4m3fn).min + fp8_max = torch.finfo(torch.float8_e4m3fn).max + + block_size_m, block_size_n = weight_block_size + rows, cols = param.shape[-2:] + + # Tensor-wise + if block_size_m == -1 or block_size_m > rows: + block_size_m = rows + if block_size_n == -1 or block_size_n > cols: + block_size_n = cols + + if rows % block_size_m != 0: + pad = torch.zeros( + [*param.shape[:-2], block_size_m - rows % block_size_m, cols], + dtype=param.dtype, + device=param.device, + ) + param = torch.concat([param, pad], dim=-2) + if cols % block_size_n != 0: + pad = torch.zeros( + [*param.shape[:-2], rows, block_size_n - cols % block_size_n], + dtype=param.dtype, + device=param.device, + ) + param = torch.concat([param, pad], dim=-1) + param_value_shape = param.shape + + param_value = ( + param.float() + .reshape( + -1, + math.ceil(rows / block_size_m), + block_size_m, + math.ceil(cols // block_size_n), + block_size_n, + ) + .permute(0, 1, 3, 2, 4) + ) + + # Calculate scaling factor for each block + max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2)) + scale_inv = fp8_max / max_abs + scale_orig_shape = scale_inv.shape + scale_inv = scale_inv.unsqueeze(-1).unsqueeze(-1) + + # Quantize the weights + quantized_param = torch.clamp(param_value * scale_inv, min=fp8_min, max=fp8_max).to( + torch.float8_e4m3fn + ) + quantized_param = quantized_param.permute(0, 1, 3, 2, 4) + quantized_param = quantized_param.reshape(param_value_shape)[..., :rows, :cols] + + scale_inv = scale_inv.reshape(scale_orig_shape).squeeze().reciprocal() + + return quantized_param.contiguous(), scale_inv.contiguous() + + +def process_safetensor(rank, file_name, input_path, output_path, block_size=(128, 128)): + state_dict = {} + index = {} + count = 0 + with safe_open( + os.path.join(input_path, file_name), framework="pt", device=f"cuda:{rank}" + ) as f: + print(f"Processing {file_name} with {len(f.keys())} weights") + for weight_name in f.keys(): + weight = f.get_tensor(weight_name) + if any(weight_name.endswith(suffix) for suffix in SUFFIX_TO_QUANT): + quant_weight, scale = create_quantized_param(weight, block_size) + state_dict[weight_name] = quant_weight + index[weight_name] = file_name + + # Reference: https://github.com/vllm-project/vllm/blob/v0.10.1/vllm/model_executor/layers/quantization/fp8.py#L295 # noqa: E501 + if block_size[0] == -1 and block_size[1] == -1: + # Tensor-wise + state_dict[f"{weight_name}_scale"] = scale + index[f"{weight_name}_scale"] = file_name + else: + # Block-wise + state_dict[f"{weight_name}_scale_inv"] = scale + index[f"{weight_name}_scale_inv"] = file_name + else: + state_dict[weight_name] = weight + index[weight_name] = file_name + count += 1 + + new_safetensor_file = os.path.join(output_path, file_name) + save_file(state_dict, new_safetensor_file) + return index + + +def worker(i, file_names, input_path, output_path, block_size, return_dict): + world_size = torch.cuda.device_count() + for file_name in tqdm(file_names, desc=f"Worker {i}"): + index = process_safetensor( + i % world_size, file_name, input_path, output_path, block_size + ) + return_dict[file_name] = index + + +def main(input_path, output_path, block_size): + os.makedirs(output_path, exist_ok=True) + model_index_file = os.path.join(input_path, "model.safetensors.index.json") + with open(model_index_file, "r") as f: + model_index = json.load(f) + weight_map = model_index["weight_map"] + safetensor_files = set(weight_map.values()) + safetensor_files = list(sorted(safetensor_files)) + print(f"Found {len(safetensor_files)} safetensor files") + + file_subsets = [ + safetensor_files[i :: args.num_workers] for i in range(args.num_workers) + ] + manager = mp.Manager() + return_dict = manager.dict() + processes = [] + for i in range(args.num_workers): + p = mp.Process( + target=worker, + args=(i, file_subsets[i], input_path, output_path, block_size, return_dict), + ) + p.start() + processes.append(p) + for p in processes: + p.join() + + index = {} + for result in return_dict.values(): + index.update(result) + with open(os.path.join(output_path, "model.safetensors.index.json"), "w") as f: + json.dump({"metadata": {}, "weight_map": index}, f, indent=2) + + # Copy config file + for file in os.listdir(input_path): + if ( + file.endswith(".py") + or file.endswith(".json") + or file.endswith(".md") + or file.endswith(".txt") + ): + src_path = os.path.join(input_path, file) + dst_path = os.path.join(output_path, file) + if os.path.exists(dst_path): + continue + print(f"cp {src_path} {dst_path}") + shutil.copy2(src_path, dst_path) + + # Quantization config + with open(os.path.join(output_path, "config.json"), "r") as f: + config = json.load(f) + config["quantization_config"] = { + "activation_scheme": "dynamic", + "fmt": "e4m3", + "quant_method": "fp8", + } + if block_size[0] != -1 and block_size[1] != -1: + config["quantization_config"]["weight_block_size"] = block_size + print(f"quant config: {config['quantization_config']}") + with open(os.path.join(output_path, "config.json"), "w") as f: + json.dump(config, f, indent=4) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--block_size", type=int, nargs=2, default=(128, 128)) + parser.add_argument("--num_workers", type=int, default=32) + parser.add_argument("--input_path", type=str, default="") + parser.add_argument("--output_path", type=str, default="") + args = parser.parse_args() + print(args) + with open(os.path.join(args.input_path, "config.json"), "r", encoding="utf8") as fp: + json_data = json.load(fp) + print(json_data) + if "quantization_config" in json_data.keys(): + raise AssertionError("NOT SUPPORT FP8 DS") + + main(args.input_path, args.output_path, args.block_size) From a78182a3a714cf09b76da6b8c381c8233f1f6b96 Mon Sep 17 00:00:00 2001 From: RuBing-Yang Date: Fri, 22 Aug 2025 22:35:08 +0800 Subject: [PATCH 2/2] add License --- tools/fp8_quant_blockwise.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tools/fp8_quant_blockwise.py b/tools/fp8_quant_blockwise.py index 27237076..f7c4ea48 100644 --- a/tools/fp8_quant_blockwise.py +++ b/tools/fp8_quant_blockwise.py @@ -1,3 +1,17 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import math import multiprocessing as mp