2020import importlib
2121import re
2222import types
23+ from functools import partial
2324from typing import TYPE_CHECKING , Any
2425
2526from packaging import version
@@ -188,8 +189,58 @@ def validate_environment(self, *args, **kwargs):
188189 f"In order to use TorchAO pre-quantized model, you need to have torch>=2.5.0. However, the current version is { torch_version } ."
189190 )
190191
192+ attention_backend = getattr (self .quantization_config , "attention_backend" , None )
193+ if attention_backend is not None :
194+ self ._validate_attention_environment (attention_backend )
195+
196+ def _validate_attention_environment (self , attention_backend ):
197+ """Validate that the environment supports the requested attention backend."""
198+ # Check torchao.prototype.attention is importable
199+ try :
200+ importlib .import_module ("torchao.prototype.attention" )
201+ except (ImportError , ModuleNotFoundError ):
202+ raise ImportError (
203+ f"attention_backend={ attention_backend !r} requires `torchao.prototype.attention`. "
204+ "Please install a version of torchao that includes the prototype attention module."
205+ )
206+
207+ # Check PyTorch >= 2.11.0
208+ torch_version_parsed = version .parse (version .parse (importlib .metadata .version ("torch" )).base_version )
209+ if torch_version_parsed < version .parse ("2.11.0" ):
210+ raise RuntimeError (
211+ f"attention_backend={ attention_backend !r} requires PyTorch >= 2.11.0, "
212+ f"but the current version is { torch_version_parsed } ."
213+ )
214+
215+ # Check CUDA available with SM90+ (Hopper)
216+ if not torch .cuda .is_available ():
217+ raise RuntimeError (
218+ f"attention_backend={ attention_backend !r} requires CUDA."
219+ )
220+ major , minor = torch .cuda .get_device_capability ()
221+ if major < 9 :
222+ raise RuntimeError (
223+ f"attention_backend={ attention_backend !r} requires Hopper GPU (SM90+), "
224+ f"but the current device has SM{ major } { minor } ."
225+ )
226+
227+ # Check FA3 availability
228+ try :
229+ importlib .import_module ("flash_attn_interface" )
230+ except (ImportError , ModuleNotFoundError ):
231+ raise ImportError (
232+ f"attention_backend={ attention_backend !r} requires the flash-attn package with FA3 support. "
233+ "Please install flash-attn with FA3 support."
234+ )
235+
191236 def update_torch_dtype (self , torch_dtype ):
192- config_name = self .quantization_config .quant_type .__class__ .__name__
237+ quant_type = self .quantization_config .quant_type
238+ if quant_type is None :
239+ if torch_dtype is None :
240+ torch_dtype = torch .bfloat16
241+ return torch_dtype
242+
243+ config_name = quant_type .__class__ .__name__
193244 is_int_quant = config_name .startswith ("Int" ) or config_name .startswith ("Uint" )
194245 if is_int_quant and torch_dtype is not None and torch_dtype != torch .bfloat16 :
195246 logger .warning (
@@ -209,6 +260,10 @@ def update_torch_dtype(self, torch_dtype):
209260 return torch_dtype
210261
211262 def adjust_target_dtype (self , target_dtype : "torch.dtype" ) -> "torch.dtype" :
263+ quant_type = self .quantization_config .quant_type
264+ if quant_type is None :
265+ return target_dtype
266+
212267 from accelerate .utils import CustomDtype
213268
214269 quant_type = self .quantization_config .quant_type
@@ -244,6 +299,9 @@ def check_if_quantized_param(
244299 state_dict : dict [str , Any ],
245300 ** kwargs ,
246301 ) -> bool :
302+ if self .quantization_config .quant_type is None :
303+ return False
304+
247305 param_device = kwargs .pop ("param_device" , None )
248306 # Check if the param_name is not in self.modules_to_not_convert
249307 if any ((key + "." in param_name ) or (key == param_name ) for key in self .modules_to_not_convert ):
@@ -298,6 +356,9 @@ def get_cuda_warm_up_factor(self):
298356 - Use a division factor of 8 for int4 weights
299357 - Use a division factor of 4 for int8 weights
300358 """
359+ if self .quantization_config .quant_type is None :
360+ return 4
361+
301362 quant_type = self .quantization_config .quant_type
302363 config_name = quant_type .__class__ .__name__
303364 size_digit = fuzzy_match_size (config_name )
@@ -314,6 +375,13 @@ def _process_model_before_weight_loading(
314375 keep_in_fp32_modules : list [str ] = [],
315376 ** kwargs ,
316377 ):
378+ model .config .quantization_config = self .quantization_config
379+
380+ if self .quantization_config .quant_type is None :
381+ # Attention-only mode: no weight quantization setup needed
382+ self .modules_to_not_convert = []
383+ return
384+
317385 self .modules_to_not_convert = self .quantization_config .modules_to_not_convert
318386
319387 if not isinstance (self .modules_to_not_convert , list ):
@@ -332,11 +400,56 @@ def _process_model_before_weight_loading(
332400 # and tied modules are usually kept in FP32.
333401 self .modules_to_not_convert = [module for module in self .modules_to_not_convert if module is not None ]
334402
335- model .config .quantization_config = self .quantization_config
336-
337403 def _process_model_after_weight_loading (self , model : "ModelMixin" ):
404+ attention_backend = getattr (self .quantization_config , "attention_backend" , None )
405+ if attention_backend is not None :
406+ self ._apply_low_precision_attention (model , attention_backend )
338407 return model
339408
409+ def _apply_low_precision_attention (self , model , attention_backend ):
410+ """Apply low-precision attention by monkey-patching the model's forward.
411+
412+ Replaces the model's forward method with a wrapper that activates FA3 and
413+ swaps F.scaled_dot_product_attention with the FP8 custom op for each forward
414+ call.
415+
416+ Also sets the torch.compile pre-grad fusion pass for RoPE fusion.
417+ """
418+ import torch ._inductor .config as inductor_config
419+ import torch .nn .functional as F
420+ from torch .nn .attention import activate_flash_attention_impl , restore_flash_attention_impl
421+
422+ from torchao .prototype .attention .fp8_fa3 .attention import _ops
423+ from torchao .prototype .attention .shared_utils .fusion_utils import rope_sdpa_fusion_pass
424+ from torchao .prototype .attention .shared_utils .wrapper import _make_causal_aware_sdpa
425+
426+ # Diffusion models don't use causal masks
427+ sdpa_patch_fn = _make_causal_aware_sdpa (_ops .fp8_sdpa_op , strip_causal_mask = False )
428+
429+ # Set the torch.compile fusion pass for RoPE fusion
430+ inductor_config .pre_grad_custom_pass = partial (
431+ rope_sdpa_fusion_pass ,
432+ rope_sdpa_op = _ops .rope_sdpa_op ,
433+ fp8_sdpa_op = _ops .fp8_sdpa_op ,
434+ backend_name = "FA3" ,
435+ )
436+
437+ original_forward = model .forward
438+
439+ def _fp8_attention_forward (* args , ** kwargs ):
440+ activate_flash_attention_impl ("FA3" )
441+ try :
442+ original_sdpa = F .scaled_dot_product_attention
443+ F .scaled_dot_product_attention = sdpa_patch_fn
444+ try :
445+ return original_forward (* args , ** kwargs )
446+ finally :
447+ F .scaled_dot_product_attention = original_sdpa
448+ finally :
449+ restore_flash_attention_impl ()
450+
451+ model .forward = _fp8_attention_forward
452+
340453 def is_serializable (self , safe_serialization = None ):
341454 # TODO(aryan): needs to be tested
342455 if safe_serialization :
@@ -371,7 +484,10 @@ def is_serializable(self, safe_serialization=None):
371484
372485 @property
373486 def is_trainable (self ):
374- return self .quantization_config .quant_type .__class__ .__name__ in self ._TRAINABLE_QUANTIZATION_CONFIGS
487+ quant_type = self .quantization_config .quant_type
488+ if quant_type is None :
489+ return False
490+ return quant_type .__class__ .__name__ in self ._TRAINABLE_QUANTIZATION_CONFIGS
375491
376492 @property
377493 def is_compileable (self ) -> bool :
0 commit comments