Skip to content

Commit cc5eaf0

Browse files
committed
enhance attention backend tests
1 parent 17c0e79 commit cc5eaf0

2 files changed

Lines changed: 29 additions & 14 deletions

File tree

tests/others/test_attention_backends.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
pytest tests/others/test_attention_backends.py
1212
```
1313
14-
Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
15-
"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
14+
Tests were conducted on an H100 with PyTorch 2.9.1 (CUDA 12.9).
1615
1716
Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X
1817
with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and
@@ -24,6 +23,8 @@
2423
import pytest
2524
import torch
2625

26+
from ..testing_utils import numpy_cosine_similarity_distance
27+
2728

2829
pytestmark = pytest.mark.skipif(
2930
os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough."
@@ -36,51 +37,61 @@
3637
FORWARD_CASES = [
3738
(
3839
"flash_hub",
39-
torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16)
40+
torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16),
41+
1e-4
4042
),
4143
(
4244
"_flash_3_hub",
4345
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16),
46+
1e-4
4447
),
4548
(
4649
"native",
47-
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16)
48-
),
50+
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16),
51+
1e-4
52+
),
4953
(
5054
"_native_cudnn",
5155
torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
56+
5e-4
5257
),
5358
(
5459
"aiter",
5560
torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16),
61+
1e-4
5662
)
5763
]
5864

5965
COMPILE_CASES = [
6066
(
6167
"flash_hub",
6268
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
63-
True
69+
True,
70+
1e-4
6471
),
6572
(
6673
"_flash_3_hub",
6774
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
6875
True,
76+
1e-4
6977
),
7078
(
7179
"native",
7280
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16),
7381
True,
82+
1e-4
7483
),
7584
(
7685
"_native_cudnn",
7786
torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
7887
True,
88+
5e-4,
7989
),
8090
(
8191
"aiter",
8292
torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16),
8393
True,
94+
1e-4
8495
)
8596
]
8697
# fmt: on
@@ -104,11 +115,11 @@ def _backend_is_probably_supported(pipe, name: str):
104115
return False
105116

106117

107-
def _check_if_slices_match(output, expected_slice):
118+
def _check_if_slices_match(output, expected_slice, expected_diff=1e-4):
108119
img = output.images.detach().cpu()
109120
generated_slice = img.flatten()
110121
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
111-
assert torch.allclose(generated_slice, expected_slice, atol=1e-4)
122+
assert numpy_cosine_similarity_distance(generated_slice, expected_slice) < expected_diff
112123

113124

114125
@pytest.fixture(scope="session")
@@ -126,23 +137,23 @@ def pipe(device):
126137
return pipe
127138

128139

129-
@pytest.mark.parametrize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
130-
def test_forward(pipe, backend_name, expected_slice):
140+
@pytest.mark.parametrize("backend_name,expected_slice,expected_diff", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
141+
def test_forward(pipe, backend_name, expected_slice, expected_diff):
131142
out = _backend_is_probably_supported(pipe, backend_name)
132143
if isinstance(out, bool):
133144
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
134145

135146
modified_pipe = out[0]
136147
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
137-
_check_if_slices_match(out, expected_slice)
148+
_check_if_slices_match(out, expected_slice, expected_diff)
138149

139150

140151
@pytest.mark.parametrize(
141-
"backend_name,expected_slice,error_on_recompile",
152+
"backend_name,expected_slice,error_on_recompile,expected_diff",
142153
COMPILE_CASES,
143154
ids=[c[0] for c in COMPILE_CASES],
144155
)
145-
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile):
156+
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile, expected_diff):
146157
if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"):
147158
pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.")
148159

@@ -160,4 +171,4 @@ def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recom
160171
):
161172
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
162173

163-
_check_if_slices_match(out, expected_slice)
174+
_check_if_slices_match(out, expected_slice, expected_diff)

tests/testing_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ def torch_all_close(a, b, *args, **kwargs):
131131

132132

133133
def numpy_cosine_similarity_distance(a, b):
134+
if isinstance(a, torch.Tensor):
135+
a = a.detach().cpu().float().numpy()
136+
if isinstance(b, torch.Tensor):
137+
b = b.detach().cpu().float().numpy()
134138
similarity = np.dot(a, b) / (norm(a) * norm(b))
135139
distance = 1.0 - similarity.mean()
136140

0 commit comments

Comments
 (0)