Skip to content

Commit 23251d6

Browse files
committed
Revert "gracefully error out when attn-backend x cp combo isn't supported."
This reverts commit c8abb5d.
1 parent c8abb5d commit 23251d6

2 files changed

Lines changed: 14 additions & 29 deletions

File tree

tests/others/test_attention_backends.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
pytest tests/others/test_attention_backends.py
1212
```
1313
14-
Tests were conducted on an H100 with PyTorch 2.9.1 (CUDA 12.9).
14+
Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
15+
"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
1516
1617
Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X
1718
with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and
@@ -23,8 +24,6 @@
2324
import pytest
2425
import torch
2526

26-
from ..testing_utils import numpy_cosine_similarity_distance
27-
2827

2928
pytestmark = pytest.mark.skipif(
3029
os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough."
@@ -37,61 +36,51 @@
3736
FORWARD_CASES = [
3837
(
3938
"flash_hub",
40-
torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16),
41-
1e-4
39+
torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16)
4240
),
4341
(
4442
"_flash_3_hub",
4543
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16),
46-
1e-4
4744
),
4845
(
4946
"native",
50-
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16),
51-
1e-4
52-
),
47+
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16)
48+
),
5349
(
5450
"_native_cudnn",
5551
torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
56-
5e-4
5752
),
5853
(
5954
"aiter",
6055
torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16),
61-
1e-4
6256
)
6357
]
6458

6559
COMPILE_CASES = [
6660
(
6761
"flash_hub",
6862
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
69-
True,
70-
1e-4
63+
True
7164
),
7265
(
7366
"_flash_3_hub",
7467
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
7568
True,
76-
1e-4
7769
),
7870
(
7971
"native",
8072
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16),
8173
True,
82-
1e-4
8374
),
8475
(
8576
"_native_cudnn",
8677
torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
8778
True,
88-
5e-4,
8979
),
9080
(
9181
"aiter",
9282
torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16),
9383
True,
94-
1e-4
9584
)
9685
]
9786
# fmt: on
@@ -115,11 +104,11 @@ def _backend_is_probably_supported(pipe, name: str):
115104
return False
116105

117106

118-
def _check_if_slices_match(output, expected_slice, expected_diff=1e-4):
107+
def _check_if_slices_match(output, expected_slice):
119108
img = output.images.detach().cpu()
120109
generated_slice = img.flatten()
121110
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
122-
assert numpy_cosine_similarity_distance(generated_slice, expected_slice) < expected_diff
111+
assert torch.allclose(generated_slice, expected_slice, atol=1e-4)
123112

124113

125114
@pytest.fixture(scope="session")
@@ -137,23 +126,23 @@ def pipe(device):
137126
return pipe
138127

139128

140-
@pytest.mark.parametrize("backend_name,expected_slice,expected_diff", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
141-
def test_forward(pipe, backend_name, expected_slice, expected_diff):
129+
@pytest.mark.parametrize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
130+
def test_forward(pipe, backend_name, expected_slice):
142131
out = _backend_is_probably_supported(pipe, backend_name)
143132
if isinstance(out, bool):
144133
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
145134

146135
modified_pipe = out[0]
147136
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
148-
_check_if_slices_match(out, expected_slice, expected_diff)
137+
_check_if_slices_match(out, expected_slice)
149138

150139

151140
@pytest.mark.parametrize(
152-
"backend_name,expected_slice,error_on_recompile,expected_diff",
141+
"backend_name,expected_slice,error_on_recompile",
153142
COMPILE_CASES,
154143
ids=[c[0] for c in COMPILE_CASES],
155144
)
156-
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile, expected_diff):
145+
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile):
157146
if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"):
158147
pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.")
159148

@@ -171,4 +160,4 @@ def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recom
171160
):
172161
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
173162

174-
_check_if_slices_match(out, expected_slice, expected_diff)
163+
_check_if_slices_match(out, expected_slice)

tests/testing_utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,6 @@ def torch_all_close(a, b, *args, **kwargs):
131131

132132

133133
def numpy_cosine_similarity_distance(a, b):
134-
if isinstance(a, torch.Tensor):
135-
a = a.detach().cpu().float().numpy()
136-
if isinstance(b, torch.Tensor):
137-
b = b.detach().cpu().float().numpy()
138134
similarity = np.dot(a, b) / (norm(a) * norm(b))
139135
distance = 1.0 - similarity.mean()
140136

0 commit comments

Comments
 (0)