@@ -49,15 +49,9 @@ def test_attention_correctness_fp32(self, batch, heads, seq_len, head_dim, devic
4949 pytest .skip ("CUDA kernels not built" )
5050
5151 # Generate inputs
52- q = torch .randn (
53- batch , heads , seq_len , head_dim , device = device , dtype = torch .float32
54- )
55- k = torch .randn (
56- batch , heads , seq_len , head_dim , device = device , dtype = torch .float32
57- )
58- v = torch .randn (
59- batch , heads , seq_len , head_dim , device = device , dtype = torch .float32
60- )
52+ q = torch .randn (batch , heads , seq_len , head_dim , device = device , dtype = torch .float32 )
53+ k = torch .randn (batch , heads , seq_len , head_dim , device = device , dtype = torch .float32 )
54+ v = torch .randn (batch , heads , seq_len , head_dim , device = device , dtype = torch .float32 )
6155
6256 # Compute outputs
6357 output = naive_attention (q , k , v )
@@ -92,15 +86,9 @@ def test_attention_correctness_fp16(self, batch, heads, seq_len, head_dim, devic
9286 except ImportError :
9387 pytest .skip ("CUDA kernels not built" )
9488
95- q = torch .randn (
96- batch , heads , seq_len , head_dim , device = device , dtype = torch .float16
97- )
98- k = torch .randn (
99- batch , heads , seq_len , head_dim , device = device , dtype = torch .float16
100- )
101- v = torch .randn (
102- batch , heads , seq_len , head_dim , device = device , dtype = torch .float16
103- )
89+ q = torch .randn (batch , heads , seq_len , head_dim , device = device , dtype = torch .float16 )
90+ k = torch .randn (batch , heads , seq_len , head_dim , device = device , dtype = torch .float16 )
91+ v = torch .randn (batch , heads , seq_len , head_dim , device = device , dtype = torch .float16 )
10492
10593 output = naive_attention (q , k , v )
10694 reference , _ = compute_attention_reference (q , k , v )
@@ -149,17 +137,16 @@ def test_softmax_invariants(self, batch, heads, seq_len, head_dim, device):
149137
150138 # Property 2: Sum equals 1
151139 row_sums = softmax_output .sum (dim = - 1 )
152- assert torch .allclose (
153- row_sums , torch . ones_like ( row_sums ), rtol = 1e-5 , atol = 1e-5
154- ), "Softmax row sums should equal 1"
140+ assert torch .allclose (row_sums , torch . ones_like ( row_sums ), rtol = 1e-5 , atol = 1e-5 ), (
141+ "Softmax row sums should equal 1"
142+ )
155143
156144 # Property 3: Monotonicity (larger input -> larger output)
157145 for i in range (min (5 , seq_len - 1 )): # Check a few pairs
158146 idx1 , idx2 = i , i + 1
159147 mask = scores [..., idx1 ] > scores [..., idx2 ]
160148 assert (
161- softmax_output [..., idx1 ][mask ]
162- >= softmax_output [..., idx2 ][mask ] - 1e-6
149+ softmax_output [..., idx1 ][mask ] >= softmax_output [..., idx2 ][mask ] - 1e-6
163150 ).all (), "Softmax should preserve relative order"
164151
165152
@@ -189,15 +176,9 @@ def test_flash_attention_consistency(self, batch, heads, seq_len, head_dim, devi
189176 except ImportError :
190177 pytest .skip ("CUDA kernels not built" )
191178
192- q = torch .randn (
193- batch , heads , seq_len , head_dim , device = device , dtype = torch .float32
194- )
195- k = torch .randn (
196- batch , heads , seq_len , head_dim , device = device , dtype = torch .float32
197- )
198- v = torch .randn (
199- batch , heads , seq_len , head_dim , device = device , dtype = torch .float32
200- )
179+ q = torch .randn (batch , heads , seq_len , head_dim , device = device , dtype = torch .float32 )
180+ k = torch .randn (batch , heads , seq_len , head_dim , device = device , dtype = torch .float32 )
181+ v = torch .randn (batch , heads , seq_len , head_dim , device = device , dtype = torch .float32 )
201182
202183 flash_output = flash_attention (q , k , v )
203184 naive_output = naive_attention (q , k , v )
@@ -233,15 +214,9 @@ def test_causal_mask_correctness(self, batch, heads, seq_len, head_dim, device):
233214 except ImportError :
234215 pytest .skip ("CUDA kernels not built" )
235216
236- q = torch .randn (
237- batch , heads , seq_len , head_dim , device = device , dtype = torch .float32
238- )
239- k = torch .randn (
240- batch , heads , seq_len , head_dim , device = device , dtype = torch .float32
241- )
242- v = torch .randn (
243- batch , heads , seq_len , head_dim , device = device , dtype = torch .float32
244- )
217+ q = torch .randn (batch , heads , seq_len , head_dim , device = device , dtype = torch .float32 )
218+ k = torch .randn (batch , heads , seq_len , head_dim , device = device , dtype = torch .float32 )
219+ v = torch .randn (batch , heads , seq_len , head_dim , device = device , dtype = torch .float32 )
245220
246221 # Compute causal attention
247222 causal_output = flash_attention (q , k , v , is_causal = True )
@@ -260,9 +235,9 @@ def test_causal_mask_correctness(self, batch, heads, seq_len, head_dim, device):
260235
261236 # Verify attention weights are lower triangular
262237 upper_triangle = torch .triu (attn_weights , diagonal = 1 )
263- assert torch .allclose (
264- upper_triangle , torch . zeros_like ( upper_triangle ), atol = 1e-6
265- ), "Causal attention weights should be lower triangular"
238+ assert torch .allclose (upper_triangle , torch . zeros_like ( upper_triangle ), atol = 1e-6 ), (
239+ "Causal attention weights should be lower triangular"
240+ )
266241
267242
268243class TestTiledAttention :
@@ -290,15 +265,9 @@ def test_tiled_attention_consistency(self, batch, heads, seq_len, head_dim, devi
290265 except ImportError :
291266 pytest .skip ("CUDA kernels not built" )
292267
293- q = torch .randn (
294- batch , heads , seq_len , head_dim , device = device , dtype = torch .float32
295- )
296- k = torch .randn (
297- batch , heads , seq_len , head_dim , device = device , dtype = torch .float32
298- )
299- v = torch .randn (
300- batch , heads , seq_len , head_dim , device = device , dtype = torch .float32
301- )
268+ q = torch .randn (batch , heads , seq_len , head_dim , device = device , dtype = torch .float32 )
269+ k = torch .randn (batch , heads , seq_len , head_dim , device = device , dtype = torch .float32 )
270+ v = torch .randn (batch , heads , seq_len , head_dim , device = device , dtype = torch .float32 )
302271
303272 tiled_output = tiled_attention (q , k , v )
304273 naive_output = naive_attention (q , k , v )
@@ -332,15 +301,9 @@ def test_tiled_attention_correctness_fp16(self, batch, heads, seq_len, head_dim,
332301 except ImportError :
333302 pytest .skip ("CUDA kernels not built" )
334303
335- q = torch .randn (
336- batch , heads , seq_len , head_dim , device = device , dtype = torch .float16
337- )
338- k = torch .randn (
339- batch , heads , seq_len , head_dim , device = device , dtype = torch .float16
340- )
341- v = torch .randn (
342- batch , heads , seq_len , head_dim , device = device , dtype = torch .float16
343- )
304+ q = torch .randn (batch , heads , seq_len , head_dim , device = device , dtype = torch .float16 )
305+ k = torch .randn (batch , heads , seq_len , head_dim , device = device , dtype = torch .float16 )
306+ v = torch .randn (batch , heads , seq_len , head_dim , device = device , dtype = torch .float16 )
344307
345308 output = tiled_attention (q , k , v )
346309 reference , _ = compute_attention_reference (q , k , v )
@@ -373,8 +336,9 @@ def test_tiled_attention_scale_parameter(self, device):
373336 output_custom = tiled_attention (q , k , v , scale = custom_scale )
374337
375338 # Outputs should be different
376- assert not torch .allclose (output_default , output_custom ), \
339+ assert not torch .allclose (output_default , output_custom ), (
377340 "Custom scale should produce different output"
341+ )
378342
379343
380344class TestNaiveAttentionErrorHandling :
@@ -644,15 +608,9 @@ def test_batch_multihead_support(self, batch, heads, seq_len, head_dim, device):
644608 except ImportError :
645609 pytest .skip ("CUDA kernels not built" )
646610
647- q = torch .randn (
648- batch , heads , seq_len , head_dim , device = device , dtype = torch .float32
649- )
650- k = torch .randn (
651- batch , heads , seq_len , head_dim , device = device , dtype = torch .float32
652- )
653- v = torch .randn (
654- batch , heads , seq_len , head_dim , device = device , dtype = torch .float32
655- )
611+ q = torch .randn (batch , heads , seq_len , head_dim , device = device , dtype = torch .float32 )
612+ k = torch .randn (batch , heads , seq_len , head_dim , device = device , dtype = torch .float32 )
613+ v = torch .randn (batch , heads , seq_len , head_dim , device = device , dtype = torch .float32 )
656614
657615 output = flash_attention (q , k , v )
658616
0 commit comments