@@ -67,6 +67,11 @@ def _max_abs_error(out, ref):
6767 return (out .float () - ref .float ()).abs ().max ().item ()
6868
6969
70+ # bf16 kernel vs fp32 reference tolerance.
71+ # The benchmark cross-validates backends at 1e-2; tests use the same bar.
72+ MAX_ABS_TOL = 1e-2
73+
74+
7075# ---------------------------------------------------------------------------
7176# Test configurations adapted from FlashAttention
7277# ---------------------------------------------------------------------------
@@ -130,7 +135,7 @@ def test_mha_basic(self):
130135
131136 self .assertFalse (torch .isnan (out ).any (), "NaN in output" )
132137 self .assertLess (
133- _max_abs_error (out , ref ), 0.05 , f"D={ D } Lq={ Lq } Lk={ Lk } "
138+ _max_abs_error (out , ref ), MAX_ABS_TOL , f"D={ D } Lq={ Lq } Lk={ Lk } "
134139 )
135140
136141 def test_mha_causal (self ):
@@ -148,7 +153,7 @@ def test_mha_causal(self):
148153 ref = _reference_sdpa (q , k , v , is_causal = True )
149154
150155 self .assertFalse (torch .isnan (out ).any ())
151- self .assertLess (_max_abs_error (out , ref ), 0.05 )
156+ self .assertLess (_max_abs_error (out , ref ), MAX_ABS_TOL )
152157
153158 def test_mha_bool_mask (self ):
154159 """MHA with explicit bool attention mask."""
@@ -168,7 +173,7 @@ def test_mha_bool_mask(self):
168173 ref = _reference_sdpa (q , k , v , attn_mask = mask )
169174
170175 self .assertFalse (torch .isnan (out ).any ())
171- self .assertLess (_max_abs_error (out , ref ), 0.05 )
176+ self .assertLess (_max_abs_error (out , ref ), MAX_ABS_TOL )
172177
173178 def test_mha_non_pow2_head_dim (self ):
174179 """MHA with non-power-of-2 head dimensions."""
@@ -187,7 +192,7 @@ def test_mha_non_pow2_head_dim(self):
187192 ref = _reference_sdpa (q , k , v )
188193
189194 self .assertFalse (torch .isnan (out ).any ())
190- self .assertLess (_max_abs_error (out , ref ), 0.05 )
195+ self .assertLess (_max_abs_error (out , ref ), MAX_ABS_TOL )
191196
192197 def test_mha_non_pow2_causal (self ):
193198 """MHA with non-pow2 head dim and causal masking."""
@@ -204,7 +209,7 @@ def test_mha_non_pow2_causal(self):
204209 ref = _reference_sdpa (q , k , v , is_causal = True )
205210
206211 self .assertFalse (torch .isnan (out ).any ())
207- self .assertLess (_max_abs_error (out , ref ), 0.05 )
212+ self .assertLess (_max_abs_error (out , ref ), MAX_ABS_TOL )
208213
209214 # ------------------------------------------------------------------
210215 # GQA tests
@@ -230,7 +235,7 @@ def test_gqa_decode(self):
230235 self .assertEqual (out .shape , (B , H_q , Lq , D ))
231236 self .assertFalse (torch .isnan (out ).any ())
232237 self .assertLess (
233- _max_abs_error (out , ref ), 0.05 , f"{ label } D={ D } Lk={ Lk } "
238+ _max_abs_error (out , ref ), MAX_ABS_TOL , f"{ label } D={ D } Lk={ Lk } "
234239 )
235240
236241 def test_gqa_decode_with_mask (self ):
@@ -253,7 +258,7 @@ def test_gqa_decode_with_mask(self):
253258 ref = _reference_sdpa (q , k , v , attn_mask = mask )
254259
255260 self .assertFalse (torch .isnan (out ).any ())
256- self .assertLess (_max_abs_error (out , ref ), 0.05 )
261+ self .assertLess (_max_abs_error (out , ref ), MAX_ABS_TOL )
257262
258263 def test_gqa_short_seqlen (self ):
259264 """GQA with short seqlen_q (2-8)."""
@@ -270,7 +275,7 @@ def test_gqa_short_seqlen(self):
270275 ref = _reference_sdpa (q , k , v )
271276
272277 self .assertFalse (torch .isnan (out ).any ())
273- self .assertLess (_max_abs_error (out , ref ), 0.05 )
278+ self .assertLess (_max_abs_error (out , ref ), MAX_ABS_TOL )
274279
275280 def test_gqa_prefill (self ):
276281 """GQA prefill (long seqlen_q)."""
@@ -290,7 +295,7 @@ def test_gqa_prefill(self):
290295
291296 self .assertEqual (out .shape , (B , H_q , L , D ))
292297 self .assertFalse (torch .isnan (out ).any ())
293- self .assertLess (_max_abs_error (out , ref ), 0.05 )
298+ self .assertLess (_max_abs_error (out , ref ), MAX_ABS_TOL )
294299
295300 def test_gqa_non_pow2_head_dim (self ):
296301 """GQA with non-power-of-2 head dimensions."""
@@ -308,7 +313,7 @@ def test_gqa_non_pow2_head_dim(self):
308313
309314 self .assertFalse (torch .isnan (out ).any ())
310315 self .assertLess (
311- _max_abs_error (out , ref ), 0.05 , f"D={ D } Lq={ Lq } Lk={ Lk } "
316+ _max_abs_error (out , ref ), MAX_ABS_TOL , f"D={ D } Lq={ Lq } Lk={ Lk } "
312317 )
313318
314319 def test_gqa_causal_prefill (self ):
@@ -326,7 +331,7 @@ def test_gqa_causal_prefill(self):
326331 ref = _reference_sdpa (q , k , v , is_causal = True )
327332
328333 self .assertFalse (torch .isnan (out ).any ())
329- self .assertLess (_max_abs_error (out , ref ), 0.05 )
334+ self .assertLess (_max_abs_error (out , ref ), MAX_ABS_TOL )
330335
331336 def test_gqa_causal_decode_with_mask (self ):
332337 """GQA decode with causal-like bool mask (simulating KV cache)."""
@@ -352,7 +357,7 @@ def test_gqa_causal_decode_with_mask(self):
352357 ref = _reference_sdpa (q , k , v , attn_mask = mask )
353358
354359 self .assertFalse (torch .isnan (out ).any ())
355- self .assertLess (_max_abs_error (out , ref ), 0.05 )
360+ self .assertLess (_max_abs_error (out , ref ), MAX_ABS_TOL )
356361
357362 def test_gqa_batch_size (self ):
358363 """GQA with batch_size > 1."""
@@ -368,7 +373,7 @@ def test_gqa_batch_size(self):
368373 ref = _reference_sdpa (q , k , v )
369374
370375 self .assertFalse (torch .isnan (out ).any ())
371- self .assertLess (_max_abs_error (out , ref ), 0.05 )
376+ self .assertLess (_max_abs_error (out , ref ), MAX_ABS_TOL )
372377
373378 # ------------------------------------------------------------------
374379 # Qwen 3.5 MoE configuration
@@ -393,7 +398,7 @@ def test_qwen35_moe_config(self):
393398 self .assertEqual (out .shape , (B , H_q , Lq , D ))
394399 self .assertFalse (torch .isnan (out ).any ())
395400 self .assertLess (
396- _max_abs_error (out , ref ), 0.05 , f"Qwen config Lq={ Lq } Lk={ Lk } "
401+ _max_abs_error (out , ref ), MAX_ABS_TOL , f"Qwen config Lq={ Lq } Lk={ Lk } "
397402 )
398403
399404 # ------------------------------------------------------------------
@@ -427,7 +432,7 @@ def test_custom_scale(self):
427432 ref = _reference_sdpa (q , k , v , scale = scale )
428433
429434 self .assertFalse (torch .isnan (out ).any ())
430- self .assertLess (_max_abs_error (out , ref ), 0.05 )
435+ self .assertLess (_max_abs_error (out , ref ), MAX_ABS_TOL )
431436
432437 def test_all_masked (self ):
433438 """All-masked block should produce zeros, not NaN."""
@@ -508,7 +513,7 @@ def test_non_pow2_no_mask(self):
508513 ref = _reference_sdpa (q , k , v )
509514
510515 self .assertFalse (torch .isnan (out ).any ())
511- self .assertLess (_max_abs_error (out , ref ), 0.05 )
516+ self .assertLess (_max_abs_error (out , ref ), MAX_ABS_TOL )
512517
513518
514519if __name__ == "__main__" :
0 commit comments