Skip to content

Commit 1f6ac1c

Browse files
sym-botclaudegithub-actions[bot]
authored
fix: graceful fallback when attention backends fail to import (#13060)
* fix: graceful fallback when attention backends fail to import ## Problem External attention backends (flash_attn, xformers, sageattention, etc.) may be installed but fail to import at runtime due to ABI mismatches. For example, when `flash_attn` is compiled against PyTorch 2.4 but used with PyTorch 2.8, the import fails with: ``` OSError: .../flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEab ``` The current code uses `importlib.util.find_spec()` to check if packages exist, but this only verifies the package is installed—not that it can actually be imported. When the import fails, diffusers crashes instead of falling back to native PyTorch attention. ## Solution Wrap all external attention backend imports in try-except blocks that catch `ImportError` and `OSError`. On failure: 1. Log a warning message explaining the issue 2. Set the corresponding `_CAN_USE_*` flag to `False` 3. Set the imported functions to `None` This allows diffusers to gracefully degrade to PyTorch's native SDPA (scaled_dot_product_attention) instead of crashing. ## Affected backends - flash_attn (Flash Attention) - flash_attn_3 (Flash Attention 3) - aiter (AMD Instinct) - sageattention (SageAttention) - flex_attention (PyTorch Flex Attention) - torch_npu (Huawei NPU) - torch_xla (TPU/XLA) - xformers (Meta xFormers) ## Testing Tested with PyTorch 2.8.0 and flash_attn 2.7.4.post1 (compiled for PyTorch 2.4). Before: crashes on import. After: logs warning and uses native attention. * address review: use single logger and catch RuntimeError - Move logger to module level instead of creating per-backend loggers - Add RuntimeError to exception list alongside ImportError and OSError Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Apply style fixes --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 5e94d62 commit 1f6ac1c

File tree

1 file changed

+75
-22
lines changed

1 file changed

+75
-22
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@
6262
_REQUIRED_XLA_VERSION = "2.2"
6363
_REQUIRED_XFORMERS_VERSION = "0.0.29"
6464

65+
logger = get_logger(__name__) # pylint: disable=invalid-name
66+
6567
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
6668
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
6769
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
@@ -73,8 +75,18 @@
7375

7476

7577
if _CAN_USE_FLASH_ATTN:
76-
from flash_attn import flash_attn_func, flash_attn_varlen_func
77-
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
78+
try:
79+
from flash_attn import flash_attn_func, flash_attn_varlen_func
80+
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
81+
except (ImportError, OSError, RuntimeError) as e:
82+
# Handle ABI mismatch or other import failures gracefully.
83+
# This can happen when flash_attn was compiled against a different PyTorch version.
84+
logger.warning(f"flash_attn is installed but failed to import: {e}. Falling back to native PyTorch attention.")
85+
_CAN_USE_FLASH_ATTN = False
86+
flash_attn_func = None
87+
flash_attn_varlen_func = None
88+
_wrapped_flash_attn_backward = None
89+
_wrapped_flash_attn_forward = None
7890
else:
7991
flash_attn_func = None
8092
flash_attn_varlen_func = None
@@ -83,26 +95,47 @@
8395

8496

8597
if _CAN_USE_FLASH_ATTN_3:
86-
from flash_attn_interface import flash_attn_func as flash_attn_3_func
87-
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
98+
try:
99+
from flash_attn_interface import flash_attn_func as flash_attn_3_func
100+
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
101+
except (ImportError, OSError, RuntimeError) as e:
102+
logger.warning(f"flash_attn_3 failed to import: {e}. Falling back to native attention.")
103+
_CAN_USE_FLASH_ATTN_3 = False
104+
flash_attn_3_func = None
105+
flash_attn_3_varlen_func = None
88106
else:
89107
flash_attn_3_func = None
90108
flash_attn_3_varlen_func = None
91109

92110
if _CAN_USE_AITER_ATTN:
93-
from aiter import flash_attn_func as aiter_flash_attn_func
111+
try:
112+
from aiter import flash_attn_func as aiter_flash_attn_func
113+
except (ImportError, OSError, RuntimeError) as e:
114+
logger.warning(f"aiter failed to import: {e}. Falling back to native attention.")
115+
_CAN_USE_AITER_ATTN = False
116+
aiter_flash_attn_func = None
94117
else:
95118
aiter_flash_attn_func = None
96119

97120
if _CAN_USE_SAGE_ATTN:
98-
from sageattention import (
99-
sageattn,
100-
sageattn_qk_int8_pv_fp8_cuda,
101-
sageattn_qk_int8_pv_fp8_cuda_sm90,
102-
sageattn_qk_int8_pv_fp16_cuda,
103-
sageattn_qk_int8_pv_fp16_triton,
104-
sageattn_varlen,
105-
)
121+
try:
122+
from sageattention import (
123+
sageattn,
124+
sageattn_qk_int8_pv_fp8_cuda,
125+
sageattn_qk_int8_pv_fp8_cuda_sm90,
126+
sageattn_qk_int8_pv_fp16_cuda,
127+
sageattn_qk_int8_pv_fp16_triton,
128+
sageattn_varlen,
129+
)
130+
except (ImportError, OSError, RuntimeError) as e:
131+
logger.warning(f"sageattention failed to import: {e}. Falling back to native attention.")
132+
_CAN_USE_SAGE_ATTN = False
133+
sageattn = None
134+
sageattn_qk_int8_pv_fp8_cuda = None
135+
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
136+
sageattn_qk_int8_pv_fp16_cuda = None
137+
sageattn_qk_int8_pv_fp16_triton = None
138+
sageattn_varlen = None
106139
else:
107140
sageattn = None
108141
sageattn_qk_int8_pv_fp16_cuda = None
@@ -113,26 +146,48 @@
113146

114147

115148
if _CAN_USE_FLEX_ATTN:
116-
# We cannot import the flex_attention function from the package directly because it is expected (from the
117-
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
118-
# compiled function.
119-
import torch.nn.attention.flex_attention as flex_attention
149+
try:
150+
# We cannot import the flex_attention function from the package directly because it is expected (from the
151+
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
152+
# compiled function.
153+
import torch.nn.attention.flex_attention as flex_attention
154+
except (ImportError, OSError, RuntimeError) as e:
155+
logger.warning(f"flex_attention failed to import: {e}. Falling back to native attention.")
156+
_CAN_USE_FLEX_ATTN = False
157+
flex_attention = None
158+
else:
159+
flex_attention = None
120160

121161

122162
if _CAN_USE_NPU_ATTN:
123-
from torch_npu import npu_fusion_attention
163+
try:
164+
from torch_npu import npu_fusion_attention
165+
except (ImportError, OSError, RuntimeError) as e:
166+
logger.warning(f"torch_npu failed to import: {e}. Falling back to native attention.")
167+
_CAN_USE_NPU_ATTN = False
168+
npu_fusion_attention = None
124169
else:
125170
npu_fusion_attention = None
126171

127172

128173
if _CAN_USE_XLA_ATTN:
129-
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
174+
try:
175+
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
176+
except (ImportError, OSError, RuntimeError) as e:
177+
logger.warning(f"torch_xla failed to import: {e}. Falling back to native attention.")
178+
_CAN_USE_XLA_ATTN = False
179+
xla_flash_attention = None
130180
else:
131181
xla_flash_attention = None
132182

133183

134184
if _CAN_USE_XFORMERS_ATTN:
135-
import xformers.ops as xops
185+
try:
186+
import xformers.ops as xops
187+
except (ImportError, OSError, RuntimeError) as e:
188+
logger.warning(f"xformers failed to import: {e}. Falling back to native attention.")
189+
_CAN_USE_XFORMERS_ATTN = False
190+
xops = None
136191
else:
137192
xops = None
138193

@@ -158,8 +213,6 @@ def wrap(func):
158213
_register_fake = register_fake_no_op
159214

160215

161-
logger = get_logger(__name__) # pylint: disable=invalid-name
162-
163216
# TODO(aryan): Add support for the following:
164217
# - Sage Attention++
165218
# - block sparse, radial and other attention methods

0 commit comments

Comments
 (0)