From 2c2339190e899f73820ff736806c577d54d751e3 Mon Sep 17 00:00:00 2001 From: garygugong Date: Tue, 28 Oct 2025 15:25:27 +0800 Subject: [PATCH 1/2] add weight only fp8 linear for dit quant --- angelslim/compressor/diffusion/README.md | 7 +-- .../diffusion/quant/modules/__init__.py | 4 +- .../diffusion/quant/modules/linear.py | 49 +++++++++++++++++++ angelslim/compressor/diffusion/quant/ptq.py | 10 +++- .../compressor/diffusion/quant/quant_func.py | 28 +++++++++++ .../compressor/diffusion/quant/utils/utils.py | 3 +- .../source/features/diffusion/quantization.md | 5 +- 7 files changed, 96 insertions(+), 10 deletions(-) diff --git a/angelslim/compressor/diffusion/README.md b/angelslim/compressor/diffusion/README.md index 3b41dd52..d610be05 100644 --- a/angelslim/compressor/diffusion/README.md +++ b/angelslim/compressor/diffusion/README.md @@ -48,7 +48,7 @@ from angelslim.compressor.diffusion import DynamicDiTQuantizer # Load DiT pipeline with bfloat16 to reduce memory usage pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) -# Supported quantization types: "fp8-per-tensor", "fp8-per-block", "fp8-per-token" +# Supported quantization types: "fp8-per-tensor", "fp8-per-block", "fp8-per-token", "fp8-per-tensor-weight-only" # If you want to use "fp8-per-block" + DeepGEMM on NVIDIA Hopper (SM90+) devices, # please refer to https://github.com/deepseek-ai/DeepGEMM for installation instructions. quantizer = DynamicDiTQuantizer(quant_type="fp8-per-tensor") @@ -86,9 +86,10 @@ quantizer.export_quantized_weight(pipe.transformer, save_path="/path/to/save/qua ## Supported Quantization Types -AngelSlim supports three FP8 quantization strategies: +AngelSlim supports four FP8 quantization strategies: - **`fp8-per-tensor`**: Per-tensor quantization for both weights and activations (recommended for most use cases) +- **`fp8-per-tensor-weight-only`**: Weight-only quantization with per-tensor scaling (weights: FP8, activations: BF16/FP16) - **`fp8-per-block`**: Per-block quantization with DeepGEMM support for NVIDIA Hopper (SM90+) devices - **`fp8-per-token`**: Per-token quantization for fine-grained control @@ -132,7 +133,7 @@ The main quantizer class for DiT models. #### Constructor Parameters -- `quant_type` (str): Quantization type - "fp8-per-tensor", "fp8-per-block", or "fp8-per-token" +- `quant_type` (str): Quantization type - "fp8-per-tensor", "fp8-per-tensor-weight-only", "fp8-per-block", or "fp8-per-token" - `layer_filter` (Callable, optional): Custom function to determine which layers to quantize - `include_patterns` (List[str|re.Pattern], optional): Patterns for layers to include - `exclude_patterns` (List[str|re.Pattern], optional): Patterns for layers to exclude diff --git a/angelslim/compressor/diffusion/quant/modules/__init__.py b/angelslim/compressor/diffusion/quant/modules/__init__.py index 4cb954b5..311476ea 100644 --- a/angelslim/compressor/diffusion/quant/modules/__init__.py +++ b/angelslim/compressor/diffusion/quant/modules/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .linear import FP8DynamicLinear +from .linear import FP8DynamicLinear, FP8WeightOnlyLinear -__all__ = ["FP8DynamicLinear"] +__all__ = ["FP8DynamicLinear", "FP8WeightOnlyLinear"] diff --git a/angelslim/compressor/diffusion/quant/modules/linear.py b/angelslim/compressor/diffusion/quant/modules/linear.py index 2cd642f0..da289dd5 100644 --- a/angelslim/compressor/diffusion/quant/modules/linear.py +++ b/angelslim/compressor/diffusion/quant/modules/linear.py @@ -16,6 +16,7 @@ from ..quant_func import ( fp8_gemm, + fp8_weight_only_gemm, fp8_per_block_quant, fp8_per_tensor_quant, fp8_per_token_group_quant, @@ -86,3 +87,51 @@ def forward(self, x): output = output.unsqueeze(0) return output + + +class FP8WeightOnlyLinear(torch.nn.Module): + """ + FP8 Weight-Only Quantized Linear Layer. + + This layer quantizes only the weights to FP8 while keeping activations + in higher precision (bfloat16/float16). This provides a good balance + between memory savings and accuracy. + """ + def __init__( + self, + weight: torch.Tensor, + weight_scale: torch.Tensor, + bias: torch.nn.Parameter, + native_fp8_support: bool = False, # 保留接口但不使用 + quant_type: str = "fp8-per-tensor-weight-only", + ): + super().__init__() + self.weight = torch.nn.Parameter(weight, requires_grad=False) + self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + self.bias = bias + self.native_fp8_support = native_fp8_support # 保留但不使用 + self.quant_type = quant_type + + @torch.compiler.disable(recursive=True) + def forward(self, x): + ori_dtype = x.dtype + assert ori_dtype in [ + torch.float32, + torch.bfloat16, + torch.float16, + ], "x.dtype must be float32, bfloat16, or float16" + + if ori_dtype == torch.float32: + x = x.to(torch.bfloat16) + + # For weight-only quantization, we don't quantize activations + # Just use the original activations with quantized weights + output = fp8_weight_only_gemm( + A=x, # Keep activations in original precision + B=self.weight, + B_scale=self.weight_scale, + bias=self.bias, + out_dtype=x.dtype, + ) + + return output diff --git a/angelslim/compressor/diffusion/quant/ptq.py b/angelslim/compressor/diffusion/quant/ptq.py index edf8088e..3f5af951 100644 --- a/angelslim/compressor/diffusion/quant/ptq.py +++ b/angelslim/compressor/diffusion/quant/ptq.py @@ -20,7 +20,7 @@ import torch import tqdm -from .modules import FP8DynamicLinear +from .modules import FP8DynamicLinear, FP8WeightOnlyLinear from .quant_func import ( fp8_per_block_quant, fp8_per_tensor_quant, @@ -90,7 +90,10 @@ def __init__( def _set_quantize_linear_module(self) -> torch.nn.Module: if "fp8" in self.quant_type: - return FP8DynamicLinear + if self.quant_type == QuantType.FP8_PER_TENSOR_WEIGHT_ONLY: + return FP8WeightOnlyLinear + else: + return FP8DynamicLinear raise ValueError(f"Invalid quant_type: {self.quant_type}") def _quantize_linear_weight( @@ -107,6 +110,9 @@ def _quantize_linear_weight( if self.native_fp8_support: _ensure_deep_gemm() quant_weight, weight_scale = fp8_per_block_quant(linear.weight) + elif self.quant_type == QuantType.FP8_PER_TENSOR_WEIGHT_ONLY: + # For weight-only quantization, we can use per-tensor quantization for weights + quant_weight, weight_scale = fp8_per_tensor_quant(linear.weight) else: raise ValueError(f"Invalid quant_type: {self.quant_type}") return quant_weight, weight_scale diff --git a/angelslim/compressor/diffusion/quant/quant_func.py b/angelslim/compressor/diffusion/quant/quant_func.py index 4d94ec39..aebc5b62 100644 --- a/angelslim/compressor/diffusion/quant/quant_func.py +++ b/angelslim/compressor/diffusion/quant/quant_func.py @@ -232,6 +232,34 @@ def fp8_gemm_torch_tensor_token( return output +def fp8_weight_only_gemm(A, B, B_scale, bias, out_dtype): + """Perform FP8 GEMM operation with fallback to standard linear. + + Args: + A: Input tensor A. + B: Input tensor B. + B_scale: Scale factor for tensor B. + bias: Optional bias tensor. + out_dtype: Output data type. + native_fp8_support: Whether to use native FP8 support. + quant_type: Quantization type. + origin_shape: Original shape for reshaping. + + Returns: + torch.Tensor: Result of the GEMM operation. + """ + if A.numel() == 0: + return torch.empty(size=(0, B.shape[0]), dtype=out_dtype, device=A.device) + + output = torch.nn.functional.linear( + A.to(out_dtype), + B.to(out_dtype) * B_scale.to(out_dtype), + bias=bias, + ) + + return output + + def fp8_gemm( A: torch.Tensor, A_scale: torch.Tensor, diff --git a/angelslim/compressor/diffusion/quant/utils/utils.py b/angelslim/compressor/diffusion/quant/utils/utils.py index ef65e6d5..e2b5224a 100644 --- a/angelslim/compressor/diffusion/quant/utils/utils.py +++ b/angelslim/compressor/diffusion/quant/utils/utils.py @@ -33,7 +33,8 @@ class QuantType: FP8_PER_TENSOR = "fp8-per-tensor" FP8_PER_TOKEN = "fp8-per-token" FP8_PER_BLOCK = "fp8-per-block" - VALID_TYPES = [FP8_PER_TENSOR, FP8_PER_TOKEN, FP8_PER_BLOCK] + FP8_PER_TENSOR_WEIGHT_ONLY = "fp8-per-tensor-weight-only" + VALID_TYPES = [FP8_PER_TENSOR, FP8_PER_TOKEN, FP8_PER_BLOCK, FP8_PER_TENSOR_WEIGHT_ONLY] @classmethod def validate(cls, quant_type: str): diff --git a/docs/source/features/diffusion/quantization.md b/docs/source/features/diffusion/quantization.md index 521589aa..fbfe0de5 100644 --- a/docs/source/features/diffusion/quantization.md +++ b/docs/source/features/diffusion/quantization.md @@ -31,7 +31,7 @@ image = pipe("A cat holding a sign that says hello world", image.save("flux-schnell_fp8_per_tensor.png") ``` -### 方式2:模型动态量化 +### 方式2:模型动态量化 & 仅权重量化 ```python import torch @@ -66,6 +66,7 @@ quantizer.export_quantized_weight(pipe.transformer, save_path="/path/to/save/qua ## 支持的FP8量化类型 - **fp8-per-tensor**:全局per-tensor量化(推荐) +- **fp8-per-tensor-weight-only**:权重量化(权重:FP8,激活:BF16/FP16) - **fp8-per-block**:per-block量化,支持NVIDIA Hopper (SM90+) DeepGEMM - **fp8-per-token**:per-token量化,粒度更细 @@ -105,7 +106,7 @@ quantizer = DynamicDiTQuantizer( ### DynamicDiTQuantizer -- `quant_type`:量化类型("fp8-per-tensor"/"fp8-per-block"/"fp8-per-token") +- `quant_type`:量化类型("fp8-per-tensor"/"fp8-per-tensor-weight-only"/"fp8-per-block"/"fp8-per-token") - `layer_filter`:自定义筛选函数(可选) - `include_patterns`/`exclude_patterns`:包含/排除哪些层(字符串或正则,支持混用) - `native_fp8_support`:是否使用原生FP8支持(如支持自动检测) From 4f5a9810829805914ee56389a58bbac4b6eaa436 Mon Sep 17 00:00:00 2001 From: garygugong Date: Tue, 28 Oct 2025 15:32:19 +0800 Subject: [PATCH 2/2] fiix lint typo --- .../compressor/diffusion/quant/modules/linear.py | 13 +++++++------ angelslim/compressor/diffusion/quant/ptq.py | 1 - angelslim/compressor/diffusion/quant/utils/utils.py | 7 ++++++- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/angelslim/compressor/diffusion/quant/modules/linear.py b/angelslim/compressor/diffusion/quant/modules/linear.py index da289dd5..7383b856 100644 --- a/angelslim/compressor/diffusion/quant/modules/linear.py +++ b/angelslim/compressor/diffusion/quant/modules/linear.py @@ -16,10 +16,10 @@ from ..quant_func import ( fp8_gemm, - fp8_weight_only_gemm, fp8_per_block_quant, fp8_per_tensor_quant, fp8_per_token_group_quant, + fp8_weight_only_gemm, ) @@ -92,24 +92,25 @@ def forward(self, x): class FP8WeightOnlyLinear(torch.nn.Module): """ FP8 Weight-Only Quantized Linear Layer. - - This layer quantizes only the weights to FP8 while keeping activations - in higher precision (bfloat16/float16). This provides a good balance + + This layer quantizes only the weights to FP8 while keeping activations + in higher precision (bfloat16/float16). This provides a good balance between memory savings and accuracy. """ + def __init__( self, weight: torch.Tensor, weight_scale: torch.Tensor, bias: torch.nn.Parameter, - native_fp8_support: bool = False, # 保留接口但不使用 + native_fp8_support: bool = False, # not used quant_type: str = "fp8-per-tensor-weight-only", ): super().__init__() self.weight = torch.nn.Parameter(weight, requires_grad=False) self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) self.bias = bias - self.native_fp8_support = native_fp8_support # 保留但不使用 + self.native_fp8_support = native_fp8_support # not used self.quant_type = quant_type @torch.compiler.disable(recursive=True) diff --git a/angelslim/compressor/diffusion/quant/ptq.py b/angelslim/compressor/diffusion/quant/ptq.py index 3f5af951..1076dbf5 100644 --- a/angelslim/compressor/diffusion/quant/ptq.py +++ b/angelslim/compressor/diffusion/quant/ptq.py @@ -111,7 +111,6 @@ def _quantize_linear_weight( _ensure_deep_gemm() quant_weight, weight_scale = fp8_per_block_quant(linear.weight) elif self.quant_type == QuantType.FP8_PER_TENSOR_WEIGHT_ONLY: - # For weight-only quantization, we can use per-tensor quantization for weights quant_weight, weight_scale = fp8_per_tensor_quant(linear.weight) else: raise ValueError(f"Invalid quant_type: {self.quant_type}") diff --git a/angelslim/compressor/diffusion/quant/utils/utils.py b/angelslim/compressor/diffusion/quant/utils/utils.py index e2b5224a..8df6e18d 100644 --- a/angelslim/compressor/diffusion/quant/utils/utils.py +++ b/angelslim/compressor/diffusion/quant/utils/utils.py @@ -34,7 +34,12 @@ class QuantType: FP8_PER_TOKEN = "fp8-per-token" FP8_PER_BLOCK = "fp8-per-block" FP8_PER_TENSOR_WEIGHT_ONLY = "fp8-per-tensor-weight-only" - VALID_TYPES = [FP8_PER_TENSOR, FP8_PER_TOKEN, FP8_PER_BLOCK, FP8_PER_TENSOR_WEIGHT_ONLY] + VALID_TYPES = [ + FP8_PER_TENSOR, + FP8_PER_TOKEN, + FP8_PER_BLOCK, + FP8_PER_TENSOR_WEIGHT_ONLY, + ] @classmethod def validate(cls, quant_type: str):