From 8e28ccfd8104463686f66aa60d500baf47dc301f Mon Sep 17 00:00:00 2001 From: linchuanxie Date: Wed, 21 Jan 2026 21:14:06 +0800 Subject: [PATCH 1/5] support ptpc-int8 --- tools/bf16_cast_channel_int8.py | 250 ++++++++++++++++++++++++++++++++ 1 file changed, 250 insertions(+) create mode 100644 tools/bf16_cast_channel_int8.py diff --git a/tools/bf16_cast_channel_int8.py b/tools/bf16_cast_channel_int8.py new file mode 100644 index 00000000..263fc3bd --- /dev/null +++ b/tools/bf16_cast_channel_int8.py @@ -0,0 +1,250 @@ +# This file is based on DeepSeek code (MIT License). +# +# Original code: +# Copyright (c) 2023 DeepSeek +# https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py +# https://huggingface.co/meituan/DeepSeek-R1-Channel-INT8/blob/main/inference/bf16_cast_channel_int8.py (Meituan fork) # noqa: E501 +# +# Additional contributions: +# Copyright (c) 2026 Kunlunxin (Beijing) Technology Co., Ltd. (Kunlunxin) +# +# Modifications: +# - Merged implementations +# - Added multi-GPU parallel processing +# +# SPDX-License-Identifier: Apache-2.0 AND MIT + +import json +import os +import shutil +from argparse import ArgumentParser +from glob import glob + +import torch +import torch.multiprocessing as mp +from safetensors.torch import safe_open, save_file + +SUFFIX_TO_QUANT = [ + ".gate_proj.weight", + ".down_proj.weight", + ".up_proj.weight", + ".q_a_proj.weight", + ".q_b_proj.weight", + ".kv_b_proj.weight", + ".kv_a_proj_with_mqa.weight", + ".o_proj.weight", + ".indexer.wq_b.weight", + ".indexer.wk.weight", +] + +def process_worker( + worker_id, safetensor_files, bf16_path, int8_path, weight_map, return_dict +): + """ + Process worker. + + Each worker process is responsible for a subset of safetensor files: + - FP8 → BF16 dequantization + - BF16 → INT8 quantization + - Generation of the updated weight_map + """ + num_gpus = torch.cuda.device_count() + rank = worker_id % num_gpus + torch.cuda.set_device(rank) + quant_count = 0 + new_weight_map = {} + for safetensor_file in safetensor_files: + file_name = os.path.basename(safetensor_file) + print(f"[Worker {worker_id}][GPU {rank}] processing {file_name}") + with safe_open(safetensor_file, framework="pt", device=f"cuda:{rank}") as f: + new_state_dict = {} + keys = set(f.keys()) + for weight_name in keys: + weight = f.get_tensor(weight_name) + if any(weight_name.endswith(suffix) for suffix in SUFFIX_TO_QUANT): + quant_count += 1 + + int8_weight, scale_inv = weight_quant(weight) + new_state_dict[weight_name] = int8_weight + new_scale_name = f"{weight_name}_scale" + new_state_dict[new_scale_name] = scale_inv + new_weight_map[weight_name] = file_name + new_weight_map[new_scale_name] = file_name + else: + new_state_dict[weight_name] = weight + new_weight_map[weight_name] = file_name + + new_safetensor_file = os.path.join(int8_path, file_name) + save_file(new_state_dict, new_safetensor_file) + return_dict[worker_id] = (quant_count, new_weight_map) + + +# Helper function to get tensor from the correct file +def get_tensor_from_file(rank, tensor_name, weight_map, bf16_path): + """ + Retrieves a tensor from mmap safe_tensors + + Args: + tensor_name (str): The name of the tensor to retrieve. + + Returns: + torch.Tensor: The retrieved tensor. + + Raises: + KeyError: If the tensor does not exist in the safetensor file. + """ + torch.cuda.set_device(rank) + file_name = weight_map[tensor_name] + file_path = os.path.join(bf16_path, file_name) + + with safe_open(file_path, framework="pt", device=f"cuda:{rank}") as f: + return f.get_tensor(tensor_name) + + +def weight_quant(tensor: torch.Tensor): + """ + Quantize a 2D tensor row-wise from BF16/FP32 to INT8. + Args: + tensor (torch.Tensor): Input 2D tensor. + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - Quantized INT8 tensor. + - Scale tensor (float32) used for quantization. + """ + assert tensor.dim() == 2 + qmax = 127.0 + abs_max = torch.abs(tensor).max(dim=1, keepdim=True)[0] # [rows, 1] + scale = abs_max / qmax # [rows, 1] + assert scale.shape == (tensor.shape[0], 1) + quantized = torch.round(tensor / scale) + quantized = torch.clamp(quantized, -qmax, qmax) + return quantized.to(torch.int8), scale.to(torch.float32) + + +def main(bf16_path, int8_path, num_workers): + """ + Run the FP8-to-INT8 per-channel quantization pipeline. + + This function: + 1. Copy the config file + 2. Loads FP8 safetensors. + 3. Dequantizes FP8 → BF16, then quantizes BF16 → INT8. + 4. Saves quantized safetensors and updates model index. + + Args: + bf16_path (str): Path to directory containing FP8 safetensors. + int8_path (str): Output directory to save INT8 safetensors. + num_workers (int): Number of processing workers + """ + torch.set_default_dtype(torch.bfloat16) + os.makedirs(int8_path, exist_ok=True) + model_index_file = os.path.join(int8_path, "model.safetensors.index.json") + config_file = os.path.join(int8_path, "config.json") + + for fname in os.listdir(bf16_path): + if fname.endswith(".safetensors"): + continue + src = os.path.join(bf16_path, fname) + dst = os.path.join(int8_path, fname) + if os.path.isdir(src): + print(f"cp -r {src} {dst}") + shutil.copytree(src, dst, dirs_exist_ok=True) + elif os.path.isfile(src): + print(f"cp {src} {dst}") + shutil.copy2(src, dst) + + # modify config.json and save it + config = json.load(open(config_file)) + # delete quantization_config + config.pop("quantization_config", None) + config["quantization_config"] = { + "config_groups": { + "group_0": { + "input_activations": { + "actorder": None, + "block_structure": None, + "dynamic": True, + "group_size": None, + "num_bits": 8, + "observer": "memoryless", + "observer_kwargs": {}, + "strategy": "token", + "symmetric": True, + "type": "int", + }, + "output_activations": None, + "weights": { + "actorder": None, + "block_structure": None, + "dynamic": False, + "group_size": None, + "num_bits": 8, + "observer": "minmax", + "observer_kwargs": {}, + "strategy": "channel", + "symmetric": True, + "type": "int", + }, + "targets": ["Linear"], + } + }, + "format": "int-quantized", + "ignore": ["lm_head"], + "kv_cache_scheme": None, + "quant_method": "compressed-tensors", + "quantization_status": "compressed", + } + + with open(config_file, "w", encoding="utf-8") as f: + json.dump(config, f, indent=2, ensure_ascii=False, sort_keys=True) + print(f"config.json modified and saved to {config_file}") + + with open(model_index_file, "r") as f: + model_index = json.load(f) + weight_map = model_index["weight_map"] + + safetensor_files = list(glob(os.path.join(bf16_path, "*.safetensors"))) + safetensor_files.sort() + quant_count = 0 + new_weight_map = {} + + file_subsets = [safetensor_files[i::num_workers] for i in range(num_workers)] + + mp.set_start_method("spawn", force=True) + manager = mp.Manager() + return_dict = manager.dict() + processes = [] + for i in range(num_workers): + p = mp.Process( + target=process_worker, + args=(i, file_subsets[i], bf16_path, int8_path, weight_map, return_dict), + ) + p.start() + processes.append(p) + for p in processes: + p.join() + + for i in range(num_workers): + qc, wm = return_dict[i] + quant_count += qc + new_weight_map.update(wm) + print(f"{quant_count} weights are quantized.") + + # modify model.safetensors.index.json + with open(model_index_file, "r") as f: + model_index = json.load(f) + model_index["weight_map"] = new_weight_map + with open(model_index_file, "w", encoding="utf-8") as f: + json.dump(model_index, f, indent=2, ensure_ascii=False, sort_keys=True) + print(f"model.safetensors.index.json modified and saved to {model_index_file}") + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--input-bf16-path", type=str, required=True) + parser.add_argument("--output-int8-path", type=str, required=True) + parser.add_argument("--num-workers", type=int, default=32) + + args = parser.parse_args() + main(args.input_bf16_path, args.output_int8_path, args.num_workers) + print("done") From 52973e515f72bbe2e979356d641c41b32fc6c41e Mon Sep 17 00:00:00 2001 From: linchuanxie Date: Wed, 21 Jan 2026 21:35:50 +0800 Subject: [PATCH 2/5] fix code style --- tools/bf16_cast_channel_int8.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tools/bf16_cast_channel_int8.py b/tools/bf16_cast_channel_int8.py index 263fc3bd..fbcdbd25 100644 --- a/tools/bf16_cast_channel_int8.py +++ b/tools/bf16_cast_channel_int8.py @@ -37,6 +37,7 @@ ".indexer.wk.weight", ] + def process_worker( worker_id, safetensor_files, bf16_path, int8_path, weight_map, return_dict ): @@ -61,9 +62,9 @@ def process_worker( keys = set(f.keys()) for weight_name in keys: weight = f.get_tensor(weight_name) - if any(weight_name.endswith(suffix) for suffix in SUFFIX_TO_QUANT): + if any(weight_name.endswith(suffix) for suffix in SUFFIX_TO_QUANT): quant_count += 1 - + int8_weight, scale_inv = weight_quant(weight) new_state_dict[weight_name] = int8_weight new_scale_name = f"{weight_name}_scale" From 8fa2a2b4f86ea1de1a9e0a6eac760580ccb9ed64 Mon Sep 17 00:00:00 2001 From: linchuanxie Date: Thu, 22 Jan 2026 20:05:44 +0800 Subject: [PATCH 3/5] merge fp8 and bf16 codes --- ...t8.py => bf16_or_fp8_cast_channel_int8.py} | 66 +++-- tools/fp8_cast_channel_int8.py | 250 ------------------ 2 files changed, 50 insertions(+), 266 deletions(-) rename tools/{bf16_cast_channel_int8.py => bf16_or_fp8_cast_channel_int8.py} (81%) delete mode 100644 tools/fp8_cast_channel_int8.py diff --git a/tools/bf16_cast_channel_int8.py b/tools/bf16_or_fp8_cast_channel_int8.py similarity index 81% rename from tools/bf16_cast_channel_int8.py rename to tools/bf16_or_fp8_cast_channel_int8.py index fbcdbd25..a1ecc3ab 100644 --- a/tools/bf16_cast_channel_int8.py +++ b/tools/bf16_or_fp8_cast_channel_int8.py @@ -24,14 +24,21 @@ import torch.multiprocessing as mp from safetensors.torch import safe_open, save_file +from angelslim.compressor.quant.core.quant_func import weight_dequant + SUFFIX_TO_QUANT = [ + ".gate_and_up_proj.weight", ".gate_proj.weight", - ".down_proj.weight", ".up_proj.weight", + ".down_proj.weight", ".q_a_proj.weight", ".q_b_proj.weight", - ".kv_b_proj.weight", ".kv_a_proj_with_mqa.weight", + ".kv_b_proj.weight", + ".qkv_proj.weight", + ".q_proj.weight", + ".k_proj.weight", + ".v_proj.weight", ".o_proj.weight", ".indexer.wq_b.weight", ".indexer.wk.weight", @@ -39,7 +46,13 @@ def process_worker( - worker_id, safetensor_files, bf16_path, int8_path, weight_map, return_dict + worker_id, + safetensor_files, + input_path, + int8_path, + weight_map, + return_dict, + input_type="bf16", ): """ Process worker. @@ -64,14 +77,23 @@ def process_worker( weight = f.get_tensor(weight_name) if any(weight_name.endswith(suffix) for suffix in SUFFIX_TO_QUANT): quant_count += 1 - - int8_weight, scale_inv = weight_quant(weight) + if input_type == "fp8": + scale_inv_name = f"{weight_name}_scale_inv" + scale_inv = get_tensor_from_file( + rank, scale_inv_name, weight_map, input_path + ) + weight_bf16 = weight_dequant(weight, scale_inv) + else: + weight_bf16 = weight + int8_weight, scale_inv = weight_quant(weight_bf16) new_state_dict[weight_name] = int8_weight new_scale_name = f"{weight_name}_scale" new_state_dict[new_scale_name] = scale_inv new_weight_map[weight_name] = file_name new_weight_map[new_scale_name] = file_name else: + if weight_name.endswith("_scale_inv"): + continue new_state_dict[weight_name] = weight new_weight_map[weight_name] = file_name @@ -81,7 +103,7 @@ def process_worker( # Helper function to get tensor from the correct file -def get_tensor_from_file(rank, tensor_name, weight_map, bf16_path): +def get_tensor_from_file(rank, tensor_name, weight_map, input_path): """ Retrieves a tensor from mmap safe_tensors @@ -96,7 +118,7 @@ def get_tensor_from_file(rank, tensor_name, weight_map, bf16_path): """ torch.cuda.set_device(rank) file_name = weight_map[tensor_name] - file_path = os.path.join(bf16_path, file_name) + file_path = os.path.join(input_path, file_name) with safe_open(file_path, framework="pt", device=f"cuda:{rank}") as f: return f.get_tensor(tensor_name) @@ -122,7 +144,7 @@ def weight_quant(tensor: torch.Tensor): return quantized.to(torch.int8), scale.to(torch.float32) -def main(bf16_path, int8_path, num_workers): +def main(input_path, int8_path, num_workers): """ Run the FP8-to-INT8 per-channel quantization pipeline. @@ -133,7 +155,7 @@ def main(bf16_path, int8_path, num_workers): 4. Saves quantized safetensors and updates model index. Args: - bf16_path (str): Path to directory containing FP8 safetensors. + input_path (str): Path to directory containing FP8 safetensors. int8_path (str): Output directory to save INT8 safetensors. num_workers (int): Number of processing workers """ @@ -142,10 +164,10 @@ def main(bf16_path, int8_path, num_workers): model_index_file = os.path.join(int8_path, "model.safetensors.index.json") config_file = os.path.join(int8_path, "config.json") - for fname in os.listdir(bf16_path): + for fname in os.listdir(input_path): if fname.endswith(".safetensors"): continue - src = os.path.join(bf16_path, fname) + src = os.path.join(input_path, fname) dst = os.path.join(int8_path, fname) if os.path.isdir(src): print(f"cp -r {src} {dst}") @@ -157,7 +179,11 @@ def main(bf16_path, int8_path, num_workers): # modify config.json and save it config = json.load(open(config_file)) # delete quantization_config - config.pop("quantization_config", None) + quant_config = config.pop("quantization_config", None) + input_type = "bf16" + if quant_config is not None: + input_type = quant_config.get("quant_method", input_type) + print("input_type", input_type) config["quantization_config"] = { "config_groups": { "group_0": { @@ -204,7 +230,7 @@ def main(bf16_path, int8_path, num_workers): model_index = json.load(f) weight_map = model_index["weight_map"] - safetensor_files = list(glob(os.path.join(bf16_path, "*.safetensors"))) + safetensor_files = list(glob(os.path.join(input_path, "*.safetensors"))) safetensor_files.sort() quant_count = 0 new_weight_map = {} @@ -218,7 +244,15 @@ def main(bf16_path, int8_path, num_workers): for i in range(num_workers): p = mp.Process( target=process_worker, - args=(i, file_subsets[i], bf16_path, int8_path, weight_map, return_dict), + args=( + i, + file_subsets[i], + input_path, + int8_path, + weight_map, + return_dict, + input_type, + ), ) p.start() processes.append(p) @@ -242,10 +276,10 @@ def main(bf16_path, int8_path, num_workers): if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("--input-bf16-path", type=str, required=True) + parser.add_argument("--input-path", type=str, required=True) parser.add_argument("--output-int8-path", type=str, required=True) parser.add_argument("--num-workers", type=int, default=32) args = parser.parse_args() - main(args.input_bf16_path, args.output_int8_path, args.num_workers) + main(args.input_path, args.output_int8_path, args.num_workers) print("done") diff --git a/tools/fp8_cast_channel_int8.py b/tools/fp8_cast_channel_int8.py deleted file mode 100644 index b2ead32c..00000000 --- a/tools/fp8_cast_channel_int8.py +++ /dev/null @@ -1,250 +0,0 @@ -# This file is based on DeepSeek code (MIT License). -# -# Original code: -# Copyright (c) 2023 DeepSeek -# https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py -# https://huggingface.co/meituan/DeepSeek-R1-Channel-INT8/blob/main/inference/bf16_cast_channel_int8.py (Meituan fork) # noqa: E501 -# -# Additional contributions: -# Copyright (c) 2026 Kunlunxin (Beijing) Technology Co., Ltd. (Kunlunxin) -# -# Modifications: -# - Merged implementations -# - Added multi-GPU parallel processing -# -# SPDX-License-Identifier: Apache-2.0 AND MIT - -import json -import os -import shutil -from argparse import ArgumentParser -from glob import glob - -import torch -import torch.multiprocessing as mp -from safetensors.torch import safe_open, save_file - -from angelslim.compressor.quant.core.quant_func import weight_dequant - - -def process_worker( - worker_id, safetensor_files, fp8_path, int8_path, weight_map, return_dict -): - """ - Process worker. - - Each worker process is responsible for a subset of safetensor files: - - FP8 → BF16 dequantization - - BF16 → INT8 quantization - - Generation of the updated weight_map - """ - num_gpus = torch.cuda.device_count() - rank = worker_id % num_gpus - torch.cuda.set_device(rank) - quant_count = 0 - new_weight_map = {} - for safetensor_file in safetensor_files: - file_name = os.path.basename(safetensor_file) - print(f"[Worker {worker_id}][GPU {rank}] processing {file_name}") - with safe_open(safetensor_file, framework="pt", device=f"cuda:{rank}") as f: - new_state_dict = {} - keys = set(f.keys()) - for weight_name in keys: - weight = f.get_tensor(weight_name) - scale_inv_name = f"{weight_name}_scale_inv" - if scale_inv_name in weight_map: - quant_count += 1 - # 1. fp8 dequant to bf16 - scale_inv = get_tensor_from_file( - rank, scale_inv_name, weight_map, fp8_path - ) - weight_bf16 = weight_dequant(weight, scale_inv) - # 2. bf16 quant to int8 - int8_weight, scale_inv = weight_quant(weight_bf16) - new_state_dict[weight_name] = int8_weight - new_scale_name = scale_inv_name.replace("_scale_inv", "_scale") - new_state_dict[new_scale_name] = scale_inv - new_weight_map[weight_name] = file_name - new_weight_map[new_scale_name] = file_name - else: - if weight_name.endswith("_scale_inv"): - continue - new_state_dict[weight_name] = weight - new_weight_map[weight_name] = file_name - - new_safetensor_file = os.path.join(int8_path, file_name) - save_file(new_state_dict, new_safetensor_file) - return_dict[worker_id] = (quant_count, new_weight_map) - - -# Helper function to get tensor from the correct file -def get_tensor_from_file(rank, tensor_name, weight_map, fp8_path): - """ - Retrieves a tensor from mmap safe_tensors - - Args: - tensor_name (str): The name of the tensor to retrieve. - - Returns: - torch.Tensor: The retrieved tensor. - - Raises: - KeyError: If the tensor does not exist in the safetensor file. - """ - torch.cuda.set_device(rank) - file_name = weight_map[tensor_name] - file_path = os.path.join(fp8_path, file_name) - - with safe_open(file_path, framework="pt", device=f"cuda:{rank}") as f: - return f.get_tensor(tensor_name) - - -def weight_quant(tensor: torch.Tensor): - """ - Quantize a 2D tensor row-wise from BF16/FP32 to INT8. - Args: - tensor (torch.Tensor): Input 2D tensor. - Returns: - Tuple[torch.Tensor, torch.Tensor]: - - Quantized INT8 tensor. - - Scale tensor (float32) used for quantization. - """ - assert tensor.dim() == 2 - qmax = 127.0 - abs_max = torch.abs(tensor).max(dim=1, keepdim=True)[0] # [rows, 1] - scale = abs_max / qmax # [rows, 1] - assert scale.shape == (tensor.shape[0], 1) - quantized = torch.round(tensor / scale) - quantized = torch.clamp(quantized, -qmax, qmax) - return quantized.to(torch.int8), scale.to(torch.float32) - - -def main(fp8_path, int8_path, num_workers): - """ - Run the FP8-to-INT8 per-channel quantization pipeline. - - This function: - 1. Copy the config file - 2. Loads FP8 safetensors. - 3. Dequantizes FP8 → BF16, then quantizes BF16 → INT8. - 4. Saves quantized safetensors and updates model index. - - Args: - fp8_path (str): Path to directory containing FP8 safetensors. - int8_path (str): Output directory to save INT8 safetensors. - num_workers (int): Number of processing workers - """ - torch.set_default_dtype(torch.bfloat16) - os.makedirs(int8_path, exist_ok=True) - model_index_file = os.path.join(int8_path, "model.safetensors.index.json") - config_file = os.path.join(int8_path, "config.json") - - for fname in os.listdir(fp8_path): - if fname.endswith(".safetensors"): - continue - src = os.path.join(fp8_path, fname) - dst = os.path.join(int8_path, fname) - if os.path.isdir(src): - print(f"cp -r {src} {dst}") - shutil.copytree(src, dst, dirs_exist_ok=True) - elif os.path.isfile(src): - print(f"cp {src} {dst}") - shutil.copy2(src, dst) - - # modify config.json and save it - config = json.load(open(config_file)) - # delete quantization_config - config.pop("quantization_config", None) - config["quantization_config"] = { - "config_groups": { - "group_0": { - "input_activations": { - "actorder": None, - "block_structure": None, - "dynamic": True, - "group_size": None, - "num_bits": 8, - "observer": "memoryless", - "observer_kwargs": {}, - "strategy": "token", - "symmetric": True, - "type": "int", - }, - "output_activations": None, - "weights": { - "actorder": None, - "block_structure": None, - "dynamic": False, - "group_size": None, - "num_bits": 8, - "observer": "minmax", - "observer_kwargs": {}, - "strategy": "channel", - "symmetric": True, - "type": "int", - }, - "targets": ["Linear"], - } - }, - "format": "int-quantized", - "ignore": ["lm_head"], - "kv_cache_scheme": None, - "quant_method": "compressed-tensors", - "quantization_status": "compressed", - } - - with open(config_file, "w", encoding="utf-8") as f: - json.dump(config, f, indent=2, ensure_ascii=False, sort_keys=True) - print(f"config.json modified and saved to {config_file}") - - with open(model_index_file, "r") as f: - model_index = json.load(f) - weight_map = model_index["weight_map"] - scale_count = len([key for key in weight_map.keys() if key.endswith("_scale_inv")]) - - safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors"))) - safetensor_files.sort() - quant_count = 0 - new_weight_map = {} - - file_subsets = [safetensor_files[i::num_workers] for i in range(num_workers)] - - mp.set_start_method("spawn", force=True) - manager = mp.Manager() - return_dict = manager.dict() - processes = [] - for i in range(num_workers): - p = mp.Process( - target=process_worker, - args=(i, file_subsets[i], fp8_path, int8_path, weight_map, return_dict), - ) - p.start() - processes.append(p) - for p in processes: - p.join() - - for i in range(num_workers): - qc, wm = return_dict[i] - quant_count += qc - new_weight_map.update(wm) - assert quant_count == scale_count - print(f"{quant_count} weights are quantized.") - - # modify model.safetensors.index.json - with open(model_index_file, "r") as f: - model_index = json.load(f) - model_index["weight_map"] = new_weight_map - with open(model_index_file, "w", encoding="utf-8") as f: - json.dump(model_index, f, indent=2, ensure_ascii=False, sort_keys=True) - print(f"model.safetensors.index.json modified and saved to {model_index_file}") - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--input-fp8-path", type=str, required=True) - parser.add_argument("--output-int8-path", type=str, required=True) - parser.add_argument("--num-workers", type=int, default=32) - - args = parser.parse_args() - main(args.input_fp8_path, args.output_int8_path, args.num_workers) - print("done") From b3aa705d7c07210d666113e91578393c12080a09 Mon Sep 17 00:00:00 2001 From: linchuanxie Date: Fri, 23 Jan 2026 15:48:55 +0800 Subject: [PATCH 4/5] change filename --- tools/{bf16_or_fp8_cast_channel_int8.py => int8_channel_quant.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tools/{bf16_or_fp8_cast_channel_int8.py => int8_channel_quant.py} (100%) diff --git a/tools/bf16_or_fp8_cast_channel_int8.py b/tools/int8_channel_quant.py similarity index 100% rename from tools/bf16_or_fp8_cast_channel_int8.py rename to tools/int8_channel_quant.py From 30833ef62533a4e929af89deed8ab5870a3885f8 Mon Sep 17 00:00:00 2001 From: linchuanxie Date: Wed, 28 Jan 2026 20:43:24 +0800 Subject: [PATCH 5/5] update requirements --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 2d6f0d15..db47c703 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,6 +1,6 @@ torch>=2.6.0 torchvision>=0.21.0 -transformers>=4.57.1 +transformers==4.57.6 safetensors>=0.5.3 numpy tqdm