1111pytest tests/others/test_attention_backends.py
1212```
1313
14- Tests were conducted on an H100 with PyTorch 2.9.1 (CUDA 12.9).
14+ Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
15+ "native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
1516
1617Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X
1718with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and
2324import pytest
2425import torch
2526
26- from ..testing_utils import numpy_cosine_similarity_distance
27-
2827
2928pytestmark = pytest .mark .skipif (
3029 os .getenv ("RUN_ATTENTION_BACKEND_TESTS" , "false" ) == "false" , reason = "Feature not mature enough."
3736FORWARD_CASES = [
3837 (
3938 "flash_hub" ,
40- torch .tensor ([0.0820 , 0.0859 , 0.0918 , 0.1016 , 0.0957 , 0.0996 , 0.0996 , 0.1016 , 0.2188 , 0.2266 , 0.2363 , 0.2500 , 0.2539 , 0.2461 , 0.2422 , 0.2695 ], dtype = torch .bfloat16 ),
41- 1e-4
39+ torch .tensor ([0.0820 , 0.0859 , 0.0918 , 0.1016 , 0.0957 , 0.0996 , 0.0996 , 0.1016 , 0.2188 , 0.2266 , 0.2363 , 0.2500 , 0.2539 , 0.2461 , 0.2422 , 0.2695 ], dtype = torch .bfloat16 )
4240 ),
4341 (
4442 "_flash_3_hub" ,
4543 torch .tensor ([0.0820 , 0.0859 , 0.0938 , 0.1016 , 0.0977 , 0.0996 , 0.1016 , 0.1016 , 0.2188 , 0.2246 , 0.2344 , 0.2480 , 0.2539 , 0.2480 , 0.2441 , 0.2715 ], dtype = torch .bfloat16 ),
46- 1e-4
4744 ),
4845 (
4946 "native" ,
50- torch .tensor ([0.0820 , 0.0859 , 0.0938 , 0.1016 , 0.0957 , 0.0996 , 0.0996 , 0.1016 , 0.2188 , 0.2266 , 0.2363 , 0.2500 , 0.2539 , 0.2480 , 0.2461 , 0.2734 ], dtype = torch .bfloat16 ),
51- 1e-4
52- ),
47+ torch .tensor ([0.0820 , 0.0859 , 0.0938 , 0.1016 , 0.0957 , 0.0996 , 0.0996 , 0.1016 , 0.2188 , 0.2266 , 0.2363 , 0.2500 , 0.2539 , 0.2480 , 0.2461 , 0.2734 ], dtype = torch .bfloat16 )
48+ ),
5349 (
5450 "_native_cudnn" ,
5551 torch .tensor ([0.0781 , 0.0840 , 0.0879 , 0.0957 , 0.0898 , 0.0957 , 0.0957 , 0.0977 , 0.2168 , 0.2246 , 0.2324 , 0.2500 , 0.2539 , 0.2480 , 0.2441 , 0.2695 ], dtype = torch .bfloat16 ),
56- 5e-4
5752 ),
5853 (
5954 "aiter" ,
6055 torch .tensor ([0.0781 , 0.0820 , 0.0879 , 0.0957 , 0.0898 , 0.0938 , 0.0957 , 0.0957 , 0.2285 , 0.2363 , 0.2461 , 0.2637 , 0.2695 , 0.2617 , 0.2617 , 0.2891 ], dtype = torch .bfloat16 ),
61- 1e-4
6256 )
6357]
6458
6559COMPILE_CASES = [
6660 (
6761 "flash_hub" ,
6862 torch .tensor ([0.0410 , 0.0410 , 0.0449 , 0.0508 , 0.0488 , 0.0586 , 0.0605 , 0.0586 , 0.2324 , 0.2422 , 0.2539 , 0.2734 , 0.2832 , 0.2812 , 0.2773 , 0.3047 ], dtype = torch .bfloat16 ),
69- True ,
70- 1e-4
63+ True
7164 ),
7265 (
7366 "_flash_3_hub" ,
7467 torch .tensor ([0.0410 , 0.0410 , 0.0449 , 0.0508 , 0.0508 , 0.0605 , 0.0625 , 0.0605 , 0.2344 , 0.2461 , 0.2578 , 0.2734 , 0.2852 , 0.2812 , 0.2773 , 0.3047 ], dtype = torch .bfloat16 ),
7568 True ,
76- 1e-4
7769 ),
7870 (
7971 "native" ,
8072 torch .tensor ([0.0410 , 0.0410 , 0.0449 , 0.0508 , 0.0508 , 0.0605 , 0.0605 , 0.0605 , 0.2344 , 0.2461 , 0.2578 , 0.2773 , 0.2871 , 0.2832 , 0.2773 , 0.3066 ], dtype = torch .bfloat16 ),
8173 True ,
82- 1e-4
8374 ),
8475 (
8576 "_native_cudnn" ,
8677 torch .tensor ([0.0410 , 0.0410 , 0.0430 , 0.0508 , 0.0488 , 0.0586 , 0.0605 , 0.0586 , 0.2344 , 0.2461 , 0.2578 , 0.2773 , 0.2871 , 0.2832 , 0.2793 , 0.3086 ], dtype = torch .bfloat16 ),
8778 True ,
88- 5e-4 ,
8979 ),
9080 (
9181 "aiter" ,
9282 torch .tensor ([0.0391 , 0.0391 , 0.0430 , 0.0488 , 0.0469 , 0.0566 , 0.0586 , 0.0566 , 0.2402 , 0.2539 , 0.2637 , 0.2812 , 0.2930 , 0.2910 , 0.2891 , 0.3164 ], dtype = torch .bfloat16 ),
9383 True ,
94- 1e-4
9584 )
9685]
9786# fmt: on
@@ -115,11 +104,11 @@ def _backend_is_probably_supported(pipe, name: str):
115104 return False
116105
117106
118- def _check_if_slices_match (output , expected_slice , expected_diff = 1e-4 ):
107+ def _check_if_slices_match (output , expected_slice ):
119108 img = output .images .detach ().cpu ()
120109 generated_slice = img .flatten ()
121110 generated_slice = torch .cat ([generated_slice [:8 ], generated_slice [- 8 :]])
122- assert numpy_cosine_similarity_distance (generated_slice , expected_slice ) < expected_diff
111+ assert torch . allclose (generated_slice , expected_slice , atol = 1e-4 )
123112
124113
125114@pytest .fixture (scope = "session" )
@@ -137,23 +126,23 @@ def pipe(device):
137126 return pipe
138127
139128
140- @pytest .mark .parametrize ("backend_name,expected_slice,expected_diff " , FORWARD_CASES , ids = [c [0 ] for c in FORWARD_CASES ])
141- def test_forward (pipe , backend_name , expected_slice , expected_diff ):
129+ @pytest .mark .parametrize ("backend_name,expected_slice" , FORWARD_CASES , ids = [c [0 ] for c in FORWARD_CASES ])
130+ def test_forward (pipe , backend_name , expected_slice ):
142131 out = _backend_is_probably_supported (pipe , backend_name )
143132 if isinstance (out , bool ):
144133 pytest .xfail (f"Backend '{ backend_name } ' not supported in this environment." )
145134
146135 modified_pipe = out [0 ]
147136 out = modified_pipe (** INFER_KW , generator = torch .manual_seed (0 ))
148- _check_if_slices_match (out , expected_slice , expected_diff )
137+ _check_if_slices_match (out , expected_slice )
149138
150139
151140@pytest .mark .parametrize (
152- "backend_name,expected_slice,error_on_recompile,expected_diff " ,
141+ "backend_name,expected_slice,error_on_recompile" ,
153142 COMPILE_CASES ,
154143 ids = [c [0 ] for c in COMPILE_CASES ],
155144)
156- def test_forward_with_compile (pipe , backend_name , expected_slice , error_on_recompile , expected_diff ):
145+ def test_forward_with_compile (pipe , backend_name , expected_slice , error_on_recompile ):
157146 if "native" in backend_name and error_on_recompile and not is_torch_version (">=" , "2.9.0" ):
158147 pytest .xfail (f"Test with { backend_name = } is compatible with a higher version of torch." )
159148
@@ -171,4 +160,4 @@ def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recom
171160 ):
172161 out = modified_pipe (** INFER_KW , generator = torch .manual_seed (0 ))
173162
174- _check_if_slices_match (out , expected_slice , expected_diff )
163+ _check_if_slices_match (out , expected_slice )
0 commit comments