|
22 | 22 | """ |
23 | 23 |
|
24 | 24 | import copy |
| 25 | +import itertools |
25 | 26 |
|
26 | 27 | import pytest |
27 | 28 | import torch |
@@ -143,6 +144,38 @@ def test_decode_branch_reports_decode_phase(self): |
143 | 144 | assert module._last_stats["phase"] == "decode" |
144 | 145 | assert len(module._last_stats["sparsity"]) == len(THRESHOLD_TRIALS) |
145 | 146 |
|
| 147 | + def test_decode_calibration_measures_full_cache_with_sink(self): |
| 148 | + """Decode calibration must scan the whole KV cache and report real sparsity. |
| 149 | +
|
| 150 | + A dominant sink at position 0 makes the distant KV tiles negligible, so a |
| 151 | + correct decode measurement skips almost all of them. This guards the two |
| 152 | + decode bugs that random inputs don't expose: |
| 153 | + * causal-offset ``kv_bound`` — without it the loop stops after the first |
| 154 | + ``BLOCK_M`` tokens, so ``total`` would be a fraction of the cache. |
| 155 | + * padding-row exclusion — without it the 127 padding rows veto every |
| 156 | + tile and sparsity is 0%. |
| 157 | + """ |
| 158 | + num_heads, seq_k, head_dim = 4, 2048, 64 |
| 159 | + block_n = 128 # the calibration kernel measures at 128x128 |
| 160 | + q = torch.ones(1, num_heads, 1, head_dim, device="cuda", dtype=torch.float16) |
| 161 | + k = torch.zeros(1, num_heads, seq_k, head_dim, device="cuda", dtype=torch.float16) |
| 162 | + k[:, :, 0] = 20.0 # attention sink dominates every query |
| 163 | + v = torch.randn(1, num_heads, seq_k, head_dim, device="cuda", dtype=torch.float16) |
| 164 | + |
| 165 | + module = _calibration_module(THRESHOLD_TRIALS) |
| 166 | + method = module._sparse_method_instance |
| 167 | + triton_attention_forward(module, q, k, v, attention_mask=None, scaling=1.0 / head_dim**0.5) |
| 168 | + |
| 169 | + counters = method._hf_calibration_counters |
| 170 | + total = int(counters[0, 0]) |
| 171 | + # Full cache scanned (not truncated to the first block). |
| 172 | + assert total == num_heads * (seq_k // block_n), total |
| 173 | + sparsity = (counters[:, 1].float() / counters[:, 0].clamp(min=1)).tolist() |
| 174 | + # Sink => the vast majority of tiles are negligible and skippable (not 0%). |
| 175 | + assert max(sparsity) > 0.8, sparsity |
| 176 | + # Skipped-tile fraction is non-decreasing as the threshold grows. |
| 177 | + assert all(later >= earlier for earlier, later in itertools.pairwise(sparsity)), sparsity |
| 178 | + |
146 | 179 |
|
147 | 180 | if __name__ == "__main__": |
148 | 181 | pytest.main([__file__, "-v"]) |
0 commit comments