Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions src/diffusers/quantizers/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,14 +433,19 @@ class TorchAoConfig(QuantizationConfigMixin):
"""This is a config class for torchao quantization/sparsity techniques.

Args:
quant_type (`AOBaseConfig`):
quant_type (`AOBaseConfig` | None):
An `AOBaseConfig` subclass instance specifying the quantization type. See the [torchao
documentation](https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize) for
available config classes (e.g. `Int4WeightOnlyConfig`, `Int8WeightOnlyConfig`, `Float8WeightOnlyConfig`,
`Float8DynamicActivationFloat8WeightConfig`, etc.).
Pass `None` when only `attention_backend` is used without weight quantization.
modules_to_not_convert (`list[str]`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
modules left in their original precision.
attention_backend (`str`, *optional*, default to `None`):
Low-precision attention backend to use. Currently supported: `"fp8_fa3"` (FP8 attention using Flash
Attention 3, requires Hopper GPU with SM90+). This is orthogonal to weight quantization — you can use
either or both. When used with `torch.compile`, RoPE fusion is automatically enabled.

Example:
```python
Expand All @@ -454,22 +459,47 @@ class TorchAoConfig(QuantizationConfigMixin):
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)

# FP8 attention only (no weight quantization)
quantization_config = TorchAoConfig(attention_backend="fp8_fa3")

# Combined: weight quantization + FP8 attention
quantization_config = TorchAoConfig(Int8WeightOnlyConfig(), attention_backend="fp8_fa3")
```
"""

_SUPPORTED_ATTENTION_BACKENDS = {"fp8_fa3"}

def __init__(
self,
quant_type: "AOBaseConfig", # noqa: F821
quant_type: "AOBaseConfig | None" = None, # noqa: F821
modules_to_not_convert: list[str] | None = None,
attention_backend: str | None = None,
**kwargs,
) -> None:
self.quant_method = QuantizationMethod.TORCHAO
self.quant_type = quant_type
self.modules_to_not_convert = modules_to_not_convert
self.attention_backend = attention_backend

self.post_init()

def post_init(self):
if self.quant_type is None and self.attention_backend is None:
raise ValueError(
"At least one of `quant_type` or `attention_backend` must be provided."
)

if self.attention_backend is not None and self.attention_backend not in self._SUPPORTED_ATTENTION_BACKENDS:
raise ValueError(
f"Unsupported attention_backend: {self.attention_backend!r}. "
f"Supported backends: {self._SUPPORTED_ATTENTION_BACKENDS}"
)

# Skip quant_type validation when only attention_backend is used
if self.quant_type is None:
return

if is_torchao_version("<", "0.15.0"):
raise ValueError("TorchAoConfig requires torchao >= 0.15.0. Please upgrade with `pip install -U torchao`.")

Expand All @@ -482,6 +512,12 @@ def to_dict(self):
"""Convert configuration to a dictionary."""
d = super().to_dict()

if self.attention_backend is not None:
d["attention_backend"] = self.attention_backend

if self.quant_type is None:
return d

# Handle AOBaseConfig serialization
from torchao.core.config import config_to_dict

Expand All @@ -498,8 +534,11 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
if not is_torchao_version(">=", "0.15.0"):
raise NotImplementedError("TorchAoConfig requires torchao >= 0.15.0 for construction from dict")
config_dict = config_dict.copy()
quant_type = config_dict.pop("quant_type")
quant_type = config_dict.pop("quant_type", None)
attention_backend = config_dict.pop("attention_backend", None)

if quant_type is None:
return cls(quant_type=quant_type, attention_backend=attention_backend, **config_dict)
# Check if we only have one key which is "default"
# In the future we may update this
assert len(quant_type) == 1 and "default" in quant_type, (
Expand All @@ -512,7 +551,7 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):

quant_type = config_from_dict(quant_type)

return cls(quant_type=quant_type, **config_dict)
return cls(quant_type=quant_type, attention_backend=attention_backend, **config_dict)

def get_apply_tensor_subclass(self):
"""Create the appropriate quantization method based on configuration."""
Expand Down
124 changes: 120 additions & 4 deletions src/diffusers/quantizers/torchao/torchao_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import importlib
import re
import types
from functools import partial
from typing import TYPE_CHECKING, Any

from packaging import version
Expand Down Expand Up @@ -188,8 +189,58 @@ def validate_environment(self, *args, **kwargs):
f"In order to use TorchAO pre-quantized model, you need to have torch>=2.5.0. However, the current version is {torch_version}."
)

attention_backend = getattr(self.quantization_config, "attention_backend", None)
if attention_backend is not None:
self._validate_attention_environment(attention_backend)

def _validate_attention_environment(self, attention_backend):
"""Validate that the environment supports the requested attention backend."""
# Check torchao.prototype.attention is importable
try:
importlib.import_module("torchao.prototype.attention")
except (ImportError, ModuleNotFoundError):
raise ImportError(
f"attention_backend={attention_backend!r} requires `torchao.prototype.attention`. "
"Please install a version of torchao that includes the prototype attention module."
)

# Check PyTorch >= 2.11.0
torch_version_parsed = version.parse(version.parse(importlib.metadata.version("torch")).base_version)
if torch_version_parsed < version.parse("2.11.0"):
raise RuntimeError(
f"attention_backend={attention_backend!r} requires PyTorch >= 2.11.0, "
f"but the current version is {torch_version_parsed}."
)

# Check CUDA available with SM90+ (Hopper)
if not torch.cuda.is_available():
raise RuntimeError(
f"attention_backend={attention_backend!r} requires CUDA."
)
major, minor = torch.cuda.get_device_capability()
if major < 9:
raise RuntimeError(
f"attention_backend={attention_backend!r} requires Hopper GPU (SM90+), "
f"but the current device has SM{major}{minor}."
)

# Check FA3 availability
try:
importlib.import_module("flash_attn_interface")
except (ImportError, ModuleNotFoundError):
raise ImportError(
f"attention_backend={attention_backend!r} requires the flash-attn package with FA3 support. "
"Please install flash-attn with FA3 support."
)

def update_torch_dtype(self, torch_dtype):
config_name = self.quantization_config.quant_type.__class__.__name__
quant_type = self.quantization_config.quant_type
if quant_type is None:
if torch_dtype is None:
torch_dtype = torch.bfloat16
return torch_dtype

config_name = quant_type.__class__.__name__
is_int_quant = config_name.startswith("Int") or config_name.startswith("Uint")
if is_int_quant and torch_dtype is not None and torch_dtype != torch.bfloat16:
logger.warning(
Expand All @@ -209,6 +260,10 @@ def update_torch_dtype(self, torch_dtype):
return torch_dtype

def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
quant_type = self.quantization_config.quant_type
if quant_type is None:
return target_dtype

from accelerate.utils import CustomDtype

quant_type = self.quantization_config.quant_type
Expand Down Expand Up @@ -244,6 +299,9 @@ def check_if_quantized_param(
state_dict: dict[str, Any],
**kwargs,
) -> bool:
if self.quantization_config.quant_type is None:
return False

param_device = kwargs.pop("param_device", None)
# Check if the param_name is not in self.modules_to_not_convert
if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert):
Expand Down Expand Up @@ -298,6 +356,9 @@ def get_cuda_warm_up_factor(self):
- Use a division factor of 8 for int4 weights
- Use a division factor of 4 for int8 weights
"""
if self.quantization_config.quant_type is None:
return 4

quant_type = self.quantization_config.quant_type
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)
Expand All @@ -314,6 +375,13 @@ def _process_model_before_weight_loading(
keep_in_fp32_modules: list[str] = [],
**kwargs,
):
model.config.quantization_config = self.quantization_config

if self.quantization_config.quant_type is None:
# Attention-only mode: no weight quantization setup needed
self.modules_to_not_convert = []
return

self.modules_to_not_convert = self.quantization_config.modules_to_not_convert

if not isinstance(self.modules_to_not_convert, list):
Expand All @@ -332,11 +400,56 @@ def _process_model_before_weight_loading(
# and tied modules are usually kept in FP32.
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]

model.config.quantization_config = self.quantization_config

def _process_model_after_weight_loading(self, model: "ModelMixin"):
attention_backend = getattr(self.quantization_config, "attention_backend", None)
if attention_backend is not None:
self._apply_low_precision_attention(model, attention_backend)
return model

def _apply_low_precision_attention(self, model, attention_backend):
"""Apply low-precision attention by monkey-patching the model's forward.

Replaces the model's forward method with a wrapper that activates FA3 and
swaps F.scaled_dot_product_attention with the FP8 custom op for each forward
call.

Also sets the torch.compile pre-grad fusion pass for RoPE fusion.
"""
import torch._inductor.config as inductor_config
import torch.nn.functional as F
from torch.nn.attention import activate_flash_attention_impl, restore_flash_attention_impl

from torchao.prototype.attention.fp8_fa3.attention import _ops
from torchao.prototype.attention.shared_utils.fusion_utils import rope_sdpa_fusion_pass
from torchao.prototype.attention.shared_utils.wrapper import _make_causal_aware_sdpa

# Diffusion models don't use causal masks
sdpa_patch_fn = _make_causal_aware_sdpa(_ops.fp8_sdpa_op, strip_causal_mask=False)

# Set the torch.compile fusion pass for RoPE fusion
inductor_config.pre_grad_custom_pass = partial(
rope_sdpa_fusion_pass,
rope_sdpa_op=_ops.rope_sdpa_op,
fp8_sdpa_op=_ops.fp8_sdpa_op,
backend_name="FA3",
)

original_forward = model.forward

def _fp8_attention_forward(*args, **kwargs):
activate_flash_attention_impl("FA3")
try:
original_sdpa = F.scaled_dot_product_attention
F.scaled_dot_product_attention = sdpa_patch_fn
try:
return original_forward(*args, **kwargs)
finally:
F.scaled_dot_product_attention = original_sdpa
finally:
restore_flash_attention_impl()

model.forward = _fp8_attention_forward

def is_serializable(self, safe_serialization=None):
# TODO(aryan): needs to be tested
if safe_serialization:
Expand Down Expand Up @@ -371,7 +484,10 @@ def is_serializable(self, safe_serialization=None):

@property
def is_trainable(self):
return self.quantization_config.quant_type.__class__.__name__ in self._TRAINABLE_QUANTIZATION_CONFIGS
quant_type = self.quantization_config.quant_type
if quant_type is None:
return False
return quant_type.__class__.__name__ in self._TRAINABLE_QUANTIZATION_CONFIGS

@property
def is_compileable(self) -> bool:
Expand Down
Loading