|
| 1 | +import json |
| 2 | +import math |
| 3 | +import multiprocessing as mp |
| 4 | +import os |
| 5 | +import shutil |
| 6 | +from argparse import ArgumentParser |
| 7 | + |
| 8 | +import torch |
| 9 | +from safetensors.torch import safe_open, save_file |
| 10 | +from tqdm import tqdm |
| 11 | + |
| 12 | +SUFFIX_TO_QUANT = [ |
| 13 | + ".gate_and_up_proj.weight", |
| 14 | + ".gate_proj.weight", |
| 15 | + ".up_proj.weight", |
| 16 | + ".down_proj.weight", |
| 17 | + ".q_a_proj.weight", |
| 18 | + ".q_b_proj.weight", |
| 19 | + ".kv_a_proj_with_mqa.weight", |
| 20 | + ".kv_b_proj.weight", |
| 21 | + ".qkv_proj.weight", |
| 22 | + ".q_proj.weight", |
| 23 | + ".k_proj.weight", |
| 24 | + ".v_proj.weight", |
| 25 | + ".o_proj.weight", |
| 26 | +] |
| 27 | + |
| 28 | + |
| 29 | +def create_quantized_param(param, weight_block_size=(128, 128)): |
| 30 | + """ |
| 31 | + Quantizes weights to FP8 format using Block-wise quantization |
| 32 | + """ |
| 33 | + # Get FP8 min/max values |
| 34 | + fp8_min = torch.finfo(torch.float8_e4m3fn).min |
| 35 | + fp8_max = torch.finfo(torch.float8_e4m3fn).max |
| 36 | + |
| 37 | + block_size_m, block_size_n = weight_block_size |
| 38 | + rows, cols = param.shape[-2:] |
| 39 | + |
| 40 | + # Tensor-wise |
| 41 | + if block_size_m == -1 or block_size_m > rows: |
| 42 | + block_size_m = rows |
| 43 | + if block_size_n == -1 or block_size_n > cols: |
| 44 | + block_size_n = cols |
| 45 | + |
| 46 | + if rows % block_size_m != 0: |
| 47 | + pad = torch.zeros( |
| 48 | + [*param.shape[:-2], block_size_m - rows % block_size_m, cols], |
| 49 | + dtype=param.dtype, |
| 50 | + device=param.device, |
| 51 | + ) |
| 52 | + param = torch.concat([param, pad], dim=-2) |
| 53 | + if cols % block_size_n != 0: |
| 54 | + pad = torch.zeros( |
| 55 | + [*param.shape[:-2], rows, block_size_n - cols % block_size_n], |
| 56 | + dtype=param.dtype, |
| 57 | + device=param.device, |
| 58 | + ) |
| 59 | + param = torch.concat([param, pad], dim=-1) |
| 60 | + param_value_shape = param.shape |
| 61 | + |
| 62 | + param_value = ( |
| 63 | + param.float() |
| 64 | + .reshape( |
| 65 | + -1, |
| 66 | + math.ceil(rows / block_size_m), |
| 67 | + block_size_m, |
| 68 | + math.ceil(cols // block_size_n), |
| 69 | + block_size_n, |
| 70 | + ) |
| 71 | + .permute(0, 1, 3, 2, 4) |
| 72 | + ) |
| 73 | + |
| 74 | + # Calculate scaling factor for each block |
| 75 | + max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2)) |
| 76 | + scale_inv = fp8_max / max_abs |
| 77 | + scale_orig_shape = scale_inv.shape |
| 78 | + scale_inv = scale_inv.unsqueeze(-1).unsqueeze(-1) |
| 79 | + |
| 80 | + # Quantize the weights |
| 81 | + quantized_param = torch.clamp(param_value * scale_inv, min=fp8_min, max=fp8_max).to( |
| 82 | + torch.float8_e4m3fn |
| 83 | + ) |
| 84 | + quantized_param = quantized_param.permute(0, 1, 3, 2, 4) |
| 85 | + quantized_param = quantized_param.reshape(param_value_shape)[..., :rows, :cols] |
| 86 | + |
| 87 | + scale_inv = scale_inv.reshape(scale_orig_shape).squeeze().reciprocal() |
| 88 | + |
| 89 | + return quantized_param.contiguous(), scale_inv.contiguous() |
| 90 | + |
| 91 | + |
| 92 | +def process_safetensor(rank, file_name, input_path, output_path, block_size=(128, 128)): |
| 93 | + state_dict = {} |
| 94 | + index = {} |
| 95 | + count = 0 |
| 96 | + with safe_open( |
| 97 | + os.path.join(input_path, file_name), framework="pt", device=f"cuda:{rank}" |
| 98 | + ) as f: |
| 99 | + print(f"Processing {file_name} with {len(f.keys())} weights") |
| 100 | + for weight_name in f.keys(): |
| 101 | + weight = f.get_tensor(weight_name) |
| 102 | + if any(weight_name.endswith(suffix) for suffix in SUFFIX_TO_QUANT): |
| 103 | + quant_weight, scale = create_quantized_param(weight, block_size) |
| 104 | + state_dict[weight_name] = quant_weight |
| 105 | + index[weight_name] = file_name |
| 106 | + |
| 107 | + # Reference: https://github.com/vllm-project/vllm/blob/v0.10.1/vllm/model_executor/layers/quantization/fp8.py#L295 # noqa: E501 |
| 108 | + if block_size[0] == -1 and block_size[1] == -1: |
| 109 | + # Tensor-wise |
| 110 | + state_dict[f"{weight_name}_scale"] = scale |
| 111 | + index[f"{weight_name}_scale"] = file_name |
| 112 | + else: |
| 113 | + # Block-wise |
| 114 | + state_dict[f"{weight_name}_scale_inv"] = scale |
| 115 | + index[f"{weight_name}_scale_inv"] = file_name |
| 116 | + else: |
| 117 | + state_dict[weight_name] = weight |
| 118 | + index[weight_name] = file_name |
| 119 | + count += 1 |
| 120 | + |
| 121 | + new_safetensor_file = os.path.join(output_path, file_name) |
| 122 | + save_file(state_dict, new_safetensor_file) |
| 123 | + return index |
| 124 | + |
| 125 | + |
| 126 | +def worker(i, file_names, input_path, output_path, block_size, return_dict): |
| 127 | + world_size = torch.cuda.device_count() |
| 128 | + for file_name in tqdm(file_names, desc=f"Worker {i}"): |
| 129 | + index = process_safetensor( |
| 130 | + i % world_size, file_name, input_path, output_path, block_size |
| 131 | + ) |
| 132 | + return_dict[file_name] = index |
| 133 | + |
| 134 | + |
| 135 | +def main(input_path, output_path, block_size): |
| 136 | + os.makedirs(output_path, exist_ok=True) |
| 137 | + model_index_file = os.path.join(input_path, "model.safetensors.index.json") |
| 138 | + with open(model_index_file, "r") as f: |
| 139 | + model_index = json.load(f) |
| 140 | + weight_map = model_index["weight_map"] |
| 141 | + safetensor_files = set(weight_map.values()) |
| 142 | + safetensor_files = list(sorted(safetensor_files)) |
| 143 | + print(f"Found {len(safetensor_files)} safetensor files") |
| 144 | + |
| 145 | + file_subsets = [ |
| 146 | + safetensor_files[i :: args.num_workers] for i in range(args.num_workers) |
| 147 | + ] |
| 148 | + manager = mp.Manager() |
| 149 | + return_dict = manager.dict() |
| 150 | + processes = [] |
| 151 | + for i in range(args.num_workers): |
| 152 | + p = mp.Process( |
| 153 | + target=worker, |
| 154 | + args=(i, file_subsets[i], input_path, output_path, block_size, return_dict), |
| 155 | + ) |
| 156 | + p.start() |
| 157 | + processes.append(p) |
| 158 | + for p in processes: |
| 159 | + p.join() |
| 160 | + |
| 161 | + index = {} |
| 162 | + for result in return_dict.values(): |
| 163 | + index.update(result) |
| 164 | + with open(os.path.join(output_path, "model.safetensors.index.json"), "w") as f: |
| 165 | + json.dump({"metadata": {}, "weight_map": index}, f, indent=2) |
| 166 | + |
| 167 | + # Copy config file |
| 168 | + for file in os.listdir(input_path): |
| 169 | + if ( |
| 170 | + file.endswith(".py") |
| 171 | + or file.endswith(".json") |
| 172 | + or file.endswith(".md") |
| 173 | + or file.endswith(".txt") |
| 174 | + ): |
| 175 | + src_path = os.path.join(input_path, file) |
| 176 | + dst_path = os.path.join(output_path, file) |
| 177 | + if os.path.exists(dst_path): |
| 178 | + continue |
| 179 | + print(f"cp {src_path} {dst_path}") |
| 180 | + shutil.copy2(src_path, dst_path) |
| 181 | + |
| 182 | + # Quantization config |
| 183 | + with open(os.path.join(output_path, "config.json"), "r") as f: |
| 184 | + config = json.load(f) |
| 185 | + config["quantization_config"] = { |
| 186 | + "activation_scheme": "dynamic", |
| 187 | + "fmt": "e4m3", |
| 188 | + "quant_method": "fp8", |
| 189 | + } |
| 190 | + if block_size[0] != -1 and block_size[1] != -1: |
| 191 | + config["quantization_config"]["weight_block_size"] = block_size |
| 192 | + print(f"quant config: {config['quantization_config']}") |
| 193 | + with open(os.path.join(output_path, "config.json"), "w") as f: |
| 194 | + json.dump(config, f, indent=4) |
| 195 | + |
| 196 | + |
| 197 | +if __name__ == "__main__": |
| 198 | + parser = ArgumentParser() |
| 199 | + parser.add_argument("--block_size", type=int, nargs=2, default=(128, 128)) |
| 200 | + parser.add_argument("--num_workers", type=int, default=32) |
| 201 | + parser.add_argument("--input_path", type=str, default="") |
| 202 | + parser.add_argument("--output_path", type=str, default="") |
| 203 | + args = parser.parse_args() |
| 204 | + print(args) |
| 205 | + with open(os.path.join(args.input_path, "config.json"), "r", encoding="utf8") as fp: |
| 206 | + json_data = json.load(fp) |
| 207 | + print(json_data) |
| 208 | + if "quantization_config" in json_data.keys(): |
| 209 | + raise AssertionError("NOT SUPPORT FP8 DS") |
| 210 | + |
| 211 | + main(args.input_path, args.output_path, args.block_size) |
0 commit comments