Skip to content

Commit ce14403

Browse files
Add low precision attention API from torchao to TorchAoConfig
1 parent c8c8401 commit ce14403

2 files changed

Lines changed: 163 additions & 8 deletions

File tree

src/diffusers/quantizers/quantization_config.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -433,14 +433,19 @@ class TorchAoConfig(QuantizationConfigMixin):
433433
"""This is a config class for torchao quantization/sparsity techniques.
434434
435435
Args:
436-
quant_type (`AOBaseConfig`):
436+
quant_type (`AOBaseConfig` | None):
437437
An `AOBaseConfig` subclass instance specifying the quantization type. See the [torchao
438438
documentation](https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize) for
439439
available config classes (e.g. `Int4WeightOnlyConfig`, `Int8WeightOnlyConfig`, `Float8WeightOnlyConfig`,
440440
`Float8DynamicActivationFloat8WeightConfig`, etc.).
441+
Pass `None` when only `attention_backend` is used without weight quantization.
441442
modules_to_not_convert (`list[str]`, *optional*, default to `None`):
442443
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
443444
modules left in their original precision.
445+
attention_backend (`str`, *optional*, default to `None`):
446+
Low-precision attention backend to use. Currently supported: `"fp8_fa3"` (FP8 attention using Flash
447+
Attention 3, requires Hopper GPU with SM90+). This is orthogonal to weight quantization — you can use
448+
either or both. When used with `torch.compile`, RoPE fusion is automatically enabled.
444449
445450
Example:
446451
```python
@@ -454,22 +459,47 @@ class TorchAoConfig(QuantizationConfigMixin):
454459
quantization_config=quantization_config,
455460
torch_dtype=torch.bfloat16,
456461
)
462+
463+
# FP8 attention only (no weight quantization)
464+
quantization_config = TorchAoConfig(attention_backend="fp8_fa3")
465+
466+
# Combined: weight quantization + FP8 attention
467+
quantization_config = TorchAoConfig(Int8WeightOnlyConfig(), attention_backend="fp8_fa3")
457468
```
458469
"""
459470

471+
_SUPPORTED_ATTENTION_BACKENDS = {"fp8_fa3"}
472+
460473
def __init__(
461474
self,
462-
quant_type: "AOBaseConfig", # noqa: F821
475+
quant_type: "AOBaseConfig | None" = None, # noqa: F821
463476
modules_to_not_convert: list[str] | None = None,
477+
attention_backend: str | None = None,
464478
**kwargs,
465479
) -> None:
466480
self.quant_method = QuantizationMethod.TORCHAO
467481
self.quant_type = quant_type
468482
self.modules_to_not_convert = modules_to_not_convert
483+
self.attention_backend = attention_backend
469484

470485
self.post_init()
471486

472487
def post_init(self):
488+
if self.quant_type is None and self.attention_backend is None:
489+
raise ValueError(
490+
"At least one of `quant_type` or `attention_backend` must be provided."
491+
)
492+
493+
if self.attention_backend is not None and self.attention_backend not in self._SUPPORTED_ATTENTION_BACKENDS:
494+
raise ValueError(
495+
f"Unsupported attention_backend: {self.attention_backend!r}. "
496+
f"Supported backends: {self._SUPPORTED_ATTENTION_BACKENDS}"
497+
)
498+
499+
# Skip quant_type validation when only attention_backend is used
500+
if self.quant_type is None:
501+
return
502+
473503
if is_torchao_version("<", "0.15.0"):
474504
raise ValueError("TorchAoConfig requires torchao >= 0.15.0. Please upgrade with `pip install -U torchao`.")
475505

@@ -482,6 +512,12 @@ def to_dict(self):
482512
"""Convert configuration to a dictionary."""
483513
d = super().to_dict()
484514

515+
if self.attention_backend is not None:
516+
d["attention_backend"] = self.attention_backend
517+
518+
if self.quant_type is None:
519+
return d
520+
485521
# Handle AOBaseConfig serialization
486522
from torchao.core.config import config_to_dict
487523

@@ -498,8 +534,11 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
498534
if not is_torchao_version(">=", "0.15.0"):
499535
raise NotImplementedError("TorchAoConfig requires torchao >= 0.15.0 for construction from dict")
500536
config_dict = config_dict.copy()
501-
quant_type = config_dict.pop("quant_type")
537+
quant_type = config_dict.pop("quant_type", None)
538+
attention_backend = config_dict.pop("attention_backend", None)
502539

540+
if quant_type is None:
541+
return cls(quant_type=quant_type, attention_backend=attention_backend, **config_dict)
503542
# Check if we only have one key which is "default"
504543
# In the future we may update this
505544
assert len(quant_type) == 1 and "default" in quant_type, (
@@ -512,7 +551,7 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
512551

513552
quant_type = config_from_dict(quant_type)
514553

515-
return cls(quant_type=quant_type, **config_dict)
554+
return cls(quant_type=quant_type, attention_backend=attention_backend, **config_dict)
516555

517556
def get_apply_tensor_subclass(self):
518557
"""Create the appropriate quantization method based on configuration."""

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 120 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import importlib
2121
import re
2222
import types
23+
from functools import partial
2324
from typing import TYPE_CHECKING, Any
2425

2526
from packaging import version
@@ -188,8 +189,58 @@ def validate_environment(self, *args, **kwargs):
188189
f"In order to use TorchAO pre-quantized model, you need to have torch>=2.5.0. However, the current version is {torch_version}."
189190
)
190191

192+
attention_backend = getattr(self.quantization_config, "attention_backend", None)
193+
if attention_backend is not None:
194+
self._validate_attention_environment(attention_backend)
195+
196+
def _validate_attention_environment(self, attention_backend):
197+
"""Validate that the environment supports the requested attention backend."""
198+
# Check torchao.prototype.attention is importable
199+
try:
200+
importlib.import_module("torchao.prototype.attention")
201+
except (ImportError, ModuleNotFoundError):
202+
raise ImportError(
203+
f"attention_backend={attention_backend!r} requires `torchao.prototype.attention`. "
204+
"Please install a version of torchao that includes the prototype attention module."
205+
)
206+
207+
# Check PyTorch >= 2.11.0
208+
torch_version_parsed = version.parse(version.parse(importlib.metadata.version("torch")).base_version)
209+
if torch_version_parsed < version.parse("2.11.0"):
210+
raise RuntimeError(
211+
f"attention_backend={attention_backend!r} requires PyTorch >= 2.11.0, "
212+
f"but the current version is {torch_version_parsed}."
213+
)
214+
215+
# Check CUDA available with SM90+ (Hopper)
216+
if not torch.cuda.is_available():
217+
raise RuntimeError(
218+
f"attention_backend={attention_backend!r} requires CUDA."
219+
)
220+
major, minor = torch.cuda.get_device_capability()
221+
if major < 9:
222+
raise RuntimeError(
223+
f"attention_backend={attention_backend!r} requires Hopper GPU (SM90+), "
224+
f"but the current device has SM{major}{minor}."
225+
)
226+
227+
# Check FA3 availability
228+
try:
229+
importlib.import_module("flash_attn_interface")
230+
except (ImportError, ModuleNotFoundError):
231+
raise ImportError(
232+
f"attention_backend={attention_backend!r} requires the flash-attn package with FA3 support. "
233+
"Please install flash-attn with FA3 support."
234+
)
235+
191236
def update_torch_dtype(self, torch_dtype):
192-
config_name = self.quantization_config.quant_type.__class__.__name__
237+
quant_type = self.quantization_config.quant_type
238+
if quant_type is None:
239+
if torch_dtype is None:
240+
torch_dtype = torch.bfloat16
241+
return torch_dtype
242+
243+
config_name = quant_type.__class__.__name__
193244
is_int_quant = config_name.startswith("Int") or config_name.startswith("Uint")
194245
if is_int_quant and torch_dtype is not None and torch_dtype != torch.bfloat16:
195246
logger.warning(
@@ -209,6 +260,10 @@ def update_torch_dtype(self, torch_dtype):
209260
return torch_dtype
210261

211262
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
263+
quant_type = self.quantization_config.quant_type
264+
if quant_type is None:
265+
return target_dtype
266+
212267
from accelerate.utils import CustomDtype
213268

214269
quant_type = self.quantization_config.quant_type
@@ -244,6 +299,9 @@ def check_if_quantized_param(
244299
state_dict: dict[str, Any],
245300
**kwargs,
246301
) -> bool:
302+
if self.quantization_config.quant_type is None:
303+
return False
304+
247305
param_device = kwargs.pop("param_device", None)
248306
# Check if the param_name is not in self.modules_to_not_convert
249307
if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert):
@@ -298,6 +356,9 @@ def get_cuda_warm_up_factor(self):
298356
- Use a division factor of 8 for int4 weights
299357
- Use a division factor of 4 for int8 weights
300358
"""
359+
if self.quantization_config.quant_type is None:
360+
return 4
361+
301362
quant_type = self.quantization_config.quant_type
302363
config_name = quant_type.__class__.__name__
303364
size_digit = fuzzy_match_size(config_name)
@@ -314,6 +375,13 @@ def _process_model_before_weight_loading(
314375
keep_in_fp32_modules: list[str] = [],
315376
**kwargs,
316377
):
378+
model.config.quantization_config = self.quantization_config
379+
380+
if self.quantization_config.quant_type is None:
381+
# Attention-only mode: no weight quantization setup needed
382+
self.modules_to_not_convert = []
383+
return
384+
317385
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
318386

319387
if not isinstance(self.modules_to_not_convert, list):
@@ -332,11 +400,56 @@ def _process_model_before_weight_loading(
332400
# and tied modules are usually kept in FP32.
333401
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
334402

335-
model.config.quantization_config = self.quantization_config
336-
337403
def _process_model_after_weight_loading(self, model: "ModelMixin"):
404+
attention_backend = getattr(self.quantization_config, "attention_backend", None)
405+
if attention_backend is not None:
406+
self._apply_low_precision_attention(model, attention_backend)
338407
return model
339408

409+
def _apply_low_precision_attention(self, model, attention_backend):
410+
"""Apply low-precision attention by monkey-patching the model's forward.
411+
412+
Replaces the model's forward method with a wrapper that activates FA3 and
413+
swaps F.scaled_dot_product_attention with the FP8 custom op for each forward
414+
call.
415+
416+
Also sets the torch.compile pre-grad fusion pass for RoPE fusion.
417+
"""
418+
import torch._inductor.config as inductor_config
419+
import torch.nn.functional as F
420+
from torch.nn.attention import activate_flash_attention_impl, restore_flash_attention_impl
421+
422+
from torchao.prototype.attention.fp8_fa3.attention import _ops
423+
from torchao.prototype.attention.shared_utils.fusion_utils import rope_sdpa_fusion_pass
424+
from torchao.prototype.attention.shared_utils.wrapper import _make_causal_aware_sdpa
425+
426+
# Diffusion models don't use causal masks
427+
sdpa_patch_fn = _make_causal_aware_sdpa(_ops.fp8_sdpa_op, strip_causal_mask=False)
428+
429+
# Set the torch.compile fusion pass for RoPE fusion
430+
inductor_config.pre_grad_custom_pass = partial(
431+
rope_sdpa_fusion_pass,
432+
rope_sdpa_op=_ops.rope_sdpa_op,
433+
fp8_sdpa_op=_ops.fp8_sdpa_op,
434+
backend_name="FA3",
435+
)
436+
437+
original_forward = model.forward
438+
439+
def _fp8_attention_forward(*args, **kwargs):
440+
activate_flash_attention_impl("FA3")
441+
try:
442+
original_sdpa = F.scaled_dot_product_attention
443+
F.scaled_dot_product_attention = sdpa_patch_fn
444+
try:
445+
return original_forward(*args, **kwargs)
446+
finally:
447+
F.scaled_dot_product_attention = original_sdpa
448+
finally:
449+
restore_flash_attention_impl()
450+
451+
model.forward = _fp8_attention_forward
452+
340453
def is_serializable(self, safe_serialization=None):
341454
# TODO(aryan): needs to be tested
342455
if safe_serialization:
@@ -371,7 +484,10 @@ def is_serializable(self, safe_serialization=None):
371484

372485
@property
373486
def is_trainable(self):
374-
return self.quantization_config.quant_type.__class__.__name__ in self._TRAINABLE_QUANTIZATION_CONFIGS
487+
quant_type = self.quantization_config.quant_type
488+
if quant_type is None:
489+
return False
490+
return quant_type.__class__.__name__ in self._TRAINABLE_QUANTIZATION_CONFIGS
375491

376492
@property
377493
def is_compileable(self) -> bool:

0 commit comments

Comments
 (0)