Skip to content

Commit 98ce883

Browse files
kaix-nvrohansjoshi
authored andcommitted
Add sink-pattern decode calibration test (full cache + nonzero sparsity)
Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent 37b77a0 commit 98ce883

1 file changed

Lines changed: 33 additions & 0 deletions

File tree

tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"""
2323

2424
import copy
25+
import itertools
2526

2627
import pytest
2728
import torch
@@ -143,6 +144,38 @@ def test_decode_branch_reports_decode_phase(self):
143144
assert module._last_stats["phase"] == "decode"
144145
assert len(module._last_stats["sparsity"]) == len(THRESHOLD_TRIALS)
145146

147+
def test_decode_calibration_measures_full_cache_with_sink(self):
148+
"""Decode calibration must scan the whole KV cache and report real sparsity.
149+
150+
A dominant sink at position 0 makes the distant KV tiles negligible, so a
151+
correct decode measurement skips almost all of them. This guards the two
152+
decode bugs that random inputs don't expose:
153+
* causal-offset ``kv_bound`` — without it the loop stops after the first
154+
``BLOCK_M`` tokens, so ``total`` would be a fraction of the cache.
155+
* padding-row exclusion — without it the 127 padding rows veto every
156+
tile and sparsity is 0%.
157+
"""
158+
num_heads, seq_k, head_dim = 4, 2048, 64
159+
block_n = 128 # the calibration kernel measures at 128x128
160+
q = torch.ones(1, num_heads, 1, head_dim, device="cuda", dtype=torch.float16)
161+
k = torch.zeros(1, num_heads, seq_k, head_dim, device="cuda", dtype=torch.float16)
162+
k[:, :, 0] = 20.0 # attention sink dominates every query
163+
v = torch.randn(1, num_heads, seq_k, head_dim, device="cuda", dtype=torch.float16)
164+
165+
module = _calibration_module(THRESHOLD_TRIALS)
166+
method = module._sparse_method_instance
167+
triton_attention_forward(module, q, k, v, attention_mask=None, scaling=1.0 / head_dim**0.5)
168+
169+
counters = method._hf_calibration_counters
170+
total = int(counters[0, 0])
171+
# Full cache scanned (not truncated to the first block).
172+
assert total == num_heads * (seq_k // block_n), total
173+
sparsity = (counters[:, 1].float() / counters[:, 0].clamp(min=1)).tolist()
174+
# Sink => the vast majority of tiles are negligible and skippable (not 0%).
175+
assert max(sparsity) > 0.8, sparsity
176+
# Skipped-tile fraction is non-decreasing as the threshold grows.
177+
assert all(later >= earlier for earlier, later in itertools.pairwise(sparsity)), sparsity
178+
146179

147180
if __name__ == "__main__":
148181
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)