|
1 | 1 | # Attention Sparsity for HuggingFace Models |
2 | 2 |
|
3 | | -In this tutorial, we demonstrate how to use NVIDIA Model Optimizer to apply attention sparsity to HuggingFace models. Attention sparsity reduces computational cost by skipping near-zero attention scores during the softmax computation. |
| 3 | +In this tutorial, we demonstrate how to use NVIDIA Model Optimizer to apply attention sparsity to HuggingFace models. Two methods are supported: |
| 4 | + |
| 5 | +- **Skip-Softmax**: Threshold-based skipping of near-zero attention scores during softmax (requires `attn_implementation="eager"`) |
| 6 | +- **Sparse24 Triton**: Fine-grained 2:4 sparsity on attention scores via a fused Triton kernel with autograd support (uses `attn_implementation="modelopt_triton"`) |
4 | 7 |
|
5 | 8 | ## Getting Started |
6 | 9 |
|
@@ -159,6 +162,82 @@ custom_config = { |
159 | 162 | model = mtsa.sparsify(model, config=custom_config) |
160 | 163 | ``` |
161 | 164 |
|
| 165 | +## Fine-grained 2:4 Sparse Attention |
| 166 | + |
| 167 | +In addition to skip-softmax, Model Optimizer supports **fine-grained 2:4 sparsity** on attention scores via a fused Triton kernel. For every 4 attention scores along the key dimension, the kernel keeps only the top 2 and zeros out the rest — achieving 50% fixed sparsity with no calibration needed. |
| 168 | + |
| 169 | +### Quick Example |
| 170 | + |
| 171 | +```python |
| 172 | +import modelopt.torch.sparsity.attention_sparsity as mtsa |
| 173 | +from modelopt.torch.sparsity.attention_sparsity.config import SPARSE24_TRITON |
| 174 | + |
| 175 | +model = AutoModelForCausalLM.from_pretrained( |
| 176 | + "meta-llama/Llama-3.1-8B", |
| 177 | + torch_dtype=torch.bfloat16, |
| 178 | +) |
| 179 | + |
| 180 | +model = mtsa.sparsify(model, config=SPARSE24_TRITON) |
| 181 | +``` |
| 182 | + |
| 183 | +> [!Note] |
| 184 | +> Unlike skip-softmax, sparse24 does **not** require `attn_implementation="eager"`. The `mtsa.sparsify` call automatically registers the Triton kernel as `attn_implementation="modelopt_triton"`. |
| 185 | +
|
| 186 | +### Running via Command Line |
| 187 | + |
| 188 | +```bash |
| 189 | +python hf_sa.py \ |
| 190 | + --pyt_ckpt_path meta-llama/Llama-3.1-8B \ |
| 191 | + --sparse_attn sparse24_triton \ |
| 192 | + --backend triton |
| 193 | +``` |
| 194 | + |
| 195 | +### Key Differences from Skip-Softmax |
| 196 | + |
| 197 | +| | Skip-Softmax | Sparse24 Triton | |
| 198 | +|---|---|---| |
| 199 | +| Method | Threshold-based softmax skipping | 2:4 structured sparsity on attention scores | |
| 200 | +| Attention backend | `eager` (patches `F.softmax`) | `modelopt_triton` (fused Triton kernel) | |
| 201 | +| Calibration | Optional (RULER-based) | Not needed (fixed top-2-of-4 selection) | |
| 202 | +| Sparsity ratio | Variable (depends on threshold) | Fixed 50% | |
| 203 | +| Diagonal preservation | N/A | Yes (tiles near the causal diagonal are kept dense) | |
| 204 | +| Training support | No | Yes (autograd-compatible forward/backward) | |
| 205 | +| Decode support | Yes | Yes (same kernel, `is_causal=False`) | |
| 206 | + |
| 207 | +### Training with Sparse24 Attention |
| 208 | + |
| 209 | +The Triton kernel supports autograd. When `requires_grad=True`, the HF integration automatically uses the backward-capable path: |
| 210 | + |
| 211 | +```python |
| 212 | +model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", torch_dtype=torch.bfloat16) |
| 213 | +model = mtsa.sparsify(model, config=SPARSE24_TRITON) |
| 214 | +model.train() |
| 215 | + |
| 216 | +# Gradients flow through the sparse attention |
| 217 | +output = model(input_ids=ids, labels=labels) |
| 218 | +output.loss.backward() # dQ, dK, dV computed via Triton backward kernels |
| 219 | +``` |
| 220 | + |
| 221 | +### Custom Sparse24 Configuration |
| 222 | + |
| 223 | +```python |
| 224 | +custom_config = { |
| 225 | + "sparse_cfg": { |
| 226 | + "*attn*": { |
| 227 | + "method": "sparse24_triton", |
| 228 | + "backend": "triton", |
| 229 | + "skip_diagonal_blocks": True, # Keep diagonal tiles dense (recommended) |
| 230 | + "enable": True, |
| 231 | + }, |
| 232 | + "default": {"enable": False}, |
| 233 | + }, |
| 234 | +} |
| 235 | + |
| 236 | +model = mtsa.sparsify(model, config=custom_config) |
| 237 | +``` |
| 238 | + |
| 239 | +Set `skip_diagonal_blocks: False` to apply 2:4 sparsity to all tiles including the diagonal (more aggressive but may hurt quality for local attention patterns). |
| 240 | + |
162 | 241 | ## References |
163 | 242 |
|
164 | 243 | - [Model Optimizer Documentation](https://nvidia.github.io/Model-Optimizer/) |
|
0 commit comments