2121import re
2222import types
2323from fnmatch import fnmatch
24+ from functools import partial
2425from typing import TYPE_CHECKING , Any
2526
2627from packaging import version
@@ -198,8 +199,57 @@ def validate_environment(self, *args, **kwargs):
198199 f"In order to use TorchAO pre-quantized model, you need to have torch>=2.5.0. However, the current version is { torch_version } ."
199200 )
200201
202+ attention_backend = getattr (self .quantization_config , "attention_backend" , None )
203+ if attention_backend is not None :
204+ self ._validate_attention_environment (attention_backend )
205+
206+ def _validate_attention_environment (self , attention_backend ):
207+ """Validate that the environment supports the requested attention backend."""
208+ # Check torchao.prototype.attention is importable
209+ try :
210+ importlib .import_module ("torchao.prototype.attention" )
211+ except (ImportError , ModuleNotFoundError ):
212+ raise ImportError (
213+ f"attention_backend={ attention_backend !r} requires `torchao.prototype.attention`. "
214+ "Please install a version of torchao that includes the prototype attention module."
215+ )
216+
217+ # Check PyTorch >= 2.11.0
218+ torch_version_parsed = version .parse (importlib .metadata .version ("torch" ))
219+ if torch_version_parsed < version .parse ("2.11.0" ):
220+ raise RuntimeError (
221+ f"attention_backend={ attention_backend !r} requires PyTorch >= 2.11.0, "
222+ f"but the current version is { torch_version_parsed } ."
223+ )
224+
225+ # Check CUDA available with SM90+ (Hopper)
226+ if not torch .cuda .is_available ():
227+ raise RuntimeError (
228+ f"attention_backend={ attention_backend !r} requires CUDA."
229+ )
230+ major , minor = torch .cuda .get_device_capability ()
231+ if major < 9 :
232+ raise RuntimeError (
233+ f"attention_backend={ attention_backend !r} requires Hopper GPU (SM90+), "
234+ f"but the current device has SM{ major } { minor } ."
235+ )
236+
237+ # Check FA3 availability
238+ try :
239+ importlib .import_module ("flash_attn_interface" )
240+ except (ImportError , ModuleNotFoundError ):
241+ raise ImportError (
242+ f"attention_backend={ attention_backend !r} requires the flash-attn package with FA3 support. "
243+ "Please install flash-attn with FA3 support."
244+ )
245+
201246 def update_torch_dtype (self , torch_dtype ):
202247 quant_type = self .quantization_config .quant_type
248+ if quant_type is None :
249+ if torch_dtype is None :
250+ torch_dtype = torch .bfloat16
251+ return torch_dtype
252+
203253 if isinstance (quant_type , str ) and (quant_type .startswith ("int" ) or quant_type .startswith ("uint" )):
204254 if torch_dtype is not None and torch_dtype != torch .bfloat16 :
205255 logger .warning (
@@ -220,6 +270,9 @@ def update_torch_dtype(self, torch_dtype):
220270
221271 def adjust_target_dtype (self , target_dtype : "torch.dtype" ) -> "torch.dtype" :
222272 quant_type = self .quantization_config .quant_type
273+ if quant_type is None :
274+ return target_dtype
275+
223276 from accelerate .utils import CustomDtype
224277
225278 if isinstance (quant_type , str ):
@@ -283,6 +336,9 @@ def check_if_quantized_param(
283336 state_dict : dict [str , Any ],
284337 ** kwargs ,
285338 ) -> bool :
339+ if self .quantization_config .quant_type is None :
340+ return False
341+
286342 param_device = kwargs .pop ("param_device" , None )
287343 # Check if the param_name is not in self.modules_to_not_convert
288344 if any ((key + "." in param_name ) or (key == param_name ) for key in self .modules_to_not_convert ):
@@ -337,6 +393,9 @@ def get_cuda_warm_up_factor(self):
337393 - Use a division factor of 8 for int4 weights
338394 - Use a division factor of 4 for int8 weights
339395 """
396+ if self .quantization_config .quant_type is None :
397+ return 4
398+
340399 # Original mapping for non-AOBaseConfig types
341400 # For the uint types, this is a best guess. Once these types become more used
342401 # we can look into their nuances.
@@ -368,6 +427,13 @@ def _process_model_before_weight_loading(
368427 keep_in_fp32_modules : list [str ] = [],
369428 ** kwargs ,
370429 ):
430+ model .config .quantization_config = self .quantization_config
431+
432+ if self .quantization_config .quant_type is None :
433+ # Attention-only mode: no weight quantization setup needed
434+ self .modules_to_not_convert = []
435+ return
436+
371437 self .modules_to_not_convert = self .quantization_config .modules_to_not_convert
372438
373439 if not isinstance (self .modules_to_not_convert , list ):
@@ -386,11 +452,53 @@ def _process_model_before_weight_loading(
386452 # and tied modules are usually kept in FP32.
387453 self .modules_to_not_convert = [module for module in self .modules_to_not_convert if module is not None ]
388454
389- model .config .quantization_config = self .quantization_config
390-
391455 def _process_model_after_weight_loading (self , model : "ModelMixin" ):
456+ attention_backend = getattr (self .quantization_config , "attention_backend" , None )
457+ if attention_backend is not None :
458+ self ._apply_low_precision_attention (model , attention_backend )
392459 return model
393460
461+ def _apply_low_precision_attention (self , model , attention_backend ):
462+ """Apply low-precision attention via forward hooks.
463+
464+ Uses forward pre/post hooks to monkey-patch F.scaled_dot_product_attention with
465+ the FP8 custom op during model forward, and sets the torch.compile pre-grad
466+ fusion pass for RoPE fusion.
467+ """
468+ import torch ._inductor .config as inductor_config
469+ import torch .nn .functional as F
470+ from torch .nn .attention import activate_flash_attention_impl , restore_flash_attention_impl
471+
472+ from torchao .prototype .attention .fp8_fa3 .attention import _ops
473+ from torchao .prototype .attention .shared_utils .fusion_utils import rope_sdpa_fusion_pass
474+ from torchao .prototype .attention .shared_utils .wrapper import _make_causal_aware_sdpa
475+
476+ # Diffusion models don't use causal masks
477+ sdpa_patch_fn = _make_causal_aware_sdpa (_ops .fp8_sdpa_op , strip_causal_mask = False )
478+
479+ # Set the torch.compile fusion pass for RoPE fusion
480+ inductor_config .pre_grad_custom_pass = partial (
481+ rope_sdpa_fusion_pass ,
482+ rope_sdpa_op = _ops .rope_sdpa_op ,
483+ fp8_sdpa_op = _ops .fp8_sdpa_op ,
484+ backend_name = "FA3" ,
485+ )
486+
487+ flash_impl_name = "FA3"
488+
489+ def _pre_hook (module , args , kwargs = None ):
490+ activate_flash_attention_impl (flash_impl_name )
491+ module ._original_sdpa = F .scaled_dot_product_attention
492+ F .scaled_dot_product_attention = sdpa_patch_fn
493+
494+ def _post_hook (module , args , output , kwargs = None ):
495+ F .scaled_dot_product_attention = module ._original_sdpa
496+ del module ._original_sdpa
497+ restore_flash_attention_impl ()
498+
499+ model .register_forward_pre_hook (_pre_hook , with_kwargs = True )
500+ model .register_forward_hook (_post_hook , with_kwargs = True )
501+
394502 def is_serializable (self , safe_serialization = None ):
395503 # TODO(aryan): needs to be tested
396504 if safe_serialization :
@@ -417,7 +525,12 @@ def is_serializable(self, safe_serialization=None):
417525
418526 @property
419527 def is_trainable (self ):
420- return self .quantization_config .quant_type .startswith ("int8" )
528+ quant_type = self .quantization_config .quant_type
529+ if quant_type is None :
530+ return False
531+ if isinstance (quant_type , str ):
532+ return quant_type .startswith ("int8" )
533+ return False
421534
422535 @property
423536 def is_compileable (self ) -> bool :
0 commit comments