Skip to content

Commit 125a7a8

Browse files
Add low precision attention API from torchao to TorchAoConfig
1 parent 0c01a4b commit 125a7a8

2 files changed

Lines changed: 159 additions & 9 deletions

File tree

src/diffusers/quantizers/quantization_config.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ class TorchAoConfig(QuantizationConfigMixin):
446446
"""This is a config class for torchao quantization/sparsity techniques.
447447
448448
Args:
449-
quant_type (`str` | AOBaseConfig):
449+
quant_type (`str` | AOBaseConfig | None):
450450
The type of quantization we want to use, currently supporting:
451451
- **Integer quantization:**
452452
- Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`,
@@ -469,9 +469,14 @@ class TorchAoConfig(QuantizationConfigMixin):
469469
- Full function names: `uintx_weight_only`
470470
- Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo`
471471
- An AOBaseConfig instance: for more advanced configuration options.
472+
- `None`: when only `attention_backend` is used without weight quantization.
472473
modules_to_not_convert (`list[str]`, *optional*, default to `None`):
473474
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
474475
modules left in their original precision.
476+
attention_backend (`str`, *optional*, default to `None`):
477+
Low-precision attention backend to use. Currently supported: `"fp8_fa3"` (FP8 attention using Flash
478+
Attention 3, requires Hopper GPU with SM90+). This is orthogonal to weight quantization — you can use
479+
either or both. When used with `torch.compile`, RoPE fusion is automatically enabled.
475480
kwargs (`dict[str, Any]`, *optional*):
476481
The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization
477482
supports two keyword arguments `group_size` and `inner_k_tiles` currently. More API examples and
@@ -495,18 +500,28 @@ class TorchAoConfig(QuantizationConfigMixin):
495500
quantization_config=quantization_config,
496501
torch_dtype=torch.bfloat16,
497502
)
503+
504+
# FP8 attention only (no weight quantization)
505+
quantization_config = TorchAoConfig(attention_backend="fp8_fa3")
506+
507+
# Combined: weight quantization + FP8 attention
508+
quantization_config = TorchAoConfig("int8wo", attention_backend="fp8_fa3")
498509
```
499510
"""
500511

512+
_SUPPORTED_ATTENTION_BACKENDS = {"fp8_fa3"}
513+
501514
def __init__(
502515
self,
503-
quant_type: str | "AOBaseConfig", # noqa: F821
516+
quant_type: str | "AOBaseConfig" | None = None, # noqa: F821
504517
modules_to_not_convert: list[str] | None = None,
518+
attention_backend: str | None = None,
505519
**kwargs,
506520
) -> None:
507521
self.quant_method = QuantizationMethod.TORCHAO
508522
self.quant_type = quant_type
509523
self.modules_to_not_convert = modules_to_not_convert
524+
self.attention_backend = attention_backend
510525

511526
# When we load from serialized config, "quant_type_kwargs" will be the key
512527
if "quant_type_kwargs" in kwargs:
@@ -517,6 +532,21 @@ def __init__(
517532
self.post_init()
518533

519534
def post_init(self):
535+
if self.quant_type is None and self.attention_backend is None:
536+
raise ValueError(
537+
"At least one of `quant_type` or `attention_backend` must be provided."
538+
)
539+
540+
if self.attention_backend is not None and self.attention_backend not in self._SUPPORTED_ATTENTION_BACKENDS:
541+
raise ValueError(
542+
f"Unsupported attention_backend: {self.attention_backend!r}. "
543+
f"Supported backends: {self._SUPPORTED_ATTENTION_BACKENDS}"
544+
)
545+
546+
# Skip quant_type validation when only attention_backend is used
547+
if self.quant_type is None:
548+
return
549+
520550
if not isinstance(self.quant_type, str):
521551
if is_torchao_version("<=", "0.9.0"):
522552
raise ValueError(
@@ -570,6 +600,12 @@ def to_dict(self):
570600
"""Convert configuration to a dictionary."""
571601
d = super().to_dict()
572602

603+
if self.attention_backend is not None:
604+
d["attention_backend"] = self.attention_backend
605+
606+
if self.quant_type is None:
607+
return d
608+
573609
if isinstance(self.quant_type, str):
574610
# Handle layout serialization if present
575611
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
@@ -600,10 +636,11 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
600636
if not is_torchao_version(">", "0.9.0"):
601637
raise NotImplementedError("TorchAoConfig requires torchao > 0.9.0 for construction from dict")
602638
config_dict = config_dict.copy()
603-
quant_type = config_dict.pop("quant_type")
639+
quant_type = config_dict.pop("quant_type", None)
640+
attention_backend = config_dict.pop("attention_backend", None)
604641

605-
if isinstance(quant_type, str):
606-
return cls(quant_type=quant_type, **config_dict)
642+
if quant_type is None or isinstance(quant_type, str):
643+
return cls(quant_type=quant_type, attention_backend=attention_backend, **config_dict)
607644
# Check if we only have one key which is "default"
608645
# In the future we may update this
609646
assert len(quant_type) == 1 and "default" in quant_type, (
@@ -616,7 +653,7 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
616653

617654
quant_type = config_from_dict(quant_type)
618655

619-
return cls(quant_type=quant_type, **config_dict)
656+
return cls(quant_type=quant_type, attention_backend=attention_backend, **config_dict)
620657

621658
@classmethod
622659
def _get_torchao_quant_type_to_method(cls):

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 116 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import re
2222
import types
2323
from fnmatch import fnmatch
24+
from functools import partial
2425
from typing import TYPE_CHECKING, Any
2526

2627
from packaging import version
@@ -198,8 +199,57 @@ def validate_environment(self, *args, **kwargs):
198199
f"In order to use TorchAO pre-quantized model, you need to have torch>=2.5.0. However, the current version is {torch_version}."
199200
)
200201

202+
attention_backend = getattr(self.quantization_config, "attention_backend", None)
203+
if attention_backend is not None:
204+
self._validate_attention_environment(attention_backend)
205+
206+
def _validate_attention_environment(self, attention_backend):
207+
"""Validate that the environment supports the requested attention backend."""
208+
# Check torchao.prototype.attention is importable
209+
try:
210+
importlib.import_module("torchao.prototype.attention")
211+
except (ImportError, ModuleNotFoundError):
212+
raise ImportError(
213+
f"attention_backend={attention_backend!r} requires `torchao.prototype.attention`. "
214+
"Please install a version of torchao that includes the prototype attention module."
215+
)
216+
217+
# Check PyTorch >= 2.11.0
218+
torch_version_parsed = version.parse(importlib.metadata.version("torch"))
219+
if torch_version_parsed < version.parse("2.11.0"):
220+
raise RuntimeError(
221+
f"attention_backend={attention_backend!r} requires PyTorch >= 2.11.0, "
222+
f"but the current version is {torch_version_parsed}."
223+
)
224+
225+
# Check CUDA available with SM90+ (Hopper)
226+
if not torch.cuda.is_available():
227+
raise RuntimeError(
228+
f"attention_backend={attention_backend!r} requires CUDA."
229+
)
230+
major, minor = torch.cuda.get_device_capability()
231+
if major < 9:
232+
raise RuntimeError(
233+
f"attention_backend={attention_backend!r} requires Hopper GPU (SM90+), "
234+
f"but the current device has SM{major}{minor}."
235+
)
236+
237+
# Check FA3 availability
238+
try:
239+
importlib.import_module("flash_attn_interface")
240+
except (ImportError, ModuleNotFoundError):
241+
raise ImportError(
242+
f"attention_backend={attention_backend!r} requires the flash-attn package with FA3 support. "
243+
"Please install flash-attn with FA3 support."
244+
)
245+
201246
def update_torch_dtype(self, torch_dtype):
202247
quant_type = self.quantization_config.quant_type
248+
if quant_type is None:
249+
if torch_dtype is None:
250+
torch_dtype = torch.bfloat16
251+
return torch_dtype
252+
203253
if isinstance(quant_type, str) and (quant_type.startswith("int") or quant_type.startswith("uint")):
204254
if torch_dtype is not None and torch_dtype != torch.bfloat16:
205255
logger.warning(
@@ -220,6 +270,9 @@ def update_torch_dtype(self, torch_dtype):
220270

221271
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
222272
quant_type = self.quantization_config.quant_type
273+
if quant_type is None:
274+
return target_dtype
275+
223276
from accelerate.utils import CustomDtype
224277

225278
if isinstance(quant_type, str):
@@ -283,6 +336,9 @@ def check_if_quantized_param(
283336
state_dict: dict[str, Any],
284337
**kwargs,
285338
) -> bool:
339+
if self.quantization_config.quant_type is None:
340+
return False
341+
286342
param_device = kwargs.pop("param_device", None)
287343
# Check if the param_name is not in self.modules_to_not_convert
288344
if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert):
@@ -337,6 +393,9 @@ def get_cuda_warm_up_factor(self):
337393
- Use a division factor of 8 for int4 weights
338394
- Use a division factor of 4 for int8 weights
339395
"""
396+
if self.quantization_config.quant_type is None:
397+
return 4
398+
340399
# Original mapping for non-AOBaseConfig types
341400
# For the uint types, this is a best guess. Once these types become more used
342401
# we can look into their nuances.
@@ -368,6 +427,13 @@ def _process_model_before_weight_loading(
368427
keep_in_fp32_modules: list[str] = [],
369428
**kwargs,
370429
):
430+
model.config.quantization_config = self.quantization_config
431+
432+
if self.quantization_config.quant_type is None:
433+
# Attention-only mode: no weight quantization setup needed
434+
self.modules_to_not_convert = []
435+
return
436+
371437
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
372438

373439
if not isinstance(self.modules_to_not_convert, list):
@@ -386,11 +452,53 @@ def _process_model_before_weight_loading(
386452
# and tied modules are usually kept in FP32.
387453
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
388454

389-
model.config.quantization_config = self.quantization_config
390-
391455
def _process_model_after_weight_loading(self, model: "ModelMixin"):
456+
attention_backend = getattr(self.quantization_config, "attention_backend", None)
457+
if attention_backend is not None:
458+
self._apply_low_precision_attention(model, attention_backend)
392459
return model
393460

461+
def _apply_low_precision_attention(self, model, attention_backend):
462+
"""Apply low-precision attention via forward hooks.
463+
464+
Uses forward pre/post hooks to monkey-patch F.scaled_dot_product_attention with
465+
the FP8 custom op during model forward, and sets the torch.compile pre-grad
466+
fusion pass for RoPE fusion.
467+
"""
468+
import torch._inductor.config as inductor_config
469+
import torch.nn.functional as F
470+
from torch.nn.attention import activate_flash_attention_impl, restore_flash_attention_impl
471+
472+
from torchao.prototype.attention.fp8_fa3.attention import _ops
473+
from torchao.prototype.attention.shared_utils.fusion_utils import rope_sdpa_fusion_pass
474+
from torchao.prototype.attention.shared_utils.wrapper import _make_causal_aware_sdpa
475+
476+
# Diffusion models don't use causal masks
477+
sdpa_patch_fn = _make_causal_aware_sdpa(_ops.fp8_sdpa_op, strip_causal_mask=False)
478+
479+
# Set the torch.compile fusion pass for RoPE fusion
480+
inductor_config.pre_grad_custom_pass = partial(
481+
rope_sdpa_fusion_pass,
482+
rope_sdpa_op=_ops.rope_sdpa_op,
483+
fp8_sdpa_op=_ops.fp8_sdpa_op,
484+
backend_name="FA3",
485+
)
486+
487+
flash_impl_name = "FA3"
488+
489+
def _pre_hook(module, args, kwargs=None):
490+
activate_flash_attention_impl(flash_impl_name)
491+
module._original_sdpa = F.scaled_dot_product_attention
492+
F.scaled_dot_product_attention = sdpa_patch_fn
493+
494+
def _post_hook(module, args, output, kwargs=None):
495+
F.scaled_dot_product_attention = module._original_sdpa
496+
del module._original_sdpa
497+
restore_flash_attention_impl()
498+
499+
model.register_forward_pre_hook(_pre_hook, with_kwargs=True)
500+
model.register_forward_hook(_post_hook, with_kwargs=True)
501+
394502
def is_serializable(self, safe_serialization=None):
395503
# TODO(aryan): needs to be tested
396504
if safe_serialization:
@@ -417,7 +525,12 @@ def is_serializable(self, safe_serialization=None):
417525

418526
@property
419527
def is_trainable(self):
420-
return self.quantization_config.quant_type.startswith("int8")
528+
quant_type = self.quantization_config.quant_type
529+
if quant_type is None:
530+
return False
531+
if isinstance(quant_type, str):
532+
return quant_type.startswith("int8")
533+
return False
421534

422535
@property
423536
def is_compileable(self) -> bool:

0 commit comments

Comments
 (0)