1111pytest tests/others/test_attention_backends.py
1212```
1313
14- Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
15- "native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
14+ Tests were conducted on an H100 with PyTorch 2.9.1 (CUDA 12.9).
1615
1716Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X
1817with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and
2423import pytest
2524import torch
2625
26+ from ..testing_utils import numpy_cosine_similarity_distance
27+
2728
2829pytestmark = pytest .mark .skipif (
2930 os .getenv ("RUN_ATTENTION_BACKEND_TESTS" , "false" ) == "false" , reason = "Feature not mature enough."
3637FORWARD_CASES = [
3738 (
3839 "flash_hub" ,
39- torch .tensor ([0.0820 , 0.0859 , 0.0918 , 0.1016 , 0.0957 , 0.0996 , 0.0996 , 0.1016 , 0.2188 , 0.2266 , 0.2363 , 0.2500 , 0.2539 , 0.2461 , 0.2422 , 0.2695 ], dtype = torch .bfloat16 )
40+ torch .tensor ([0.0820 , 0.0859 , 0.0918 , 0.1016 , 0.0957 , 0.0996 , 0.0996 , 0.1016 , 0.2188 , 0.2266 , 0.2363 , 0.2500 , 0.2539 , 0.2461 , 0.2422 , 0.2695 ], dtype = torch .bfloat16 ),
41+ 1e-4
4042 ),
4143 (
4244 "_flash_3_hub" ,
4345 torch .tensor ([0.0820 , 0.0859 , 0.0938 , 0.1016 , 0.0977 , 0.0996 , 0.1016 , 0.1016 , 0.2188 , 0.2246 , 0.2344 , 0.2480 , 0.2539 , 0.2480 , 0.2441 , 0.2715 ], dtype = torch .bfloat16 ),
46+ 1e-4
4447 ),
4548 (
4649 "native" ,
47- torch .tensor ([0.0820 , 0.0859 , 0.0938 , 0.1016 , 0.0957 , 0.0996 , 0.0996 , 0.1016 , 0.2188 , 0.2266 , 0.2363 , 0.2500 , 0.2539 , 0.2480 , 0.2461 , 0.2734 ], dtype = torch .bfloat16 )
48- ),
50+ torch .tensor ([0.0820 , 0.0859 , 0.0938 , 0.1016 , 0.0957 , 0.0996 , 0.0996 , 0.1016 , 0.2188 , 0.2266 , 0.2363 , 0.2500 , 0.2539 , 0.2480 , 0.2461 , 0.2734 ], dtype = torch .bfloat16 ),
51+ 1e-4
52+ ),
4953 (
5054 "_native_cudnn" ,
5155 torch .tensor ([0.0781 , 0.0840 , 0.0879 , 0.0957 , 0.0898 , 0.0957 , 0.0957 , 0.0977 , 0.2168 , 0.2246 , 0.2324 , 0.2500 , 0.2539 , 0.2480 , 0.2441 , 0.2695 ], dtype = torch .bfloat16 ),
56+ 5e-4
5257 ),
5358 (
5459 "aiter" ,
5560 torch .tensor ([0.0781 , 0.0820 , 0.0879 , 0.0957 , 0.0898 , 0.0938 , 0.0957 , 0.0957 , 0.2285 , 0.2363 , 0.2461 , 0.2637 , 0.2695 , 0.2617 , 0.2617 , 0.2891 ], dtype = torch .bfloat16 ),
61+ 1e-4
5662 )
5763]
5864
5965COMPILE_CASES = [
6066 (
6167 "flash_hub" ,
6268 torch .tensor ([0.0410 , 0.0410 , 0.0449 , 0.0508 , 0.0488 , 0.0586 , 0.0605 , 0.0586 , 0.2324 , 0.2422 , 0.2539 , 0.2734 , 0.2832 , 0.2812 , 0.2773 , 0.3047 ], dtype = torch .bfloat16 ),
63- True
69+ True ,
70+ 1e-4
6471 ),
6572 (
6673 "_flash_3_hub" ,
6774 torch .tensor ([0.0410 , 0.0410 , 0.0449 , 0.0508 , 0.0508 , 0.0605 , 0.0625 , 0.0605 , 0.2344 , 0.2461 , 0.2578 , 0.2734 , 0.2852 , 0.2812 , 0.2773 , 0.3047 ], dtype = torch .bfloat16 ),
6875 True ,
76+ 1e-4
6977 ),
7078 (
7179 "native" ,
7280 torch .tensor ([0.0410 , 0.0410 , 0.0449 , 0.0508 , 0.0508 , 0.0605 , 0.0605 , 0.0605 , 0.2344 , 0.2461 , 0.2578 , 0.2773 , 0.2871 , 0.2832 , 0.2773 , 0.3066 ], dtype = torch .bfloat16 ),
7381 True ,
82+ 1e-4
7483 ),
7584 (
7685 "_native_cudnn" ,
7786 torch .tensor ([0.0410 , 0.0410 , 0.0430 , 0.0508 , 0.0488 , 0.0586 , 0.0605 , 0.0586 , 0.2344 , 0.2461 , 0.2578 , 0.2773 , 0.2871 , 0.2832 , 0.2793 , 0.3086 ], dtype = torch .bfloat16 ),
7887 True ,
88+ 5e-4 ,
7989 ),
8090 (
8191 "aiter" ,
8292 torch .tensor ([0.0391 , 0.0391 , 0.0430 , 0.0488 , 0.0469 , 0.0566 , 0.0586 , 0.0566 , 0.2402 , 0.2539 , 0.2637 , 0.2812 , 0.2930 , 0.2910 , 0.2891 , 0.3164 ], dtype = torch .bfloat16 ),
8393 True ,
94+ 1e-4
8495 )
8596]
8697# fmt: on
@@ -104,11 +115,11 @@ def _backend_is_probably_supported(pipe, name: str):
104115 return False
105116
106117
107- def _check_if_slices_match (output , expected_slice ):
118+ def _check_if_slices_match (output , expected_slice , expected_diff = 1e-4 ):
108119 img = output .images .detach ().cpu ()
109120 generated_slice = img .flatten ()
110121 generated_slice = torch .cat ([generated_slice [:8 ], generated_slice [- 8 :]])
111- assert torch . allclose (generated_slice , expected_slice , atol = 1e-4 )
122+ assert numpy_cosine_similarity_distance (generated_slice , expected_slice ) < expected_diff
112123
113124
114125@pytest .fixture (scope = "session" )
@@ -126,23 +137,23 @@ def pipe(device):
126137 return pipe
127138
128139
129- @pytest .mark .parametrize ("backend_name,expected_slice" , FORWARD_CASES , ids = [c [0 ] for c in FORWARD_CASES ])
130- def test_forward (pipe , backend_name , expected_slice ):
140+ @pytest .mark .parametrize ("backend_name,expected_slice,expected_diff " , FORWARD_CASES , ids = [c [0 ] for c in FORWARD_CASES ])
141+ def test_forward (pipe , backend_name , expected_slice , expected_diff ):
131142 out = _backend_is_probably_supported (pipe , backend_name )
132143 if isinstance (out , bool ):
133144 pytest .xfail (f"Backend '{ backend_name } ' not supported in this environment." )
134145
135146 modified_pipe = out [0 ]
136147 out = modified_pipe (** INFER_KW , generator = torch .manual_seed (0 ))
137- _check_if_slices_match (out , expected_slice )
148+ _check_if_slices_match (out , expected_slice , expected_diff )
138149
139150
140151@pytest .mark .parametrize (
141- "backend_name,expected_slice,error_on_recompile" ,
152+ "backend_name,expected_slice,error_on_recompile,expected_diff " ,
142153 COMPILE_CASES ,
143154 ids = [c [0 ] for c in COMPILE_CASES ],
144155)
145- def test_forward_with_compile (pipe , backend_name , expected_slice , error_on_recompile ):
156+ def test_forward_with_compile (pipe , backend_name , expected_slice , error_on_recompile , expected_diff ):
146157 if "native" in backend_name and error_on_recompile and not is_torch_version (">=" , "2.9.0" ):
147158 pytest .xfail (f"Test with { backend_name = } is compatible with a higher version of torch." )
148159
@@ -160,4 +171,4 @@ def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recom
160171 ):
161172 out = modified_pipe (** INFER_KW , generator = torch .manual_seed (0 ))
162173
163- _check_if_slices_match (out , expected_slice )
174+ _check_if_slices_match (out , expected_slice , expected_diff )
0 commit comments