@@ -63,8 +63,7 @@ def cosine_similarity(a, b):
6363# ---------------------------------------------------------------------------
6464# Reference implementation: naive decode attention (no paging)
6565# ---------------------------------------------------------------------------
66- def naive_decode_attention_ref (q , k_pages , v_pages , kv_indptr , kv_indices ,
67- sm_scale , kv_block_size ):
66+ def naive_decode_attention_ref (q , k_pages , v_pages , kv_indptr , kv_indices , sm_scale , kv_block_size ):
6867 """
6968 Naive Python reference for decode attention with paged KV cache.
7069
@@ -154,7 +153,6 @@ def build_decode_test_data(
154153 np .random .seed (seed )
155154 paddle .seed (seed )
156155
157- total_kv_len = sum (seq_lens )
158156 num_blocks_needed = sum ((s + block_size - 1 ) // block_size for s in seq_lens )
159157 num_blocks = max (num_blocks_needed + 4 , 8 )
160158
@@ -194,7 +192,7 @@ def build_decode_test_data(
194192 attn_lse = paddle .empty ([batch_size , num_heads , max_kv_splits ], dtype = "float32" )
195193 o = paddle .empty ([batch_size , num_heads , Lv ], dtype = dtype )
196194
197- sm_scale = head_dim_k ** - 0.5
195+ sm_scale = head_dim_k ** - 0.5
198196
199197 return {
200198 "q" : q ,
@@ -290,8 +288,7 @@ def test_empty(self):
290288 ids = [c [0 ] for c in _DECODE_CASES ],
291289)
292290@pytest .mark .parametrize ("dtype" , ["float16" , "bfloat16" ])
293- def test_decode_attention_correctness (name , batch , num_heads , kv_heads , Lk , Lv ,
294- seq_lens , block_size , dtype ):
291+ def test_decode_attention_correctness (name , batch , num_heads , kv_heads , Lk , Lv , seq_lens , block_size , dtype ):
295292 """Triton decode attention output should match naive reference."""
296293 data = build_decode_test_data (
297294 batch_size = batch ,
@@ -336,12 +333,10 @@ def test_decode_attention_correctness(name, batch, num_heads, kv_heads, Lk, Lv,
336333 cos_sim = cosine_similarity (triton_out , ref_out )
337334
338335 atol = BF16_ATOL if dtype == "bfloat16" else FP16_ATOL
339- assert max_diff < atol , (
340- f"[{ name } /{ dtype } ] max_diff={ max_diff :.6f} exceeds atol={ atol } "
341- )
342- assert cos_sim > COSINE_SIM_THRESHOLD , (
343- f"[{ name } /{ dtype } ] cos_sim={ cos_sim :.6f} below threshold={ COSINE_SIM_THRESHOLD } "
344- )
336+ assert max_diff < atol , f"[{ name } /{ dtype } ] max_diff={ max_diff :.6f} exceeds atol={ atol } "
337+ assert (
338+ cos_sim > COSINE_SIM_THRESHOLD
339+ ), f"[{ name } /{ dtype } ] cos_sim={ cos_sim :.6f} below threshold={ COSINE_SIM_THRESHOLD } "
345340
346341
347342# ===========================================================================
@@ -380,10 +375,7 @@ def test_decode_attention_determinism():
380375 results .append (o .astype ("float32" ).numpy ())
381376
382377 for i in range (1 , len (results )):
383- np .testing .assert_array_equal (
384- results [0 ], results [i ],
385- err_msg = f"Run 0 vs run { i } differ — non-deterministic!"
386- )
378+ np .testing .assert_array_equal (results [0 ], results [i ], err_msg = f"Run 0 vs run { i } differ — non-deterministic!" )
387379
388380
389381# ===========================================================================
0 commit comments