Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions angelslim/compressor/diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ from angelslim.compressor.diffusion import DynamicDiTQuantizer
# Load DiT pipeline with bfloat16 to reduce memory usage
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)

# Supported quantization types: "fp8-per-tensor", "fp8-per-block", "fp8-per-token"
# Supported quantization types: "fp8-per-tensor", "fp8-per-block", "fp8-per-token", "fp8-per-tensor-weight-only"
# If you want to use "fp8-per-block" + DeepGEMM on NVIDIA Hopper (SM90+) devices,
# please refer to https://github.com/deepseek-ai/DeepGEMM for installation instructions.
quantizer = DynamicDiTQuantizer(quant_type="fp8-per-tensor")
Expand Down Expand Up @@ -86,9 +86,10 @@ quantizer.export_quantized_weight(pipe.transformer, save_path="/path/to/save/qua

## Supported Quantization Types

AngelSlim supports three FP8 quantization strategies:
AngelSlim supports four FP8 quantization strategies:

- **`fp8-per-tensor`**: Per-tensor quantization for both weights and activations (recommended for most use cases)
- **`fp8-per-tensor-weight-only`**: Weight-only quantization with per-tensor scaling (weights: FP8, activations: BF16/FP16)
- **`fp8-per-block`**: Per-block quantization with DeepGEMM support for NVIDIA Hopper (SM90+) devices
- **`fp8-per-token`**: Per-token quantization for fine-grained control

Expand Down Expand Up @@ -132,7 +133,7 @@ The main quantizer class for DiT models.

#### Constructor Parameters

- `quant_type` (str): Quantization type - "fp8-per-tensor", "fp8-per-block", or "fp8-per-token"
- `quant_type` (str): Quantization type - "fp8-per-tensor", "fp8-per-tensor-weight-only", "fp8-per-block", or "fp8-per-token"
- `layer_filter` (Callable, optional): Custom function to determine which layers to quantize
- `include_patterns` (List[str|re.Pattern], optional): Patterns for layers to include
- `exclude_patterns` (List[str|re.Pattern], optional): Patterns for layers to exclude
Expand Down
4 changes: 2 additions & 2 deletions angelslim/compressor/diffusion/quant/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .linear import FP8DynamicLinear
from .linear import FP8DynamicLinear, FP8WeightOnlyLinear

__all__ = ["FP8DynamicLinear"]
__all__ = ["FP8DynamicLinear", "FP8WeightOnlyLinear"]
50 changes: 50 additions & 0 deletions angelslim/compressor/diffusion/quant/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
fp8_per_block_quant,
fp8_per_tensor_quant,
fp8_per_token_group_quant,
fp8_weight_only_gemm,
)


Expand Down Expand Up @@ -86,3 +87,52 @@ def forward(self, x):
output = output.unsqueeze(0)

return output


class FP8WeightOnlyLinear(torch.nn.Module):
"""
FP8 Weight-Only Quantized Linear Layer.

This layer quantizes only the weights to FP8 while keeping activations
in higher precision (bfloat16/float16). This provides a good balance
between memory savings and accuracy.
"""

def __init__(
self,
weight: torch.Tensor,
weight_scale: torch.Tensor,
bias: torch.nn.Parameter,
native_fp8_support: bool = False, # not used
quant_type: str = "fp8-per-tensor-weight-only",
):
super().__init__()
self.weight = torch.nn.Parameter(weight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
self.bias = bias
self.native_fp8_support = native_fp8_support # not used
self.quant_type = quant_type

@torch.compiler.disable(recursive=True)
def forward(self, x):
ori_dtype = x.dtype
assert ori_dtype in [
torch.float32,
torch.bfloat16,
torch.float16,
], "x.dtype must be float32, bfloat16, or float16"

if ori_dtype == torch.float32:
x = x.to(torch.bfloat16)

# For weight-only quantization, we don't quantize activations
# Just use the original activations with quantized weights
output = fp8_weight_only_gemm(
A=x, # Keep activations in original precision
B=self.weight,
B_scale=self.weight_scale,
bias=self.bias,
out_dtype=x.dtype,
)

return output
9 changes: 7 additions & 2 deletions angelslim/compressor/diffusion/quant/ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
import tqdm

from .modules import FP8DynamicLinear
from .modules import FP8DynamicLinear, FP8WeightOnlyLinear
from .quant_func import (
fp8_per_block_quant,
fp8_per_tensor_quant,
Expand Down Expand Up @@ -90,7 +90,10 @@ def __init__(

def _set_quantize_linear_module(self) -> torch.nn.Module:
if "fp8" in self.quant_type:
return FP8DynamicLinear
if self.quant_type == QuantType.FP8_PER_TENSOR_WEIGHT_ONLY:
return FP8WeightOnlyLinear
else:
return FP8DynamicLinear
raise ValueError(f"Invalid quant_type: {self.quant_type}")

def _quantize_linear_weight(
Expand All @@ -107,6 +110,8 @@ def _quantize_linear_weight(
if self.native_fp8_support:
_ensure_deep_gemm()
quant_weight, weight_scale = fp8_per_block_quant(linear.weight)
elif self.quant_type == QuantType.FP8_PER_TENSOR_WEIGHT_ONLY:
quant_weight, weight_scale = fp8_per_tensor_quant(linear.weight)
else:
raise ValueError(f"Invalid quant_type: {self.quant_type}")
return quant_weight, weight_scale
Expand Down
28 changes: 28 additions & 0 deletions angelslim/compressor/diffusion/quant/quant_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,34 @@ def fp8_gemm_torch_tensor_token(
return output


def fp8_weight_only_gemm(A, B, B_scale, bias, out_dtype):
"""Perform FP8 GEMM operation with fallback to standard linear.

Args:
A: Input tensor A.
B: Input tensor B.
B_scale: Scale factor for tensor B.
bias: Optional bias tensor.
out_dtype: Output data type.
native_fp8_support: Whether to use native FP8 support.
quant_type: Quantization type.
origin_shape: Original shape for reshaping.

Returns:
torch.Tensor: Result of the GEMM operation.
"""
if A.numel() == 0:
return torch.empty(size=(0, B.shape[0]), dtype=out_dtype, device=A.device)

output = torch.nn.functional.linear(
A.to(out_dtype),
B.to(out_dtype) * B_scale.to(out_dtype),
bias=bias,
)

return output


def fp8_gemm(
A: torch.Tensor,
A_scale: torch.Tensor,
Expand Down
8 changes: 7 additions & 1 deletion angelslim/compressor/diffusion/quant/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@ class QuantType:
FP8_PER_TENSOR = "fp8-per-tensor"
FP8_PER_TOKEN = "fp8-per-token"
FP8_PER_BLOCK = "fp8-per-block"
VALID_TYPES = [FP8_PER_TENSOR, FP8_PER_TOKEN, FP8_PER_BLOCK]
FP8_PER_TENSOR_WEIGHT_ONLY = "fp8-per-tensor-weight-only"
VALID_TYPES = [
FP8_PER_TENSOR,
FP8_PER_TOKEN,
FP8_PER_BLOCK,
FP8_PER_TENSOR_WEIGHT_ONLY,
]

@classmethod
def validate(cls, quant_type: str):
Expand Down
5 changes: 3 additions & 2 deletions docs/source/features/diffusion/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ image = pipe("A cat holding a sign that says hello world",
image.save("flux-schnell_fp8_per_tensor.png")
```

### 方式2:模型动态量化
### 方式2:模型动态量化 & 仅权重量化

```python
import torch
Expand Down Expand Up @@ -66,6 +66,7 @@ quantizer.export_quantized_weight(pipe.transformer, save_path="/path/to/save/qua
## 支持的FP8量化类型

- **fp8-per-tensor**:全局per-tensor量化(推荐)
- **fp8-per-tensor-weight-only**:权重量化(权重:FP8,激活:BF16/FP16)
- **fp8-per-block**:per-block量化,支持NVIDIA Hopper (SM90+) DeepGEMM
- **fp8-per-token**:per-token量化,粒度更细

Expand Down Expand Up @@ -105,7 +106,7 @@ quantizer = DynamicDiTQuantizer(

### DynamicDiTQuantizer

- `quant_type`:量化类型("fp8-per-tensor"/"fp8-per-block"/"fp8-per-token")
- `quant_type`:量化类型("fp8-per-tensor"/"fp8-per-tensor-weight-only"/"fp8-per-block"/"fp8-per-token")
- `layer_filter`:自定义筛选函数(可选)
- `include_patterns`/`exclude_patterns`:包含/排除哪些层(字符串或正则,支持混用)
- `native_fp8_support`:是否使用原生FP8支持(如支持自动检测)
Expand Down