From 6af1848a6c6569baafd03e056f9295b25d3b4dcf Mon Sep 17 00:00:00 2001 From: StromNoNo Date: Thu, 11 Sep 2025 16:25:37 +0800 Subject: [PATCH 1/3] add nvfp4 --- angelslim/compressor/quant/core/config.py | 26 ++ angelslim/compressor/quant/core/quant_func.py | 38 +++ angelslim/compressor/quant/core/save.py | 17 + .../compressor/quant/modules/__init__.py | 1 + .../compressor/quant/modules/helper_layer.py | 304 ++++++++++++++++++ .../quant/modules/nvfp4/__init__.py | 13 + .../compressor/quant/modules/nvfp4/nvfp4.py | 126 ++++++++ angelslim/compressor/quant/ptq.py | 19 +- angelslim/models/base_model.py | 1 + angelslim/utils/config_parser.py | 1 + configs/qwen3/nvfp4/qwen3-0_6b_nvfp4.yaml | 35 ++ configs/qwen3/nvfp4/qwen3-14b_nvfp4.yaml | 35 ++ configs/qwen3/nvfp4/qwen3-1_7b_nvfp4.yaml | 35 ++ configs/qwen3/nvfp4/qwen3-32b_nvfp4.yaml | 35 ++ configs/qwen3/nvfp4/qwen3-4b_nvfp4.yaml | 35 ++ configs/qwen3/nvfp4/qwen3-8b_nvfp4.yaml | 35 ++ 16 files changed, 753 insertions(+), 3 deletions(-) create mode 100644 angelslim/compressor/quant/modules/nvfp4/__init__.py create mode 100644 angelslim/compressor/quant/modules/nvfp4/nvfp4.py create mode 100644 configs/qwen3/nvfp4/qwen3-0_6b_nvfp4.yaml create mode 100644 configs/qwen3/nvfp4/qwen3-14b_nvfp4.yaml create mode 100644 configs/qwen3/nvfp4/qwen3-1_7b_nvfp4.yaml create mode 100644 configs/qwen3/nvfp4/qwen3-32b_nvfp4.yaml create mode 100644 configs/qwen3/nvfp4/qwen3-4b_nvfp4.yaml create mode 100644 configs/qwen3/nvfp4/qwen3-8b_nvfp4.yaml diff --git a/angelslim/compressor/quant/core/config.py b/angelslim/compressor/quant/core/config.py index 5ea979e0..a8b20be8 100644 --- a/angelslim/compressor/quant/core/config.py +++ b/angelslim/compressor/quant/core/config.py @@ -148,6 +148,32 @@ def __init__(self, config, global_config=None): "group_size": group_size, "ignore_layers": quantization_args.ignore_layers, } + elif "nvfp4" in self.quant_algo: + is_dynamic = "dynamic" if "dynamic" in self.quant_algo else "static" + assert ( + is_dynamic or act_quant_method is not None + ), "[Error] nvfp4 need act_quant_method" + self.act_observer = ( + ACT_OBSERVERS_CLASS[act_quant_method] + if "static" in is_dynamic + else None + ) + self.weight_observer = WEIGHT_OBSERVERS_CLASS[weight_quant_method] + self.kv_cache_observer = None + block_size = ( + 16 + if quantization_args.quant_method["group_size"] == -1 + else quantization_args.quant_method["group_size"] + ) + + self.quant_algo_info = { + "w": f"nvfp4_{weight_quant_method}", + "ignore_layers": quantization_args.ignore_layers, + "block_size": block_size, + } + + if act_quant_method is not None: + self.quant_algo_info["a"] = f"nvfp4_{act_quant_method}-{is_dynamic}" if "smooth" in self.quant_helpers: self.smooth_alpha = quantization_args.smooth_alpha diff --git a/angelslim/compressor/quant/core/quant_func.py b/angelslim/compressor/quant/core/quant_func.py index 5e79cc5a..bb4ea8d3 100644 --- a/angelslim/compressor/quant/core/quant_func.py +++ b/angelslim/compressor/quant/core/quant_func.py @@ -15,6 +15,7 @@ from typing import Tuple import torch +import torch.nn.functional as F import triton import triton.language as tl @@ -429,3 +430,40 @@ def per_block_weight_quant( weight_quant[grid](x, y, s, M, N, BLOCK_SIZE=block_size) return y, s + + +def reduce_block_padding(input: torch.Tensor, block_sizes: dict, pad_value: float = 0): + """Padding the input using block-based reduction for each dimension. + + Args: + input_tensor (torch.Tensor): The input tensor. + block_sizes (dict): A dictionary specifying the block size for + padding each dimension. Example: `{-1: 128, -2: 128}` pads + the input over 2D blocks. + """ + with torch.no_grad(): + padded_tensor = input + num_dims = padded_tensor.dim() + # Process each specified dimension independently + for dim, block in block_sizes.items(): + # Convert negative dimension to positive index + pos_dim = dim if dim >= 0 else num_dims + dim + + # Calculate how many elements are missing along that dimension + current_size = padded_tensor.size(pos_dim) + remainder = current_size % block + pad_amt = 0 if remainder == 0 else block - remainder + + if pad_amt > 0: + # F.pad expects a pad tuple of length 2*num_dims. + pad = [0] * (2 * num_dims) + # For dimension pos_dim, the right padding is at index: + # (num_dims - 1 - pos_dim)*2 + 1. + pad_index = (num_dims - 1 - pos_dim) * 2 + pad[pad_index + 1] = ( + pad_amt # Set padding on the right side of the target dimension + ) + + padded_tensor = F.pad(padded_tensor, pad, value=pad_value) + + return padded_tensor diff --git a/angelslim/compressor/quant/core/save.py b/angelslim/compressor/quant/core/save.py index e146e123..71b6025a 100644 --- a/angelslim/compressor/quant/core/save.py +++ b/angelslim/compressor/quant/core/save.py @@ -154,6 +154,23 @@ def save(self, save_path): "dynamic": False, "type": "int", } + elif "nvfp4" in self.quant_model.quant_config.quant_algo: + quant_format = "naive-quantized" + group_size = self.quant_model.quant_config.quant_algo_info["block_size"] + trtllm_config["quantization"]["quant_algo"] = "NVFP4" + trtllm_config["quantization"]["group_size"] = group_size + act_config = { + "num_bits": 4, + "group_size": group_size, + "dynamic": "dynamic" in a_quant_algo, + "type": "float", + } + weight_config = { + "num_bits": 4, + "group_size": group_size, + "dynamic": False, + "type": "float", + } else: raise ValueError( f"{self.quant_model.quant_config.quant_algo} not supported" diff --git a/angelslim/compressor/quant/modules/__init__.py b/angelslim/compressor/quant/modules/__init__.py index 07962bdf..1997da73 100644 --- a/angelslim/compressor/quant/modules/__init__.py +++ b/angelslim/compressor/quant/modules/__init__.py @@ -25,4 +25,5 @@ from .helper_layer import SmoothHelpModule # noqa: F401 from .helper_layer import WQLinearGEMM # noqa: F401 from .int8.int8 import INT8 # noqa: F401 +from .nvfp4.nvfp4 import NVFP4 # noqa: F401 from .smooth.smooth import SmoothQuant # noqa: F401 diff --git a/angelslim/compressor/quant/modules/helper_layer.py b/angelslim/compressor/quant/modules/helper_layer.py index 044e9714..cae1f00d 100644 --- a/angelslim/compressor/quant/modules/helper_layer.py +++ b/angelslim/compressor/quant/modules/helper_layer.py @@ -31,10 +31,17 @@ quantize_activation_per_tensor_fp8, quantize_weight_int, quantize_weight_per_tensor_fp8, + reduce_block_padding, tensor_quant_dequant_fp8, tensor_quant_dequant_int, ) +# Define conversion tables +e2m1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5]) +e2m1_values = torch.tensor( + [0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6] +) + def flush(): gc.collect() @@ -718,3 +725,300 @@ def forward(self, x): raise ValueError(f"Unsupported quantization algorithm: {self.quant_algo}") return output + + +class NVFP4QDQModule(torch.nn.Module): + def __init__( + self, + weight: torch.nn.Parameter, + weight_scale: torch.nn.Parameter, + weight_scale_2: torch.nn.Parameter, + bias: torch.nn.Parameter, + block_size: int = 16, + input_scale: Optional[torch.nn.Parameter] = None, + ): + super().__init__() + self.e2m1_values_on_device = {} + self.shape = weight.shape + self.dtype = weight.dtype + self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + self.weight_scale_2 = torch.nn.Parameter(weight_scale_2, requires_grad=False) + self.bias = bias + self.block_size = block_size + if input_scale is not None: + self.input_scale = torch.nn.Parameter(input_scale, requires_grad=False) + else: + self.input_scale = None + + quant_weight = self.to_quantized_weight( + weight, + weight_scale, + weight_scale_2, + block_size, + ) + self.weight = torch.nn.Parameter(quant_weight, requires_grad=False) + + def get_e2m1_values(self, device): + """Returns the e2m1 values on the device.""" + if device not in self.e2m1_values_on_device: + self.e2m1_values_on_device[device] = e2m1_values.to(device) + return self.e2m1_values_on_device[device] + + def _cast_fp4(self, weight: torch.Tensor): + """Converts tensor to uint4.""" + # Get device + device = weight.device + + # Define mask to perform rounding + mask = torch.tensor([0, 1, 0, 1, 0, 1, 0], dtype=torch.uint8).to(device) + mask_shape = list(weight.shape) + mask = mask.expand([*mask_shape, 7]) + + sign_bit = (weight < 0).to(torch.uint8) + + weight_abs = weight.abs_() + # Calculate the ordinal value based on the bounds + ord = torch.searchsorted(e2m1_bounds.to(device), weight_abs, out_int32=True).to( + torch.uint8 + ) + # All values equal to e2m1_bounds at odd indices are rounded up + # and even indices are rounded down + round = torch.any( + (weight_abs.unsqueeze(-1) == e2m1_bounds.to(device)) * mask, dim=-1 + ) + fp4_val = (sign_bit * 0b1000 + ord + round).to(torch.uint8) + return fp4_val + + def quantize( + self, + weight: torch.Tensor, + block_size: int, + weights_scaling_factor: torch.Tensor | None = None, + weights_scaling_factor_2: torch.Tensor | None = None, + keep_high_precision: bool = False, + ): + """Converting a tensor to a quantized format based on NVFP4 quantization. + + Args: + weight (torch.Tensor): The weight tensor to be quantized. + block_size (int): The size of each block for quantization. + weights_scaling_factor (torch.Tensor): The scaling factor for the weights. + weights_scaling_factor_2 (torch.Tensor): The scaling factor for the weights. + keep_high_precision (bool): Whether to keep output scales at high precision. + + Returns: + Quantized data. + """ + # pad the weight if needed + weight = reduce_block_padding(weight, block_sizes={-1: block_size}) + + # Reshape the weight and scale factors + weight = weight.view((*tuple(weight.shape[:-1]), -1, block_size)) + + # Scale weights + scaled_weight = weight / ( + ( + weights_scaling_factor.to(torch.float32) * weights_scaling_factor_2 + ).unsqueeze(-1) + ) + + # Reshape weights to original + scaled_weight = scaled_weight.view((*tuple(scaled_weight.shape[:-2]), -1)) + + if keep_high_precision: + return scaled_weight + # Cast weights to fp4 + q_weight = self._cast_fp4(scaled_weight) + # Pack weights + packed_weight = (q_weight[..., 1::2] << 4) | q_weight[..., 0::2] + + return packed_weight + + def to_quantized_weight( + self, + weight: torch.Tensor, + weights_scaling_factor: torch.Tensor, + weights_scaling_factor2: torch.Tensor | None = None, + block_size: int | None = None, + ): + """Converts the weight to the quantized (packed) format.""" + if weights_scaling_factor is not None: + weights_scaling_factor = weights_scaling_factor.to(weight.device) + + if weights_scaling_factor2 is not None: + weights_scaling_factor2 = weights_scaling_factor2.to(weight.device) + + assert ( + block_size is not None + ), "Block size not passed. Unable to quantize to NVFP4 format." + assert ( + weights_scaling_factor2 is not None + ), "Weights scaling factor 2 not passed. Unable to quantize to NVFP4 format" + # If MoE reshape weights_scaling_factor2 to enable quantize operations + return self.quantize( + weight, + block_size, + weights_scaling_factor, + ( + weights_scaling_factor2.view(-1, 1, 1) + if weights_scaling_factor2.dim() != 0 + else weights_scaling_factor2 + ), + ) + + def get_input_scaling_factor( + self, + inputs: torch.Tensor, + block_size: int, + inputs_scaling_factor_2: torch.Tensor | None = None, + keep_high_precision: bool = False, + ): + """Returns quantized per block input scaling factor.""" + # Get per_block amax + [n, k] = inputs.shape[-2:] + assert ( + block_size != 0 + ), "Block size is zero. Cannot return per_block amax for given input." + + assert ( + k % block_size == 0 + ), "input shape is not divisible for block size for block quantiation." + + inputs = inputs.reshape( + (*tuple(inputs.shape[:-2]), n, k // block_size, block_size) + ) + # Get per block amax + per_block_amax = inputs.abs().amax(dim=-1).float() + # Get per-block-scale + per_block_scale = per_block_amax / 6.0 + # Quantize per_block_scale to FP8 + q_per_block_scale = per_block_scale / inputs_scaling_factor_2 + # Set all zero values in scale to 1.0 + q_per_block_scale[per_block_scale == 0] = 1.0 + # Convert to torch.float8_e4m3fn + if not keep_high_precision: + finfo = torch.finfo(torch.float8_e4m3fn) + q_per_block_scale = q_per_block_scale.clamp(min=finfo.min, max=finfo.max) + q_per_block_scale = q_per_block_scale.to(torch.float8_e4m3fn) + return q_per_block_scale + + def quantize_input( + self, + inputs: torch.Tensor, + block_size: int, + inputs_scaling_factor: torch.Tensor | None = None, + inputs_scaling_factor_2: torch.Tensor | None = None, + keep_high_precision: bool = False, + ): + """Converting a tensor to a quantized format based on NVFP4 quantization. + + Args: + weight (torch.Tensor): The weight tensor to be quantized. + block_size (int): The size of each block for quantization. + weights_scaling_factor (torch.Tensor): The scaling factor for the weights. + weights_scaling_factor_2 (torch.Tensor): The scaling factor for the weights. + keep_high_precision (bool): Whether to keep output scales at high precision. + + Returns: + tuple: Contains quantized data, quantized per block scaling factor, + and per tensor scaling factor. + """ + # pad the weight if needed + inputs = reduce_block_padding(inputs, block_sizes={-1: block_size}) + + # Reshape the weight and scale factors + inputs = inputs.view((*tuple(inputs.shape[:-1]), -1, block_size)) + + # Scale weights + scaled_inputs = inputs / ( + ( + inputs_scaling_factor.to(torch.float32) * inputs_scaling_factor_2 + ).unsqueeze(-1) + ) + + # Reshape weights to original + scaled_inputs = scaled_inputs.view((*tuple(scaled_inputs.shape[:-2]), -1)) + + if keep_high_precision: + return scaled_inputs + # Cast weights to fp4 + cast_inputs = self._cast_fp4(scaled_inputs) + qinputs = self.get_e2m1_values(cast_inputs.device)[cast_inputs.long()] + + return qinputs + + def forward(self, x): + qdqweight = self.dequantize( + self.weight, self.block_size, self.weight_scale, self.weight_scale_2 + ) + + if self.input_scale is None: + input_amax = x.abs().amax() + input_scale_2 = input_amax.float() / 6.0 / 448.0 + else: + input_scale_2 = self.input_scale + + input_scale = self.get_input_scaling_factor( + inputs=x.detach(), + inputs_scaling_factor_2=input_scale_2, + block_size=self.block_size, + ) + + qinput = self.quantize_input( + x, + self.block_size, + input_scale, + input_scale_2.view(-1, 1, 1) if input_scale_2.dim() != 0 else input_scale_2, + ) + + qdqinput = qinput.view( + qinput.shape[0], qinput.shape[1], qinput.shape[2] // self.block_size, -1 + ) * (input_scale.to(torch.float32) * input_scale_2).unsqueeze(-1) + qdqinput = qdqinput.view(-1)[: np.prod(x.shape)].reshape(x.shape).to(x.dtype) + + output = torch.nn.functional.linear( + qdqinput.to(self.dtype), + qdqweight, + bias=self.bias, + ) + + return output + + def dequantize( + self, + weight: torch.Tensor, + block_size: int, + weights_scaling_factor: torch.Tensor | None = None, + weights_scaling_factor_2: torch.Tensor | None = None, + ): + """Dequantze NVFP4 packed tensor to a target dtype.""" + dtype = self.dtype + + def _unpack_tensor(input: torch.Tensor): + # Initalize storage for unpacked tensor + unpacked = torch.empty( + [input.shape[0], input.shape[1] * 2], dtype=dtype, device=input.device + ) + unpacked_shape = unpacked.shape + + unpacked[..., 1::2] = input >> 4 + unpacked[..., 0::2] = input & 0x0F + + unpacked = unpacked.reshape(-1) + unpacked = self.get_e2m1_values(input.device)[unpacked.long()] + + return unpacked.reshape(unpacked_shape) + + q_per_block_scale = weights_scaling_factor.to(torch.float32) + per_block_quant_scale = weights_scaling_factor_2 + + # Dequantize scales + per_block_scale = q_per_block_scale * per_block_quant_scale + + # Unpack and unscale weights + deq_data = _unpack_tensor(weight) + + deq_data = deq_data.view( + deq_data.shape[0], deq_data.shape[1] // block_size, -1 + ) * per_block_scale.unsqueeze(-1) + return deq_data.view(-1)[: np.prod(self.shape)].reshape(self.shape).to(dtype) diff --git a/angelslim/compressor/quant/modules/nvfp4/__init__.py b/angelslim/compressor/quant/modules/nvfp4/__init__.py new file mode 100644 index 00000000..eca77b00 --- /dev/null +++ b/angelslim/compressor/quant/modules/nvfp4/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/angelslim/compressor/quant/modules/nvfp4/nvfp4.py b/angelslim/compressor/quant/modules/nvfp4/nvfp4.py new file mode 100644 index 00000000..4d0b00e4 --- /dev/null +++ b/angelslim/compressor/quant/modules/nvfp4/nvfp4.py @@ -0,0 +1,126 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from .....utils import print_info +from ...modules.helper_layer import NVFP4QDQModule + +__all__ = ["NVFP4"] + + +class NVFP4: + def __init__( + self, + model, + ): + """ + Args: + model(nn.Module, required): The model to be smoothed. + """ + super(NVFP4, self).__init__() + self.model = model + self.block_size = self.model.quant_config.quant_algo_info["block_size"] + + @torch.no_grad() + def run(self, dataloader): + print_info("Use NVFP4 fast forward") + self.model.model_forward(dataloader) + + def get_activation_scaling_factor(self, input_observer_amax): + """Returns the activation scaling factor for export.""" + activation_scaling_factor = input_observer_amax.float() / 6.0 / 448.0 + + assert torch.all( + activation_scaling_factor > 0 + ), f" activation scaling factor {activation_scaling_factor} not positive." + + return activation_scaling_factor + + def get_weights_scaling_factor_2(self, weight_observer_amax): + """Returns per tensor weight scaling factor.""" + return weight_observer_amax.float() / 6.0 / 448.0 + + def get_weights_scaling_factor( + self, + weight: torch.Tensor, + block_size: int, + weights_scaling_factor_2: torch.Tensor | None = None, + keep_high_precision: bool = False, + ): + """Returns quantized per block weight scaling factor.""" + # Get per_block amax + [n, k] = weight.shape[-2:] + assert ( + block_size != 0 + ), "Block size is zero. Cannot return per_block amax for given weight." + + assert ( + k % block_size == 0 + ), "Weight shape is not divisible for block size for block quantiation." + + weight = weight.reshape( + (*tuple(weight.shape[:-2]), n, k // block_size, block_size) + ) + # Get per block amax + per_block_amax = weight.abs().amax(dim=-1).float() + # Get per-block-scale + per_block_scale = per_block_amax / 6.0 + # Quantize per_block_scale to FP8 + q_per_block_scale = per_block_scale / weights_scaling_factor_2 + # Set all zero values in scale to 1.0 + q_per_block_scale[per_block_scale == 0] = 1.0 + # Convert to torch.float8_e4m3fn + if not keep_high_precision: + finfo = torch.finfo(torch.float8_e4m3fn) + q_per_block_scale = q_per_block_scale.clamp(min=finfo.min, max=finfo.max) + q_per_block_scale = q_per_block_scale.to(torch.float8_e4m3fn) + return q_per_block_scale + + def post_process(self, sub_layer, name): + weight_observer_amax = self.model.weight_scales_dict[name] + weight_scale_2 = self.get_weights_scaling_factor_2(weight_observer_amax) + self.model.weight_scales_dict_2[name] = weight_scale_2 + + weight_scale = self.get_weights_scaling_factor( + weight=sub_layer.weight.detach(), + weights_scaling_factor_2=weight_scale_2, + block_size=self.block_size, + ) + self.model.weight_scales_dict[name] = weight_scale + + input_observer_amax = self.model.act_scales_dict[name] + input_scale = self.get_activation_scaling_factor(input_observer_amax) + self.model.act_scales_dict[name] = input_scale + + def get_qdq_module(self, sub_layer, name): + act_scale, weight_scale, weight_scale_2 = None, None, None + if name in self.model.act_scales_dict: + act_scale = self.model.act_scales_dict[name] + if name in self.model.weight_scales_dict: + weight_scale = self.model.weight_scales_dict[name] + if name in self.model.weight_scales_dict_2: + weight_scale_2 = self.model.weight_scales_dict_2[name] + + q_linear = NVFP4QDQModule( + weight=sub_layer.weight, + weight_scale=weight_scale, + weight_scale_2=weight_scale_2, + bias=sub_layer.bias, + block_size=self.block_size, + input_scale=act_scale, + ) + + return q_linear diff --git a/angelslim/compressor/quant/ptq.py b/angelslim/compressor/quant/ptq.py index 5529a426..822bb8b5 100644 --- a/angelslim/compressor/quant/ptq.py +++ b/angelslim/compressor/quant/ptq.py @@ -18,7 +18,7 @@ from ...utils import find_parent_layer_and_sub_name, print_info from ..compressor_factory import CompressorFactory from .core import PTQHook -from .modules import AWQ, FP8, GPTQ, INT8, LeptoFP8, SmoothQuant +from .modules import AWQ, FP8, GPTQ, INT8, NVFP4, LeptoFP8, SmoothQuant __all__ = ["PTQ"] @@ -38,7 +38,11 @@ def __init__(self, model, slim_config=None): self.quant_model.init_ptq(slim_config) self.quant_algo = self.quant_model.quant_config.quant_algo self.quant_helpers = self.quant_model.quant_config.quant_helpers - if "fp8" in self.quant_algo or "int8" in self.quant_algo: + if ( + "fp8" in self.quant_algo + or "int8" in self.quant_algo + or "nvfp4" in self.quant_algo + ): # Add ptq observer hook self.ptq_hook = PTQHook(self.quant_model) self.ptq_hook.apply_hook() @@ -94,6 +98,8 @@ def __init__(self, model, slim_config=None): model_arch_type=model_arch_type, low_memory=self.quant_model.quant_config.low_memory, ) + elif "nvfp4" in self.quant_algo: + self.nvfp4 = NVFP4(self.quant_model) else: raise NotImplementedError( f"[AngelSlim Error] algo {self.quant_algo} is not support" @@ -115,6 +121,8 @@ def calibrate(self, dataloader): self.fp8.run(dataloader) elif "int8" in self.quant_algo: self.int8.run(dataloader) + elif "nvfp4" in self.quant_algo: + self.nvfp4.run(dataloader) else: raise AssertionError( f"[AngelSlim Error] algo {self.quant_algo} is not support calibrate" @@ -216,7 +224,12 @@ def _convert(self): quant_convert_module, name ) - qdq_module = self.quant_model.get_qdq_module(sub_layer, name) + if "nvfp4" in self.quant_algo: + self.nvfp4.post_process(sub_layer, name) + qdq_module = self.nvfp4.get_qdq_module(sub_layer, name) + else: + qdq_module = self.quant_model.get_qdq_module(sub_layer, name) + if qdq_module is not sub_layer: setattr(parent_layer, sub_name, qdq_module) self.quant_model.quantized = True diff --git a/angelslim/models/base_model.py b/angelslim/models/base_model.py index 3bd0edd3..f63e68cd 100644 --- a/angelslim/models/base_model.py +++ b/angelslim/models/base_model.py @@ -88,6 +88,7 @@ def init_ptq(self, slim_config): self.quant_config = quant_config self.act_scales_dict = {} self.weight_scales_dict = {} + self.weight_scales_dict_2 = {} self.kv_cache_scales_dict = {} if hasattr(self.quant_config, "weight_observer"): self.quant_algo_dict = self.get_quant_config() diff --git a/angelslim/utils/config_parser.py b/angelslim/utils/config_parser.py index 0dbf8c30..240449f4 100644 --- a/angelslim/utils/config_parser.py +++ b/angelslim/utils/config_parser.py @@ -42,6 +42,7 @@ class QuantizationMethod(str, Enum): INT8_DYNAMIC = "int8_dynamic" W4A8_FP8 = "w4a8_fp8" INT4_GPTAQ = "int4_gptaq" + NVFP4 = "nvfp4" @dataclass diff --git a/configs/qwen3/nvfp4/qwen3-0_6b_nvfp4.yaml b/configs/qwen3/nvfp4/qwen3-0_6b_nvfp4.yaml new file mode 100644 index 00000000..810fdfc2 --- /dev/null +++ b/configs/qwen3/nvfp4/qwen3-0_6b_nvfp4.yaml @@ -0,0 +1,35 @@ +# Global configuration of pipeline +global: + save_path: ./output + +# Simplified Configuration for LLM compression +model: + name: Qwen + model_path: Qwen/Qwen3-0.6B + trust_remote_code: true + low_cpu_mem_usage: true + use_cache: false + torch_dtype: auto + device_map: auto + +# Compression configuration +compression: + name: PTQ + quantization: + name: nvfp4 + bits: 4 + quant_method: + weight: "per-tensor" + activation: "per-tensor" + group_size: 16 + ignore_layers: # Skip quantization for these layers + - "lm_head" + - "model.embed_tokens" + +# Dataset for calibration +dataset: + name: TextDataset + data_path: ./dataset/sharegpt_gpt4_qwen/sharegpt_gpt4-qwen3_a22B_output.jsonl + max_seq_length: 4096 + num_samples: 256 + batch_size: 1 diff --git a/configs/qwen3/nvfp4/qwen3-14b_nvfp4.yaml b/configs/qwen3/nvfp4/qwen3-14b_nvfp4.yaml new file mode 100644 index 00000000..08b421a8 --- /dev/null +++ b/configs/qwen3/nvfp4/qwen3-14b_nvfp4.yaml @@ -0,0 +1,35 @@ +# Global configuration of pipeline +global: + save_path: ./output + +# Simplified Configuration for LLM compression +model: + name: Qwen + model_path: Qwen/Qwen3-14B + trust_remote_code: true + low_cpu_mem_usage: true + use_cache: false + torch_dtype: auto + device_map: auto + +# Compression configuration +compression: + name: PTQ + quantization: + name: nvfp4 + bits: 4 + quant_method: + weight: "per-tensor" + activation: "per-tensor" + group_size: 16 + ignore_layers: # Skip quantization for these layers + - "lm_head" + - "model.embed_tokens" + +# Dataset for calibration +dataset: + name: TextDataset + data_path: ./dataset/sharegpt_gpt4_qwen/sharegpt_gpt4-qwen3_a22B_output.jsonl + max_seq_length: 4096 + num_samples: 256 + batch_size: 1 diff --git a/configs/qwen3/nvfp4/qwen3-1_7b_nvfp4.yaml b/configs/qwen3/nvfp4/qwen3-1_7b_nvfp4.yaml new file mode 100644 index 00000000..35b1ec9e --- /dev/null +++ b/configs/qwen3/nvfp4/qwen3-1_7b_nvfp4.yaml @@ -0,0 +1,35 @@ +# Global configuration of pipeline +global: + save_path: ./output + +# Simplified Configuration for LLM compression +model: + name: Qwen + model_path: Qwen/Qwen3-1.7B + trust_remote_code: true + low_cpu_mem_usage: true + use_cache: false + torch_dtype: auto + device_map: auto + +# Compression configuration +compression: + name: PTQ + quantization: + name: nvfp4 + bits: 4 + quant_method: + weight: "per-tensor" + activation: "per-tensor" + group_size: 16 + ignore_layers: # Skip quantization for these layers + - "lm_head" + - "model.embed_tokens" + +# Dataset for calibration +dataset: + name: TextDataset + data_path: ./dataset/sharegpt_gpt4_qwen/sharegpt_gpt4-qwen3_a22B_output.jsonl + max_seq_length: 4096 + num_samples: 256 + batch_size: 1 diff --git a/configs/qwen3/nvfp4/qwen3-32b_nvfp4.yaml b/configs/qwen3/nvfp4/qwen3-32b_nvfp4.yaml new file mode 100644 index 00000000..f9cda058 --- /dev/null +++ b/configs/qwen3/nvfp4/qwen3-32b_nvfp4.yaml @@ -0,0 +1,35 @@ +# Global configuration of pipeline +global: + save_path: ./output + +# Simplified Configuration for LLM compression +model: + name: Qwen + model_path: Qwen/Qwen3-32B + trust_remote_code: true + low_cpu_mem_usage: true + use_cache: false + torch_dtype: auto + device_map: auto + +# Compression configuration +compression: + name: PTQ + quantization: + name: nvfp4 + bits: 4 + quant_method: + weight: "per-tensor" + activation: "per-tensor" + group_size: 16 + ignore_layers: # Skip quantization for these layers + - "lm_head" + - "model.embed_tokens" + +# Dataset for calibration +dataset: + name: TextDataset + data_path: ./dataset/sharegpt_gpt4_qwen/sharegpt_gpt4-qwen3_a22B_output.jsonl + max_seq_length: 4096 + num_samples: 256 + batch_size: 1 diff --git a/configs/qwen3/nvfp4/qwen3-4b_nvfp4.yaml b/configs/qwen3/nvfp4/qwen3-4b_nvfp4.yaml new file mode 100644 index 00000000..e0cdc287 --- /dev/null +++ b/configs/qwen3/nvfp4/qwen3-4b_nvfp4.yaml @@ -0,0 +1,35 @@ +# Global configuration of pipeline +global: + save_path: ./output + +# Simplified Configuration for LLM compression +model: + name: Qwen + model_path: Qwen/Qwen3-4B + trust_remote_code: true + low_cpu_mem_usage: true + use_cache: false + torch_dtype: auto + device_map: auto + +# Compression configuration +compression: + name: PTQ + quantization: + name: nvfp4 + bits: 4 + quant_method: + weight: "per-tensor" + activation: "per-tensor" + group_size: 16 + ignore_layers: # Skip quantization for these layers + - "lm_head" + - "model.embed_tokens" + +# Dataset for calibration +dataset: + name: TextDataset + data_path: ./dataset/sharegpt_gpt4_qwen/sharegpt_gpt4-qwen3_a22B_output.jsonl + max_seq_length: 4096 + num_samples: 256 + batch_size: 1 diff --git a/configs/qwen3/nvfp4/qwen3-8b_nvfp4.yaml b/configs/qwen3/nvfp4/qwen3-8b_nvfp4.yaml new file mode 100644 index 00000000..acc2c737 --- /dev/null +++ b/configs/qwen3/nvfp4/qwen3-8b_nvfp4.yaml @@ -0,0 +1,35 @@ +# Global configuration of pipeline +global: + save_path: ./output + +# Simplified Configuration for LLM compression +model: + name: Qwen + model_path: Qwen/Qwen3-8B + trust_remote_code: true + low_cpu_mem_usage: true + use_cache: false + torch_dtype: auto + device_map: auto + +# Compression configuration +compression: + name: PTQ + quantization: + name: nvfp4 + bits: 4 + quant_method: + weight: "per-tensor" + activation: "per-tensor" + group_size: 16 + ignore_layers: # Skip quantization for these layers + - "lm_head" + - "model.embed_tokens" + +# Dataset for calibration +dataset: + name: TextDataset + data_path: ./dataset/sharegpt_gpt4_qwen/sharegpt_gpt4-qwen3_a22B_output.jsonl + max_seq_length: 4096 + num_samples: 256 + batch_size: 1 From 6f7e149a22fc22810093f8cd8f7d82a49ef71bfb Mon Sep 17 00:00:00 2001 From: StromNoNo Date: Mon, 15 Sep 2025 15:46:03 +0800 Subject: [PATCH 2/3] fix nvfp4 --- angelslim/compressor/quant/core/config.py | 6 ++---- .../compressor/quant/modules/helper_layer.py | 21 +++++++++---------- configs/qwen3/nvfp4/qwen3-0_6b_nvfp4.yaml | 4 ++-- configs/qwen3/nvfp4/qwen3-14b_nvfp4.yaml | 4 ++-- configs/qwen3/nvfp4/qwen3-1_7b_nvfp4.yaml | 4 ++-- configs/qwen3/nvfp4/qwen3-32b_nvfp4.yaml | 4 ++-- configs/qwen3/nvfp4/qwen3-4b_nvfp4.yaml | 4 ++-- configs/qwen3/nvfp4/qwen3-8b_nvfp4.yaml | 4 ++-- 8 files changed, 24 insertions(+), 27 deletions(-) diff --git a/angelslim/compressor/quant/core/config.py b/angelslim/compressor/quant/core/config.py index a8b20be8..54b3b624 100644 --- a/angelslim/compressor/quant/core/config.py +++ b/angelslim/compressor/quant/core/config.py @@ -154,11 +154,9 @@ def __init__(self, config, global_config=None): is_dynamic or act_quant_method is not None ), "[Error] nvfp4 need act_quant_method" self.act_observer = ( - ACT_OBSERVERS_CLASS[act_quant_method] - if "static" in is_dynamic - else None + AbsmaxPertensorObserver if "static" in is_dynamic else None ) - self.weight_observer = WEIGHT_OBSERVERS_CLASS[weight_quant_method] + self.weight_observer = AbsmaxPertensorObserver self.kv_cache_observer = None block_size = ( 16 diff --git a/angelslim/compressor/quant/modules/helper_layer.py b/angelslim/compressor/quant/modules/helper_layer.py index cae1f00d..a67a84e8 100644 --- a/angelslim/compressor/quant/modules/helper_layer.py +++ b/angelslim/compressor/quant/modules/helper_layer.py @@ -36,12 +36,6 @@ tensor_quant_dequant_int, ) -# Define conversion tables -e2m1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5]) -e2m1_values = torch.tensor( - [0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6] -) - def flush(): gc.collect() @@ -738,6 +732,11 @@ def __init__( input_scale: Optional[torch.nn.Parameter] = None, ): super().__init__() + # Define conversion tables + self.e2m1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5]) + self.e2m1_values = torch.tensor( + [0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6] + ) self.e2m1_values_on_device = {} self.shape = weight.shape self.dtype = weight.dtype @@ -761,7 +760,7 @@ def __init__( def get_e2m1_values(self, device): """Returns the e2m1 values on the device.""" if device not in self.e2m1_values_on_device: - self.e2m1_values_on_device[device] = e2m1_values.to(device) + self.e2m1_values_on_device[device] = self.e2m1_values.to(device) return self.e2m1_values_on_device[device] def _cast_fp4(self, weight: torch.Tensor): @@ -778,13 +777,13 @@ def _cast_fp4(self, weight: torch.Tensor): weight_abs = weight.abs_() # Calculate the ordinal value based on the bounds - ord = torch.searchsorted(e2m1_bounds.to(device), weight_abs, out_int32=True).to( - torch.uint8 - ) + ord = torch.searchsorted( + self.e2m1_bounds.to(device), weight_abs, out_int32=True + ).to(torch.uint8) # All values equal to e2m1_bounds at odd indices are rounded up # and even indices are rounded down round = torch.any( - (weight_abs.unsqueeze(-1) == e2m1_bounds.to(device)) * mask, dim=-1 + (weight_abs.unsqueeze(-1) == self.e2m1_bounds.to(device)) * mask, dim=-1 ) fp4_val = (sign_bit * 0b1000 + ord + round).to(torch.uint8) return fp4_val diff --git a/configs/qwen3/nvfp4/qwen3-0_6b_nvfp4.yaml b/configs/qwen3/nvfp4/qwen3-0_6b_nvfp4.yaml index 810fdfc2..f6ebcbba 100644 --- a/configs/qwen3/nvfp4/qwen3-0_6b_nvfp4.yaml +++ b/configs/qwen3/nvfp4/qwen3-0_6b_nvfp4.yaml @@ -19,8 +19,8 @@ compression: name: nvfp4 bits: 4 quant_method: - weight: "per-tensor" - activation: "per-tensor" + weight: "per-block" + activation: "per-block" group_size: 16 ignore_layers: # Skip quantization for these layers - "lm_head" diff --git a/configs/qwen3/nvfp4/qwen3-14b_nvfp4.yaml b/configs/qwen3/nvfp4/qwen3-14b_nvfp4.yaml index 08b421a8..d5b2b91e 100644 --- a/configs/qwen3/nvfp4/qwen3-14b_nvfp4.yaml +++ b/configs/qwen3/nvfp4/qwen3-14b_nvfp4.yaml @@ -19,8 +19,8 @@ compression: name: nvfp4 bits: 4 quant_method: - weight: "per-tensor" - activation: "per-tensor" + weight: "per-block" + activation: "per-block" group_size: 16 ignore_layers: # Skip quantization for these layers - "lm_head" diff --git a/configs/qwen3/nvfp4/qwen3-1_7b_nvfp4.yaml b/configs/qwen3/nvfp4/qwen3-1_7b_nvfp4.yaml index 35b1ec9e..55cee8cc 100644 --- a/configs/qwen3/nvfp4/qwen3-1_7b_nvfp4.yaml +++ b/configs/qwen3/nvfp4/qwen3-1_7b_nvfp4.yaml @@ -19,8 +19,8 @@ compression: name: nvfp4 bits: 4 quant_method: - weight: "per-tensor" - activation: "per-tensor" + weight: "per-block" + activation: "per-block" group_size: 16 ignore_layers: # Skip quantization for these layers - "lm_head" diff --git a/configs/qwen3/nvfp4/qwen3-32b_nvfp4.yaml b/configs/qwen3/nvfp4/qwen3-32b_nvfp4.yaml index f9cda058..321a1bd7 100644 --- a/configs/qwen3/nvfp4/qwen3-32b_nvfp4.yaml +++ b/configs/qwen3/nvfp4/qwen3-32b_nvfp4.yaml @@ -19,8 +19,8 @@ compression: name: nvfp4 bits: 4 quant_method: - weight: "per-tensor" - activation: "per-tensor" + weight: "per-block" + activation: "per-block" group_size: 16 ignore_layers: # Skip quantization for these layers - "lm_head" diff --git a/configs/qwen3/nvfp4/qwen3-4b_nvfp4.yaml b/configs/qwen3/nvfp4/qwen3-4b_nvfp4.yaml index e0cdc287..addd696b 100644 --- a/configs/qwen3/nvfp4/qwen3-4b_nvfp4.yaml +++ b/configs/qwen3/nvfp4/qwen3-4b_nvfp4.yaml @@ -19,8 +19,8 @@ compression: name: nvfp4 bits: 4 quant_method: - weight: "per-tensor" - activation: "per-tensor" + weight: "per-block" + activation: "per-block" group_size: 16 ignore_layers: # Skip quantization for these layers - "lm_head" diff --git a/configs/qwen3/nvfp4/qwen3-8b_nvfp4.yaml b/configs/qwen3/nvfp4/qwen3-8b_nvfp4.yaml index acc2c737..0c80b721 100644 --- a/configs/qwen3/nvfp4/qwen3-8b_nvfp4.yaml +++ b/configs/qwen3/nvfp4/qwen3-8b_nvfp4.yaml @@ -19,8 +19,8 @@ compression: name: nvfp4 bits: 4 quant_method: - weight: "per-tensor" - activation: "per-tensor" + weight: "per-block" + activation: "per-block" group_size: 16 ignore_layers: # Skip quantization for these layers - "lm_head" From da6dea04348486055200314ea511b59839d2cfd3 Mon Sep 17 00:00:00 2001 From: StromNoNo Date: Mon, 15 Sep 2025 16:09:35 +0800 Subject: [PATCH 3/3] fix get_nvfp4_qdq_module --- .../compressor/quant/modules/__init__.py | 1 + .../compressor/quant/modules/nvfp4/nvfp4.py | 23 +--------------- angelslim/compressor/quant/ptq.py | 2 +- angelslim/models/base_model.py | 27 ++++++++++++++++++- 4 files changed, 29 insertions(+), 24 deletions(-) diff --git a/angelslim/compressor/quant/modules/__init__.py b/angelslim/compressor/quant/modules/__init__.py index 1997da73..49b41bb2 100644 --- a/angelslim/compressor/quant/modules/__init__.py +++ b/angelslim/compressor/quant/modules/__init__.py @@ -19,6 +19,7 @@ from .gptq.gptq import GPTQ # noqa: F401 from .gptq.gptq_module import GPTQModule # noqa: F401 from .helper_layer import GPTQQuantLinear # noqa: F401 +from .helper_layer import NVFP4QDQModule # noqa: F401 from .helper_layer import QDQModule # noqa: F401 from .helper_layer import QDQSingleModule # noqa: F401 from .helper_layer import QLinear # noqa: F401 diff --git a/angelslim/compressor/quant/modules/nvfp4/nvfp4.py b/angelslim/compressor/quant/modules/nvfp4/nvfp4.py index 4d0b00e4..d662ac95 100644 --- a/angelslim/compressor/quant/modules/nvfp4/nvfp4.py +++ b/angelslim/compressor/quant/modules/nvfp4/nvfp4.py @@ -16,7 +16,6 @@ import torch from .....utils import print_info -from ...modules.helper_layer import NVFP4QDQModule __all__ = ["NVFP4"] @@ -28,7 +27,7 @@ def __init__( ): """ Args: - model(nn.Module, required): The model to be smoothed. + model(nn.Module, required): The model to be quanted. """ super(NVFP4, self).__init__() self.model = model @@ -104,23 +103,3 @@ def post_process(self, sub_layer, name): input_observer_amax = self.model.act_scales_dict[name] input_scale = self.get_activation_scaling_factor(input_observer_amax) self.model.act_scales_dict[name] = input_scale - - def get_qdq_module(self, sub_layer, name): - act_scale, weight_scale, weight_scale_2 = None, None, None - if name in self.model.act_scales_dict: - act_scale = self.model.act_scales_dict[name] - if name in self.model.weight_scales_dict: - weight_scale = self.model.weight_scales_dict[name] - if name in self.model.weight_scales_dict_2: - weight_scale_2 = self.model.weight_scales_dict_2[name] - - q_linear = NVFP4QDQModule( - weight=sub_layer.weight, - weight_scale=weight_scale, - weight_scale_2=weight_scale_2, - bias=sub_layer.bias, - block_size=self.block_size, - input_scale=act_scale, - ) - - return q_linear diff --git a/angelslim/compressor/quant/ptq.py b/angelslim/compressor/quant/ptq.py index 822bb8b5..7b85d479 100644 --- a/angelslim/compressor/quant/ptq.py +++ b/angelslim/compressor/quant/ptq.py @@ -226,7 +226,7 @@ def _convert(self): if "nvfp4" in self.quant_algo: self.nvfp4.post_process(sub_layer, name) - qdq_module = self.nvfp4.get_qdq_module(sub_layer, name) + qdq_module = self.quant_model.get_nvfp4_qdq_module(sub_layer, name) else: qdq_module = self.quant_model.get_qdq_module(sub_layer, name) diff --git a/angelslim/models/base_model.py b/angelslim/models/base_model.py index f63e68cd..54c6c46d 100644 --- a/angelslim/models/base_model.py +++ b/angelslim/models/base_model.py @@ -22,7 +22,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from ..compressor.quant.core import QuantConfig -from ..compressor.quant.modules import QDQModule +from ..compressor.quant.modules import NVFP4QDQModule, QDQModule from ..utils import common_prefix, print_info __all__ = ["BaseLLMModel", "BaseDiffusionModel"] @@ -145,6 +145,31 @@ def get_qdq_module(self, sub_layer, name): raise NotImplementedError return q_linear + def get_nvfp4_qdq_module(self, sub_layer, name): + act_scale, weight_scale, weight_scale_2 = None, None, None + block_size = self.quant_config.quant_algo_info["block_size"] + if name in self.act_scales_dict: + act_scale = self.act_scales_dict[name] + if name in self.weight_scales_dict: + weight_scale = self.weight_scales_dict[name] + if name in self.weight_scales_dict_2: + weight_scale_2 = self.weight_scales_dict_2[name] + if self.deploy_backend in ["vllm", "huggingface", "trtllm", "tensorrt"]: + q_linear = NVFP4QDQModule( + weight=sub_layer.weight, + weight_scale=weight_scale, + weight_scale_2=weight_scale_2, + bias=sub_layer.bias, + block_size=block_size, + input_scale=act_scale, + ) + else: + print_info( + "current {} deploy_backend not support".format(self.deploy_backend) + ) + raise NotImplementedError + return q_linear + def get_kvcache_observer_layers_names(self, observe_names): names = ["self_attn.k_proj", "self_attn.v_proj"] return [