Commit b1f9f01
authored
[2/n] Add sparse softmax to the Triton flash attention kernel (#1078)
### What does this PR do?
Type of change: ? <!-- Use one of the following: Bug fix, new feature,
new example, new tests, documentation. -->
Type of change: New feature
Add N:M structured sparsity support to the Triton flash attention kernel
(`modelopt/torch/kernels/triton_fa.py`). For every M consecutive key
positions in the attention score tile, keeps the top-N values and sets
the rest to -inf before softmax. This is applied during prefill only.
**Supported patterns:** Any N:M where M=4 (N=1,2,3) or M=8 (N=1..4).
- Sink tokens and dense window blocks for preserving local attention and
attention sinks
**Performance (TFLOPS at seq_len=16384, RTX 6000):**
| Pattern | TFLOPS | % of Dense |
|---------|--------|------------|
| Dense | 89.3 | 100% |
| 2:4 (M=4) | 69.5 | 78% |
| 4:8 (M=8) | 57.3 | 64% |
### Usage
```python
# Add a code snippet demonstrating how to use this
from modelopt.torch.kernels import attention
# 2:4 sparsity (keep top 2 of every 4 K positions)
out = attention(q, k, v, b_start_loc, b_seq_len, max_len,
sparsity_n=2, sparsity_m=4)
# 4:8 sparsity with sink tokens and dense window
out = attention(q, k, v, b_start_loc, b_seq_len, max_len,
sparsity_n=4, sparsity_m=8,
num_sink_tokens=4, dense_window_blocks=2)
# Dense (default, zero overhead)
out = attention(q, k, v, b_start_loc, b_seq_len, max_len)
# Via mtsa.sparsify() on HuggingFace models
import modelopt.torch.sparsity.attention_sparsity as mtsa
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B",
torch_dtype=torch.bfloat16,
device_map="cuda")
# Default config
mtsa.sparsify(model, mtsa.SPARSE_SOFTMAX_DEFAULT)
```
### Testing
<!-- Mention how have you tested your change if applicable. -->
### Before your PR is "*Ready for review*"
Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)
and your commits are signed (`git commit -s -S`).
Make sure you read and follow the [Security Best
Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors)
(e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(...,
weights_only=False)`, `pickle`, etc.).
- Is this change backward compatible?: ✅ / ❌ / N/A <!--- If ❌, explain
why. -->
- If you copied code from any other sources or added a new PIP
dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ / ❌ / N/A
<!--- Mandatory -->
- Did you write any new necessary tests?: ✅ / ❌ / N/A <!--- Mandatory
for new features or examples. -->
- Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?:
✅ / ❌ / N/A <!--- Only for new features, API changes, critical bug fixes
or backward incompatible changes. -->
### Additional Information
<!-- E.g. related issue. -->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* N:M structured sparse softmax for Triton flash-attention prefill with
configurable dense-window and sink-token handling.
* **API**
* attention(...) accepts sparsity_n, sparsity_m, num_sink_tokens,
dense_window_size; HF/Triton prefill path propagates them.
* **Configuration**
* New config fields and exported preset to enable/configure Triton N:M
sparse softmax with validation.
* **Tests**
* Added GPU tests covering N:M behavior, tile structure,
forward/backward correctness.
* **Documentation**
* CHANGELOG and example docs updated with usage and CLI options.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Kai Xu <kaix@nvidia.com>1 parent 291498b commit b1f9f01
12 files changed
Lines changed: 1125 additions & 279 deletions
File tree
- examples/llm_sparsity/attention_sparsity
- modelopt/torch
- kernels
- sparsity/attention_sparsity
- methods
- tests/gpu/torch/sparsity/attention_sparsity
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
7 | 7 | | |
8 | 8 | | |
9 | 9 | | |
| 10 | + | |
10 | 11 | | |
11 | 12 | | |
12 | 13 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | | - | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
4 | 9 | | |
5 | 10 | | |
6 | 11 | | |
| |||
29 | 34 | | |
30 | 35 | | |
31 | 36 | | |
32 | | - | |
| 37 | + | |
33 | 38 | | |
34 | | - | |
| 39 | + | |
35 | 40 | | |
36 | 41 | | |
37 | 42 | | |
| |||
41 | 46 | | |
42 | 47 | | |
43 | 48 | | |
44 | | - | |
| 49 | + | |
45 | 50 | | |
46 | 51 | | |
47 | 52 | | |
| |||
51 | 56 | | |
52 | 57 | | |
53 | 58 | | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
54 | 99 | | |
55 | 100 | | |
56 | 101 | | |
| |||
104 | 149 | | |
105 | 150 | | |
106 | 151 | | |
107 | | - | |
108 | | - | |
| 152 | + | |
| 153 | + | |
109 | 154 | | |
110 | 155 | | |
111 | 156 | | |
| |||
166 | 211 | | |
167 | 212 | | |
168 | 213 | | |
| 214 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
31 | 31 | | |
32 | 32 | | |
33 | 33 | | |
| 34 | + | |
34 | 35 | | |
35 | 36 | | |
36 | 37 | | |
| |||
43 | 44 | | |
44 | 45 | | |
45 | 46 | | |
| 47 | + | |
46 | 48 | | |
47 | 49 | | |
48 | 50 | | |
| |||
168 | 170 | | |
169 | 171 | | |
170 | 172 | | |
171 | | - | |
172 | | - | |
173 | | - | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
174 | 177 | | |
175 | 178 | | |
176 | 179 | | |
| |||
240 | 243 | | |
241 | 244 | | |
242 | 245 | | |
243 | | - | |
| 246 | + | |
244 | 247 | | |
245 | | - | |
| 248 | + | |
| 249 | + | |
246 | 250 | | |
247 | 251 | | |
248 | 252 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
105 | 105 | | |
106 | 106 | | |
107 | 107 | | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
108 | 119 | | |
109 | 120 | | |
110 | 121 | | |
| |||
0 commit comments