Skip to content

Commit b1f9f01

Browse files
authored
[2/n] Add sparse softmax to the Triton flash attention kernel (#1078)
### What does this PR do? Type of change: ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> Type of change: New feature Add N:M structured sparsity support to the Triton flash attention kernel (`modelopt/torch/kernels/triton_fa.py`). For every M consecutive key positions in the attention score tile, keeps the top-N values and sets the rest to -inf before softmax. This is applied during prefill only. **Supported patterns:** Any N:M where M=4 (N=1,2,3) or M=8 (N=1..4). - Sink tokens and dense window blocks for preserving local attention and attention sinks **Performance (TFLOPS at seq_len=16384, RTX 6000):** | Pattern | TFLOPS | % of Dense | |---------|--------|------------| | Dense | 89.3 | 100% | | 2:4 (M=4) | 69.5 | 78% | | 4:8 (M=8) | 57.3 | 64% | ### Usage ```python # Add a code snippet demonstrating how to use this from modelopt.torch.kernels import attention # 2:4 sparsity (keep top 2 of every 4 K positions) out = attention(q, k, v, b_start_loc, b_seq_len, max_len, sparsity_n=2, sparsity_m=4) # 4:8 sparsity with sink tokens and dense window out = attention(q, k, v, b_start_loc, b_seq_len, max_len, sparsity_n=4, sparsity_m=8, num_sink_tokens=4, dense_window_blocks=2) # Dense (default, zero overhead) out = attention(q, k, v, b_start_loc, b_seq_len, max_len) # Via mtsa.sparsify() on HuggingFace models import modelopt.torch.sparsity.attention_sparsity as mtsa from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", torch_dtype=torch.bfloat16, device_map="cuda") # Default config mtsa.sparsify(model, mtsa.SPARSE_SOFTMAX_DEFAULT) ``` ### Testing <!-- Mention how have you tested your change if applicable. --> ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ / ❌ / N/A <!--- If ❌, explain why. --> - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ / ❌ / N/A <!--- Mandatory --> - Did you write any new necessary tests?: ✅ / ❌ / N/A <!--- Mandatory for new features or examples. --> - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ / ❌ / N/A <!--- Only for new features, API changes, critical bug fixes or backward incompatible changes. --> ### Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * N:M structured sparse softmax for Triton flash-attention prefill with configurable dense-window and sink-token handling. * **API** * attention(...) accepts sparsity_n, sparsity_m, num_sink_tokens, dense_window_size; HF/Triton prefill path propagates them. * **Configuration** * New config fields and exported preset to enable/configure Triton N:M sparse softmax with validation. * **Tests** * Added GPU tests covering N:M behavior, tile structure, forward/backward correctness. * **Documentation** * CHANGELOG and example docs updated with usage and CLI options. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent 291498b commit b1f9f01

12 files changed

Lines changed: 1125 additions & 279 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ NVIDIA Model Optimizer Changelog
77

88
- Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics.
99
- Added iterator interface using CalibrationDataReader in ONNX quantization workflow.
10+
- Add N:M sparse softmax support to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_sparsity/attention_sparsity>`_ for usage.
1011
- Enable PTQ workflow for the Step3.5-Flash MoE model with NVFP4 W4A4 + FP8 KV cache quantization. See `modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml>`_ for more details.
1112

1213
**Bug Fixes**

examples/llm_sparsity/attention_sparsity/README.md

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
# Attention Sparsity for HuggingFace Models
22

3-
In this tutorial, we demonstrate how to use NVIDIA Model Optimizer to apply attention sparsity to HuggingFace models. Attention sparsity reduces computational cost by skipping near-zero attention scores during the softmax computation. Two attention backends are supported:
3+
In this tutorial, we demonstrate how to use NVIDIA Model Optimizer to apply attention sparsity to HuggingFace models. Two sparsity methods are supported:
4+
5+
- **Skip-softmax** (`flash_skip_softmax`): Skips attention tiles whose contribution is negligible, based on a threshold. Based on the [BLASST](https://arxiv.org/pdf/2512.12087) algorithm.
6+
- **N:M sparse softmax** (`triton_sparse_softmax`): For every M consecutive key positions, keeps the top-N attention scores and sets the rest to -inf before softmax.
7+
8+
Two attention backends are available:
49

510
- **pytorch** (default): Patches `F.softmax` to apply skip-softmax sparsity (requires `attn_implementation="eager"`)
611
- **triton**: Uses a fused Triton Flash Attention kernel with in-kernel sparsity (uses `attn_implementation="modelopt_triton"`)
@@ -29,9 +34,9 @@ model = mtsa.sparsify(model, config=SKIP_SOFTMAX_DEFAULT)
2934
3035
## Configuration Options
3136

32-
Two pre-defined configurations are available:
37+
### Skip-Softmax
3338

34-
### 1. Fixed Threshold (SKIP_SOFTMAX_DEFAULT)
39+
#### 1. Fixed Threshold (SKIP_SOFTMAX_DEFAULT)
3540

3641
Uses a fixed threshold value. Simple but may not be optimal for all sequence lengths.
3742

@@ -41,7 +46,7 @@ from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAU
4146
model = mtsa.sparsify(model, config=SKIP_SOFTMAX_DEFAULT)
4247
```
4348

44-
### 2. Calibrated Threshold (SKIP_SOFTMAX_CALIB)
49+
#### 2. Calibrated Threshold (SKIP_SOFTMAX_CALIB)
4550

4651
Uses RULER-based calibration to determine an optimal dynamic threshold that adapts to sequence length. Recommended for production use.
4752

@@ -51,6 +56,46 @@ from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_CALIB
5156
model = mtsa.sparsify(model, config=SKIP_SOFTMAX_CALIB)
5257
```
5358

59+
### N:M Sparse Softmax (SPARSE_SOFTMAX_DEFAULT)
60+
61+
Applies N:M structured sparsity to attention scores using the Triton backend. For every M consecutive key positions, keeps only the top-N scores and sets the rest to -inf. Supports M=4 (N=1,2,3) and M=8 (N=1..7). Attention sinks and a local dense window can be configured to preserve important positions.
62+
63+
```python
64+
from modelopt.torch.sparsity.attention_sparsity.config import SPARSE_SOFTMAX_DEFAULT
65+
66+
model = AutoModelForCausalLM.from_pretrained(
67+
"meta-llama/Llama-3.1-8B",
68+
torch_dtype=torch.bfloat16,
69+
device_map="cuda",
70+
)
71+
72+
model = mtsa.sparsify(model, config=SPARSE_SOFTMAX_DEFAULT)
73+
```
74+
75+
Custom N:M configuration:
76+
77+
```python
78+
sparse_cfg = {
79+
"sparse_cfg": {
80+
"*attn*": {
81+
"method": "triton_sparse_softmax",
82+
"sparsity_n": 2, # Keep top-2 of every 4
83+
"sparsity_m": 4, # Group size
84+
"num_sink_tokens": 4, # Keep first 4 tokens dense (attention sinks)
85+
"dense_window_size": 128, # Keep tokens within distance 128 dense
86+
"backend": "triton",
87+
"enable": True,
88+
},
89+
"default": {"enable": False},
90+
},
91+
}
92+
93+
model = mtsa.sparsify(model, config=sparse_cfg)
94+
```
95+
96+
> [!Note]
97+
> N:M sparse softmax requires the Triton backend (`backend="triton"`). The `attn_implementation` is automatically set to `"modelopt_triton"` by `mtsa.sparsify()`. N:M sparsity is applied during prefill only — decode tokens are not sparsified.
98+
5499
## Prerequisites
55100

56101
### Local Installation
@@ -104,8 +149,8 @@ The calibration process:
104149
| Argument | Default | Description |
105150
|----------|---------|-------------|
106151
| `--pyt_ckpt_path` | Required | HuggingFace model path or name |
107-
| `--sparse_attn` | `skip_softmax` | Configuration: `skip_softmax` or `skip_softmax_calib` |
108-
| `--backend` | `pytorch` | Backend: `pytorch` (only supported backend) |
152+
| `--sparse_attn` | `skip_softmax` | Configuration: `skip_softmax`, `skip_softmax_calib`, or `sparse_softmax` |
153+
| `--backend` | `pytorch` | Backend: `pytorch` (skip-softmax) or `triton` (N:M sparse softmax) |
109154
| `--seq_len` | `2048` | Maximum sequence length for input prompts |
110155
| `--export_dir` | `None` | Directory to export the sparsified model |
111156

@@ -166,3 +211,4 @@ model = mtsa.sparsify(model, config=custom_config)
166211

167212
- [Model Optimizer Documentation](https://nvidia.github.io/Model-Optimizer/)
168213
- [RULER: What's the Real Context Size of Your Long-Context Language Models?](https://github.com/NVIDIA/RULER)
214+
- [BLASST: Block-Level Adaptive Structured Sparse Training](https://arxiv.org/pdf/2512.12087) — skip-softmax algorithm

examples/llm_sparsity/attention_sparsity/hf_sa.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from modelopt.torch.sparsity.attention_sparsity.config import (
3232
SKIP_SOFTMAX_CALIB,
3333
SKIP_SOFTMAX_DEFAULT,
34+
SPARSE_SOFTMAX_DEFAULT,
3435
)
3536
from modelopt.torch.utils.memory_monitor import launch_memory_monitor
3637

@@ -43,6 +44,7 @@
4344
SPARSE_ATTN_CFG_CHOICES = {
4445
"skip_softmax": SKIP_SOFTMAX_DEFAULT,
4546
"skip_softmax_calib": SKIP_SOFTMAX_CALIB,
47+
"sparse_softmax": SPARSE_SOFTMAX_DEFAULT,
4648
}
4749

4850

@@ -168,9 +170,10 @@ def main(args):
168170

169171
# Apply CLI overrides to sparse_cfg
170172
sparse_cfg = sparse_config.get("sparse_cfg", {})
171-
for layer_cfg in sparse_cfg.values():
172-
if isinstance(layer_cfg, dict) and "method" in layer_cfg:
173-
layer_cfg["backend"] = args.backend
173+
if args.backend is not None:
174+
for layer_cfg in sparse_cfg.values():
175+
if isinstance(layer_cfg, dict) and "method" in layer_cfg:
176+
layer_cfg["backend"] = args.backend
174177
if args.target_sparse_ratio is not None:
175178
calib = sparse_cfg.setdefault("calibration", {})
176179
assert isinstance(calib, dict)
@@ -240,9 +243,10 @@ def main(args):
240243
parser.add_argument(
241244
"--backend",
242245
type=str,
243-
default="pytorch",
246+
default=None,
244247
choices=["pytorch", "triton"],
245-
help="Backend for sparse attention (default: pytorch). 'triton' uses the fused Triton kernel.",
248+
help="Backend for sparse attention. Overrides the config default if set. "
249+
"'triton' uses the fused Triton kernel.",
246250
)
247251

248252
# Sequence length arguments

modelopt/torch/kernels/hf_triton_attention.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,17 @@ def triton_attention_forward(
105105
kw["b_seq_len_k"] = torch.full((batch,), seq_k, device=device, dtype=torch.int32)
106106
kw["max_input_len_k"] = seq_k
107107

108+
# N:M sparse softmax — prefill only (decode should not sparsify KV)
109+
if not is_decode and getattr(module, "_apply_sparse_nm", False):
110+
# _sparse_method_instance is set by SparseAttentionModule._init_sparse_method()
111+
# in modelopt/torch/sparsity/attention_sparsity/sparse_attention.py
112+
method = getattr(module, "_sparse_method_instance", None)
113+
if method is not None:
114+
kw["sparsity_n"] = getattr(method, "sparsity_n", 2)
115+
kw["sparsity_m"] = getattr(method, "sparsity_m", 4)
116+
kw["num_sink_tokens"] = getattr(method, "num_sink_tokens", 0)
117+
kw["dense_window_size"] = getattr(method, "dense_window_size", 64)
118+
108119
o = attention(q, k, v, **kw)
109120

110121
attn_output = o.view(batch, seq_len, num_heads, head_dim)

0 commit comments

Comments
 (0)