Skip to content

Commit e17b1b8

Browse files
authored
[Fix] Fix FA2 kernels ut (huggingface#42803)
* Fixed FA2 kernels UT * Update * Refactor FA2 kernel map * Update * Update
1 parent 40dc11c commit e17b1b8

4 files changed

Lines changed: 42 additions & 17 deletions

File tree

src/transformers/modeling_utils.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,12 @@
155155
_is_quantized = False
156156
_is_ds_init_called = False
157157

158+
# Mapping from flash attention implementations to their kernel fallback repositories
159+
FLASH_ATTN_KERNEL_FALLBACK = {
160+
"flash_attention_2": "kernels-community/flash-attn2",
161+
"flash_attention_3": "kernels-community/vllm-flash-attn3",
162+
}
163+
158164

159165
def is_local_dist_rank_0():
160166
return (
@@ -1592,7 +1598,9 @@ def _flash_attn_2_can_dispatch(self, is_init_check: bool = False) -> bool:
15921598
return True
15931599

15941600
if is_torch_xpu_available():
1595-
logger.info("Detect using FlashAttention2 (via kernel `kernels-community/flash-attn2`) on XPU.")
1601+
logger.info(
1602+
f"Detect using FlashAttention2 (via kernel `{FLASH_ATTN_KERNEL_FALLBACK['flash_attention_2']}`) on XPU."
1603+
)
15961604
return True
15971605

15981606
if importlib.util.find_spec("flash_attn") is None:
@@ -1824,14 +1832,12 @@ def _check_and_adjust_attn_implementation(
18241832
and is_kernels_available()
18251833
and not is_torch_npu_available()
18261834
):
1827-
if attn_implementation.endswith("2"):
1828-
applicable_attn_implementation = "kernels-community/flash-attn2"
1829-
if is_torch_xpu_available():
1830-
# On XPU, kernels library is the native implementation
1831-
# Disabling this flag to avoid giving wrong fallbacks on errors and warnings
1832-
requested_original_flash_attn = False
1833-
else:
1834-
applicable_attn_implementation = "kernels-community/vllm-flash-attn3"
1835+
applicable_attn_implementation = FLASH_ATTN_KERNEL_FALLBACK[attn_implementation.removeprefix("paged|")]
1836+
1837+
if is_torch_xpu_available() and attn_implementation.removeprefix("paged|") == "flash_attention_2":
1838+
# On XPU, kernels library is the native implementation
1839+
# Disabling this flag to avoid giving wrong fallbacks on errors and warnings
1840+
requested_original_flash_attn = False
18351841

18361842
if is_paged:
18371843
applicable_attn_implementation = f"paged|{applicable_attn_implementation}"

src/transformers/testing_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@
221221
import torch
222222
from safetensors.torch import load_file
223223

224-
from .modeling_utils import PreTrainedModel
224+
from .modeling_utils import FLASH_ATTN_KERNEL_FALLBACK, PreTrainedModel
225225

226226
IS_ROCM_SYSTEM = torch.version.hip is not None
227227
IS_CUDA_SYSTEM = torch.version.cuda is not None
@@ -620,7 +620,7 @@ def require_flash_attn(test_case):
620620
try:
621621
from kernels import get_kernel
622622

623-
get_kernel("kernels-community/flash-attn2")
623+
get_kernel(FLASH_ATTN_KERNEL_FALLBACK["flash_attention_2"])
624624
except Exception as _:
625625
kernels_available = False
626626

tests/test_modeling_common.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
unset_hf_deepspeed_config,
5353
)
5454
from transformers.modeling_layers import GradientCheckpointingLayer
55-
from transformers.modeling_utils import _get_tied_weight_keys
55+
from transformers.modeling_utils import FLASH_ATTN_KERNEL_FALLBACK, _get_tied_weight_keys
5656
from transformers.models.auto import get_values
5757
from transformers.models.auto.modeling_auto import (
5858
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
@@ -3243,6 +3243,20 @@ def flash_attn_can_dispatch_composite_models(self, attn_implementation: str):
32433243
self.skipTest(f"bfloat16 not supported on {torch_device} (on the specific device currently used)")
32443244

32453245
dtype = torch.bfloat16
3246+
3247+
def _expected_attn_implementations(attention_implementation: str) -> set[str]:
3248+
# Allow kernels fallbacks for flash attention tests.
3249+
requested = attention_implementation
3250+
base = requested.removeprefix("paged|")
3251+
prefix = "paged|" if requested.startswith("paged|") else ""
3252+
3253+
expected = {requested}
3254+
if base in FLASH_ATTN_KERNEL_FALLBACK:
3255+
expected.add(f"{prefix}{FLASH_ATTN_KERNEL_FALLBACK[base]}")
3256+
return expected
3257+
3258+
expected_attn_implementations = _expected_attn_implementations(attn_implementation)
3259+
32463260
for model_class in self.all_model_classes:
32473261
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
32483262
model = model_class(config)
@@ -3275,15 +3289,15 @@ def flash_attn_can_dispatch_composite_models(self, attn_implementation: str):
32753289
for key in model_fa.config:
32763290
if isinstance(getattr(model_fa.config, key), PreTrainedConfig):
32773291
sub_config = getattr(model_fa.config, key)
3278-
self.assertTrue(sub_config._attn_implementation == attn_implementation)
3292+
self.assertIn(sub_config._attn_implementation, expected_attn_implementations)
32793293

32803294
has_fa = False
32813295
for name, submodule in model_fa.named_modules():
32823296
class_name = submodule.__class__.__name__
32833297
if (
32843298
"Attention" in class_name
32853299
and getattr(submodule, "config", None)
3286-
and submodule.config._attn_implementation == attn_implementation
3300+
and submodule.config._attn_implementation in expected_attn_implementations
32873301
):
32883302
has_fa = True
32893303
break

tests/utils/test_modeling_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,11 @@
129129
_prepare_4d_attention_mask,
130130
_prepare_4d_causal_attention_mask,
131131
)
132-
from transformers.modeling_utils import _find_disjoint, _find_identical
132+
from transformers.modeling_utils import (
133+
FLASH_ATTN_KERNEL_FALLBACK,
134+
_find_disjoint,
135+
_find_identical,
136+
)
133137
from transformers.pytorch_utils import isin_mps_friendly
134138

135139
# Fake pretrained models for tests
@@ -3028,7 +3032,7 @@ def test_kernels_fallback(self):
30283032
)
30293033

30303034
self.assertTrue(
3031-
"You do not have `flash_attn` installed, using `kernels-community/flash-attn2` from the `kernels` library instead!"
3035+
f"You do not have `flash_attn` installed, using `{FLASH_ATTN_KERNEL_FALLBACK['flash_attention_2']}` from the `kernels` library instead!"
30323036
in cl.out
30333037
)
30343038

@@ -3040,7 +3044,8 @@ def test_not_available_kernels(self):
30403044

30413045
with self.assertRaises(ImportError) as cm:
30423046
_ = AutoModel.from_pretrained(
3043-
"hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="kernels-community/flash-attn2"
3047+
"hf-tiny-model-private/tiny-random-MCTCTModel",
3048+
attn_implementation=FLASH_ATTN_KERNEL_FALLBACK["flash_attention_2"],
30443049
)
30453050

30463051
self.assertTrue("`kernels` is either not installed or uses an incompatible version." in str(cm.exception))

0 commit comments

Comments
 (0)