From 065c898c30740d4f52b35386296ed0c45e5ee5ed Mon Sep 17 00:00:00 2001 From: Jianbang Yang Date: Tue, 13 Jan 2026 18:50:04 +0800 Subject: [PATCH 1/5] Support DS V3/R1 fp8 cast to channel-wise int8 --- tools/fp8_cast_channel_int8.py | 251 +++++++++++++++++++++++++++++++++ 1 file changed, 251 insertions(+) create mode 100644 tools/fp8_cast_channel_int8.py diff --git a/tools/fp8_cast_channel_int8.py b/tools/fp8_cast_channel_int8.py new file mode 100644 index 00000000..12852e05 --- /dev/null +++ b/tools/fp8_cast_channel_int8.py @@ -0,0 +1,251 @@ +# This file is based on DeepSeek code (MIT License). +# +# Original code: +# Copyright (c) 2023 DeepSeek +# https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py +# https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py +# https://huggingface.co/meituan/DeepSeek-R1-Channel-INT8/blob/main/inference/bf16_cast_channel_int8.py (Meituan fork) +# +# Additional contributions: +# Copyright (c) 2026 Kunlunxin (Beijing) Technology Co., Ltd. (Kunlunxin) +# +# Modifications: +# - Merged implementations +# - Added multi-GPU parallel processing +# +# SPDX-License-Identifier: Apache-2.0 AND MIT + +import os +import json +from argparse import ArgumentParser +from glob import glob + +import torch +import triton +import triton.language as tl +from safetensors.torch import load_file, save_file +from huggingface_hub import snapshot_download + +import torch.multiprocessing as mp + + +@triton.jit +def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): + """ + Dequantizes weights using the provided scaling factors and stores the result. + + Args: + x_ptr (tl.pointer): Pointer to the quantized weights. + s_ptr (tl.pointer): Pointer to the scaling factors. + y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): Size of the block for tiling. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.load(s_ptr + pid_m * n + pid_n) + y = x * s + tl.store(y_ptr + offs, y, mask=mask) + + +def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: + """ + Dequantizes the given weight tensor using the provided scale tensor. + + Args: + x (torch.Tensor): The quantized weight tensor of shape (M, N). + s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size). + block_size (int, optional): The block size to use for dequantization. Defaults to 128. + + Returns: + torch.Tensor: The dequantized weight tensor of the same shape as `x`. + + Raises: + AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. + """ + assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous' + assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions' + M, N = x.size() + y = torch.empty_like(x, dtype=torch.get_default_dtype()) + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) + weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) + return y + + +def process_worker(rank, safetensor_files, fp8_path, int8_path, weight_map, return_dict): + """ + Process worker + """ + torch.cuda.set_device(rank) + quant_count = 0 + new_weight_map = {} + for safetensor_file in safetensor_files: + file_name = os.path.basename(safetensor_file) + print(f"[GPU {rank}] processing {file_name}") + state_dict = load_file(safetensor_file, device=f"cuda:{rank}") + new_state_dict = {} + for weight_name, weight in state_dict.items(): + scale_inv_name = f"{weight_name}_scale_inv" + if scale_inv_name in weight_map: + quant_count += 1 + # 1. fp8 dequant to bf16 + scale_inv = get_tensor(rank, scale_inv_name, weight_map, fp8_path) + weight_bf16 = weight_dequant(weight, scale_inv) + # 2. bf16 quant to int8 + int8_weight, scale_inv = weight_quant(weight_bf16) + new_state_dict[weight_name] = int8_weight + new_scale_name = scale_inv_name.replace("_scale_inv", "_scale") + new_state_dict[new_scale_name] = scale_inv + new_weight_map[weight_name] = file_name + new_weight_map[new_scale_name] = file_name + else: + if weight_name.endswith("_scale_inv"): + continue + new_state_dict[weight_name] = weight + new_weight_map[weight_name] = file_name + + new_safetensor_file = os.path.join(int8_path, file_name) + save_file(new_state_dict, new_safetensor_file) + return_dict[rank] = (quant_count, new_weight_map) + + +# Helper function to get tensor from the correct file +def get_tensor(rank, tensor_name, weight_map, fp8_path): + """ + Retrieves a tensor from the cached safetensor files or loads it from disk if not cached. + + Args: + tensor_name (str): The name of the tensor to retrieve. + + Returns: + torch.Tensor: The retrieved tensor. + + Raises: + KeyError: If the tensor does not exist in the safetensor file. + """ + torch.cuda.set_device(rank) + file_name = weight_map[tensor_name] + loaded_files = {} + file_path = os.path.join(fp8_path, file_name) + loaded_files[file_name] = load_file(file_path, device=f"cuda:{rank}") + return loaded_files[file_name][tensor_name] + + +def weight_quant(tensor: torch.Tensor): + """ + Quantize a 2D tensor row-wise from BF16/FP32 to INT8. + Args: + tensor (torch.Tensor): Input 2D tensor. + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - Quantized INT8 tensor. + - Scale tensor (float32) used for quantization. + """ + assert tensor.dim() == 2 + qmax = 127.0 + abs_max = torch.abs(tensor).max(dim=1, keepdim=True)[0] # [rows, 1] + scale = abs_max / qmax # [rows, 1] + assert scale.shape == (tensor.shape[0], 1) + quantized = torch.round(tensor / scale) + quantized = torch.clamp(quantized, -qmax, qmax) + return quantized.to(torch.int8), scale.to(torch.float32) + + +def main(fp8_path, int8_path, model_name="deepseek-ai/DeepSeek-R1"): + """ + Run the FP8-to-INT8 per-channel quantization pipeline. + + This function: + 1. Downloads model index/config if missing. + 2. Loads FP8 safetensors. + 3. Dequantizes FP8 → BF16, then quantizes BF16 → INT8. + 4. Saves quantized safetensors and updates model index. + + Args: + fp8_path (str): Path to directory containing FP8 safetensors. + int8_path (str): Output directory to save INT8 safetensors. + model_name (str, optional): HuggingFace model repo name. Defaults to DeepSeek-R1. + """ + torch.set_default_dtype(torch.bfloat16) + os.makedirs(int8_path, exist_ok=True) + model_index_file = os.path.join(int8_path, "model.safetensors.index.json") + config_file = os.path.join(int8_path, "config.json") + + if not os.path.exists(model_index_file) or not os.path.exists(config_file): + snapshot_download( + repo_id=model_name, + ignore_patterns=["*.safetensors"], + local_dir=int8_path, + local_dir_use_symlinks=False + ) + print(f"model index file and config file downloaded to {int8_path}") + + # modify config.json and save it + config = json.load(open(config_file)) + # delete quantization_config + config.pop("quantization_config", None) + with open(config_file, "w", encoding="utf-8") as f: + json.dump(config, f, indent=2, ensure_ascii=False, sort_keys=True) + print(f"config.json modified and saved to {config_file}") + + with open(model_index_file, "r") as f: + model_index = json.load(f) + weight_map = model_index["weight_map"] + scale_count = len([key for key in weight_map.keys() if key.endswith("_scale_inv")]) + + safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors"))) + safetensor_files.sort() + quant_count = 0 + new_weight_map = {} + + num_gpus = torch.cuda.device_count() + files_per_gpu = [[] for _ in range(num_gpus)] + for idx, sf in enumerate(safetensor_files): + files_per_gpu[idx % num_gpus].append(sf) + + manager = mp.Manager() + return_dict = manager.dict() + processes = [] + for rank in range(num_gpus): + p = mp.Process(target=process_worker, + args=(rank, files_per_gpu[rank], fp8_path, int8_path, weight_map, return_dict)) + p.start() + processes.append(p) + for p in processes: + p.join() + + for rank in range(num_gpus): + qc, wm = return_dict[rank] + quant_count += qc + new_weight_map.update(wm) + assert quant_count == scale_count + print(f"{quant_count} weights are quantized.") + + # modify model.safetensors.index.json + with open(model_index_file, "r") as f: + model_index = json.load(f) + model_index["weight_map"] = new_weight_map + with open(model_index_file, "w", encoding="utf-8") as f: + json.dump(model_index, f, indent=2, ensure_ascii=False, sort_keys=True) + print(f"model.safetensors.index.json modified and saved to {model_index_file}") + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--input-fp8-hf-path", type=str, required=True) + parser.add_argument("--output-int8-hf-path", type=str, required=True) + parser.add_argument("--model-name", type=str, default="deepseek-ai/DeepSeek-R1") + + args = parser.parse_args() + main(args.input_fp8_hf_path, args.output_int8_hf_path, args.model_name) + print("done") From 61cc768a9a87f2eac940f3161a91e8df971ef42b Mon Sep 17 00:00:00 2001 From: Jianbang Yang Date: Tue, 13 Jan 2026 20:06:26 +0800 Subject: [PATCH 2/5] refine --- tools/fp8_cast_channel_int8.py | 187 +++++++++++++++++++++------------ 1 file changed, 122 insertions(+), 65 deletions(-) diff --git a/tools/fp8_cast_channel_int8.py b/tools/fp8_cast_channel_int8.py index 12852e05..51dab98a 100644 --- a/tools/fp8_cast_channel_int8.py +++ b/tools/fp8_cast_channel_int8.py @@ -15,18 +15,17 @@ # # SPDX-License-Identifier: Apache-2.0 AND MIT -import os import json +import os +import shutil from argparse import ArgumentParser from glob import glob import torch +import torch.multiprocessing as mp import triton import triton.language as tl -from safetensors.torch import load_file, save_file -from huggingface_hub import snapshot_download - -import torch.multiprocessing as mp +from safetensors.torch import safe_open, save_file @triton.jit @@ -58,7 +57,9 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): tl.store(y_ptr + offs, y, mask=mask) -def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: +def weight_dequant( + x: torch.Tensor, s: torch.Tensor, block_size: int = 128 +) -> torch.Tensor: """ Dequantizes the given weight tensor using the provided scale tensor. @@ -73,56 +74,72 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t Raises: AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. """ - assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous' - assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions' + assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous" + assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions" M, N = x.size() y = torch.empty_like(x, dtype=torch.get_default_dtype()) - grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(N, meta["BLOCK_SIZE"]), + ) weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) return y -def process_worker(rank, safetensor_files, fp8_path, int8_path, weight_map, return_dict): +def process_worker( + worker_id, safetensor_files, fp8_path, int8_path, weight_map, return_dict +): """ - Process worker + Process worker. + + Each worker process is responsible for a subset of safetensor files: + - FP8 → BF16 dequantization + - BF16 → INT8 quantization + - Generation of the updated weight_map """ + num_gpus = torch.cuda.device_count() + rank = worker_id % num_gpus torch.cuda.set_device(rank) quant_count = 0 new_weight_map = {} for safetensor_file in safetensor_files: file_name = os.path.basename(safetensor_file) - print(f"[GPU {rank}] processing {file_name}") - state_dict = load_file(safetensor_file, device=f"cuda:{rank}") - new_state_dict = {} - for weight_name, weight in state_dict.items(): - scale_inv_name = f"{weight_name}_scale_inv" - if scale_inv_name in weight_map: - quant_count += 1 - # 1. fp8 dequant to bf16 - scale_inv = get_tensor(rank, scale_inv_name, weight_map, fp8_path) - weight_bf16 = weight_dequant(weight, scale_inv) - # 2. bf16 quant to int8 - int8_weight, scale_inv = weight_quant(weight_bf16) - new_state_dict[weight_name] = int8_weight - new_scale_name = scale_inv_name.replace("_scale_inv", "_scale") - new_state_dict[new_scale_name] = scale_inv - new_weight_map[weight_name] = file_name - new_weight_map[new_scale_name] = file_name - else: - if weight_name.endswith("_scale_inv"): - continue - new_state_dict[weight_name] = weight - new_weight_map[weight_name] = file_name + print(f"[Worker {worker_id}][GPU {rank}] processing {file_name}") + with safe_open(safetensor_file, framework="pt", device=f"cuda:{rank}") as f: + new_state_dict = {} + keys = set(f.keys()) + for weight_name in keys: + weight = f.get_tensor(weight_name) + scale_inv_name = f"{weight_name}_scale_inv" + if scale_inv_name in weight_map: + quant_count += 1 + # 1. fp8 dequant to bf16 + scale_inv = get_tensor_from_file( + rank, scale_inv_name, weight_map, fp8_path + ) + weight_bf16 = weight_dequant(weight, scale_inv) + # 2. bf16 quant to int8 + int8_weight, scale_inv = weight_quant(weight_bf16) + new_state_dict[weight_name] = int8_weight + new_scale_name = scale_inv_name.replace("_scale_inv", "_scale") + new_state_dict[new_scale_name] = scale_inv + new_weight_map[weight_name] = file_name + new_weight_map[new_scale_name] = file_name + else: + if weight_name.endswith("_scale_inv"): + continue + new_state_dict[weight_name] = weight + new_weight_map[weight_name] = file_name new_safetensor_file = os.path.join(int8_path, file_name) save_file(new_state_dict, new_safetensor_file) - return_dict[rank] = (quant_count, new_weight_map) + return_dict[worker_id] = (quant_count, new_weight_map) # Helper function to get tensor from the correct file -def get_tensor(rank, tensor_name, weight_map, fp8_path): +def get_tensor_from_file(rank, tensor_name, weight_map, fp8_path): """ - Retrieves a tensor from the cached safetensor files or loads it from disk if not cached. + Retrieves a tensor from mmap safe_tensors Args: tensor_name (str): The name of the tensor to retrieve. @@ -135,10 +152,10 @@ def get_tensor(rank, tensor_name, weight_map, fp8_path): """ torch.cuda.set_device(rank) file_name = weight_map[tensor_name] - loaded_files = {} file_path = os.path.join(fp8_path, file_name) - loaded_files[file_name] = load_file(file_path, device=f"cuda:{rank}") - return loaded_files[file_name][tensor_name] + + with safe_open(file_path, framework="pt", device=f"cuda:{rank}") as f: + return f.get_tensor(tensor_name) def weight_quant(tensor: torch.Tensor): @@ -161,12 +178,12 @@ def weight_quant(tensor: torch.Tensor): return quantized.to(torch.int8), scale.to(torch.float32) -def main(fp8_path, int8_path, model_name="deepseek-ai/DeepSeek-R1"): +def main(fp8_path, int8_path, num_workers): """ Run the FP8-to-INT8 per-channel quantization pipeline. This function: - 1. Downloads model index/config if missing. + 1. Copy the config file 2. Loads FP8 safetensors. 3. Dequantizes FP8 → BF16, then quantizes BF16 → INT8. 4. Saves quantized safetensors and updates model index. @@ -174,58 +191,98 @@ def main(fp8_path, int8_path, model_name="deepseek-ai/DeepSeek-R1"): Args: fp8_path (str): Path to directory containing FP8 safetensors. int8_path (str): Output directory to save INT8 safetensors. - model_name (str, optional): HuggingFace model repo name. Defaults to DeepSeek-R1. + num_workers (int): Number of processing workers """ torch.set_default_dtype(torch.bfloat16) os.makedirs(int8_path, exist_ok=True) model_index_file = os.path.join(int8_path, "model.safetensors.index.json") config_file = os.path.join(int8_path, "config.json") - if not os.path.exists(model_index_file) or not os.path.exists(config_file): - snapshot_download( - repo_id=model_name, - ignore_patterns=["*.safetensors"], - local_dir=int8_path, - local_dir_use_symlinks=False - ) - print(f"model index file and config file downloaded to {int8_path}") + for fname in os.listdir(fp8_path): + if fname.endswith(".safetensors"): + continue + src = os.path.join(fp8_path, fname) + dst = os.path.join(int8_path, fname) + if os.path.isdir(src): + print(f"cp -r {src} {dst}") + shutil.copytree(src, dst, dirs_exist_ok=True) + elif os.path.isfile(src): + print(f"cp {src} {dst}") + shutil.copy2(src, dst) # modify config.json and save it config = json.load(open(config_file)) # delete quantization_config config.pop("quantization_config", None) + config["quantization_config"] = { + "config_groups": { + "group_0": { + "input_activations": { + "actorder": None, + "block_structure": None, + "dynamic": True, + "group_size": None, + "num_bits": 8, + "observer": "memoryless", + "observer_kwargs": {}, + "strategy": "token", + "symmetric": True, + "type": "int", + }, + "output_activations": None, + "weights": { + "actorder": None, + "block_structure": None, + "dynamic": False, + "group_size": None, + "num_bits": 8, + "observer": "minmax", + "observer_kwargs": {}, + "strategy": "channel", + "symmetric": True, + "type": "int", + }, + "targets": ["Linear"], + } + }, + "format": "int-quantized", + "ignore": ["lm_head"], + "kv_cache_scheme": None, + "quant_method": "compressed-tensors", + "quantization_status": "compressed", + } + with open(config_file, "w", encoding="utf-8") as f: - json.dump(config, f, indent=2, ensure_ascii=False, sort_keys=True) + json.dump(config, f, indent=2, ensure_ascii=False, sort_keys=True) print(f"config.json modified and saved to {config_file}") with open(model_index_file, "r") as f: model_index = json.load(f) weight_map = model_index["weight_map"] scale_count = len([key for key in weight_map.keys() if key.endswith("_scale_inv")]) - + safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors"))) safetensor_files.sort() quant_count = 0 new_weight_map = {} - num_gpus = torch.cuda.device_count() - files_per_gpu = [[] for _ in range(num_gpus)] - for idx, sf in enumerate(safetensor_files): - files_per_gpu[idx % num_gpus].append(sf) + file_subsets = [safetensor_files[i::num_workers] for i in range(num_workers)] manager = mp.Manager() return_dict = manager.dict() processes = [] - for rank in range(num_gpus): - p = mp.Process(target=process_worker, - args=(rank, files_per_gpu[rank], fp8_path, int8_path, weight_map, return_dict)) + for i in range(num_workers): + p = mp.Process( + target=process_worker, + args=(i, file_subsets[i], fp8_path, int8_path, weight_map, return_dict), + ) p.start() processes.append(p) for p in processes: p.join() - for rank in range(num_gpus): - qc, wm = return_dict[rank] + for i in range(num_workers): + qc, wm = return_dict[i] quant_count += qc new_weight_map.update(wm) assert quant_count == scale_count @@ -242,10 +299,10 @@ def main(fp8_path, int8_path, model_name="deepseek-ai/DeepSeek-R1"): if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("--input-fp8-hf-path", type=str, required=True) - parser.add_argument("--output-int8-hf-path", type=str, required=True) - parser.add_argument("--model-name", type=str, default="deepseek-ai/DeepSeek-R1") + parser.add_argument("g", type=str, required=True) + parser.add_argument("--output-int8-path", type=str, required=True) + parser.add_argument("--num-workers", type=int, default=32) args = parser.parse_args() - main(args.input_fp8_hf_path, args.output_int8_hf_path, args.model_name) + main(args.input_fp8_path, args.output_int8_path, args.num_workers) print("done") From f65017192276c7f1c7572522e7757f9a9bf23d05 Mon Sep 17 00:00:00 2001 From: Jianbang Yang Date: Tue, 13 Jan 2026 20:19:36 +0800 Subject: [PATCH 3/5] fix code style check --- tools/fp8_cast_channel_int8.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tools/fp8_cast_channel_int8.py b/tools/fp8_cast_channel_int8.py index 51dab98a..b243ca8c 100644 --- a/tools/fp8_cast_channel_int8.py +++ b/tools/fp8_cast_channel_int8.py @@ -4,7 +4,7 @@ # Copyright (c) 2023 DeepSeek # https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py # https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py -# https://huggingface.co/meituan/DeepSeek-R1-Channel-INT8/blob/main/inference/bf16_cast_channel_int8.py (Meituan fork) +# https://huggingface.co/meituan/DeepSeek-R1-Channel-INT8/blob/main/inference/bf16_cast_channel_int8.py (Meituan fork) # noqa: E501 # # Additional contributions: # Copyright (c) 2026 Kunlunxin (Beijing) Technology Co., Ltd. (Kunlunxin) @@ -66,19 +66,21 @@ def weight_dequant( Args: x (torch.Tensor): The quantized weight tensor of shape (M, N). s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size). - block_size (int, optional): The block size to use for dequantization. Defaults to 128. + block_size (int, optional): The block size to use for dequantization. + Defaults to 128. Returns: torch.Tensor: The dequantized weight tensor of the same shape as `x`. Raises: - AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. + AssertionError: If `x` or `s` are not contiguous or if their dimensions + are not 2. """ assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous" assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions" M, N = x.size() y = torch.empty_like(x, dtype=torch.get_default_dtype()) - grid = lambda meta: ( + grid = lambda meta: ( # noqa: E731 triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"]), ) From 46f5044e4009d7dee130ab41798b06a3a7699e4b Mon Sep 17 00:00:00 2001 From: Jianbang Yang Date: Tue, 13 Jan 2026 20:30:59 +0800 Subject: [PATCH 4/5] use weight_dequant in angelslim.compressor.quant.core.quant_func --- tools/fp8_cast_channel_int8.py | 65 ++-------------------------------- 1 file changed, 2 insertions(+), 63 deletions(-) diff --git a/tools/fp8_cast_channel_int8.py b/tools/fp8_cast_channel_int8.py index b243ca8c..45543684 100644 --- a/tools/fp8_cast_channel_int8.py +++ b/tools/fp8_cast_channel_int8.py @@ -3,7 +3,6 @@ # Original code: # Copyright (c) 2023 DeepSeek # https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py -# https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py # https://huggingface.co/meituan/DeepSeek-R1-Channel-INT8/blob/main/inference/bf16_cast_channel_int8.py (Meituan fork) # noqa: E501 # # Additional contributions: @@ -23,69 +22,9 @@ import torch import torch.multiprocessing as mp -import triton -import triton.language as tl from safetensors.torch import safe_open, save_file - -@triton.jit -def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): - """ - Dequantizes weights using the provided scaling factors and stores the result. - - Args: - x_ptr (tl.pointer): Pointer to the quantized weights. - s_ptr (tl.pointer): Pointer to the scaling factors. - y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. - M (int): Number of rows in the weight matrix. - N (int): Number of columns in the weight matrix. - BLOCK_SIZE (tl.constexpr): Size of the block for tiling. - - Returns: - None - """ - pid_m = tl.program_id(axis=0) - pid_n = tl.program_id(axis=1) - n = tl.cdiv(N, BLOCK_SIZE) - offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - offs = offs_m[:, None] * N + offs_n[None, :] - mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) - x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) - s = tl.load(s_ptr + pid_m * n + pid_n) - y = x * s - tl.store(y_ptr + offs, y, mask=mask) - - -def weight_dequant( - x: torch.Tensor, s: torch.Tensor, block_size: int = 128 -) -> torch.Tensor: - """ - Dequantizes the given weight tensor using the provided scale tensor. - - Args: - x (torch.Tensor): The quantized weight tensor of shape (M, N). - s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size). - block_size (int, optional): The block size to use for dequantization. - Defaults to 128. - - Returns: - torch.Tensor: The dequantized weight tensor of the same shape as `x`. - - Raises: - AssertionError: If `x` or `s` are not contiguous or if their dimensions - are not 2. - """ - assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous" - assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions" - M, N = x.size() - y = torch.empty_like(x, dtype=torch.get_default_dtype()) - grid = lambda meta: ( # noqa: E731 - triton.cdiv(M, meta["BLOCK_SIZE"]), - triton.cdiv(N, meta["BLOCK_SIZE"]), - ) - weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) - return y +from angelslim.compressor.quant.core.quant_func import weight_dequant def process_worker( @@ -301,7 +240,7 @@ def main(fp8_path, int8_path, num_workers): if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("g", type=str, required=True) + parser.add_argument("--input-fp8-path", type=str, required=True) parser.add_argument("--output-int8-path", type=str, required=True) parser.add_argument("--num-workers", type=int, default=32) From 21adc219c2ddbfcd8430c8e7c884844f9f0df66a Mon Sep 17 00:00:00 2001 From: Jianbang Yang Date: Tue, 13 Jan 2026 20:32:57 +0800 Subject: [PATCH 5/5] use weight_dequant in angelslim.compressor.quant.core.quant_func --- tools/fp8_cast_channel_int8.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/fp8_cast_channel_int8.py b/tools/fp8_cast_channel_int8.py index 45543684..b2ead32c 100644 --- a/tools/fp8_cast_channel_int8.py +++ b/tools/fp8_cast_channel_int8.py @@ -209,6 +209,7 @@ def main(fp8_path, int8_path, num_workers): file_subsets = [safetensor_files[i::num_workers] for i in range(num_workers)] + mp.set_start_method("spawn", force=True) manager = mp.Manager() return_dict = manager.dict() processes = []