Skip to content

Commit a76f561

Browse files
committed
Skip-softmax sparse attention for diffusion models (LTX-2, Wan2.2, etc.)
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
1 parent e4b054b commit a76f561

24 files changed

Lines changed: 2611 additions & 76 deletions

File tree

.github/workflows/example_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ jobs:
6363
strategy: &torch_strategy
6464
fail-fast: false
6565
matrix:
66-
example: [llm_distill, llm_qat, llm_sparsity]
66+
example: [llm_distill, llm_qat, llm_sparsity, diffusers_sparsity]
6767
include:
6868
- example: speculative_decoding
6969
docker_image: "26.01"

examples/diffusers/README.md

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Cache Diffusion is a technique that reuses cached outputs from previous diffusio
1313
| Pre-Requisites | Required & optional packages to use this technique | \[[Link](#pre-requisites)\] | |
1414
| Getting Started | Learn how to optimize your models using quantization/cache diffusion to reduce precision and improve inference efficiency | \[[Link](#getting-started)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] |
1515
| Support Matrix | View the support matrix to see quantization/cahce diffusion compatibility and feature availability across different models | \[[Link](#support-matrix)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] |
16+
| Sparse Attention (Skip-Softmax) | Skip-softmax sparse attention for diffusion models | \[[Link](#sparse-attention-skip-softmax)\] | |
1617
| Cache Diffusion | Caching technique to accelerate inference without compromising quality | \[[Link](#cache-diffusion)\] | |
1718
| Post Training Quantization (PTQ) | Example scripts on how to run PTQ on diffusion models | \[[Link](#post-training-quantization-ptq)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] |
1819
| Quantization Aware Training (QAT) | Example scripts on how to run QAT on diffusion models | \[[Link](#quantization-aware-training-qat)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] |
@@ -290,6 +291,59 @@ mto.restore(pipe.unet, your_quantized_ckpt)
290291

291292
By following these steps, your PEFT LoRA model should be efficiently quantized using ModelOpt, ready for deployment while maximizing performance.
292293

294+
## Sparse Attention (Skip-Softmax)
295+
296+
Skip-softmax sparse attention skips KV tiles whose attention scores are negligible during the softmax computation, reducing FLOPs without retraining. An exponential model (`scale_factor = a * exp(b * target_sparsity)`) is calibrated once, then the target sparsity can be adjusted at runtime without recalibration.
297+
298+
### Getting Started
299+
300+
```python
301+
import modelopt.torch.sparsity.attention_sparsity as mtsa
302+
303+
# 1. Define config with calibration
304+
config = {
305+
"sparse_cfg": {
306+
"calibration": {
307+
"target_sparse_ratio": {"prefill": 0.5},
308+
},
309+
"*.attn1": {
310+
"method": "triton_skip_softmax",
311+
"backend": "triton",
312+
"is_causal": False,
313+
"collect_stats": True,
314+
"enable": True,
315+
},
316+
"*.attn2": {"enable": False},
317+
"default": {"enable": False},
318+
},
319+
}
320+
321+
# 2. Provide a calibration forward loop
322+
def forward_loop(model):
323+
pipeline(prompt="a cat", num_frames=81, num_inference_steps=40, ...)
324+
325+
# 3. Sparsify + calibrate
326+
mtsa.sparsify(transformer, config, forward_loop=forward_loop)
327+
328+
# 4. Generate as usual — sparsity is applied automatically
329+
output = pipeline(prompt="a dog on the beach", ...)
330+
```
331+
332+
### Example Scripts
333+
334+
#### Wan 2.2 [Script](./sparsity/wan22_skip_softmax.py)
335+
336+
The 14B model automatically sparsifies both `transformer` and `transformer_2`.
337+
338+
```bash
339+
340+
# 5B/14B model
341+
python sparsity/wan22_skip_softmax.py \
342+
--model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers|Wan-AI/Wan2.2-TI2V-5B-Diffusers \
343+
--calibrate --target-sparsity 0.5 --calib-size 4 \
344+
--prompt "A sunset over mountains" --output out.mp4
345+
```
346+
293347
## Cache Diffusion
294348

295349
Cache Diffusion methods, such as [DeepCache](https://arxiv.org/abs/2312.00858), [Block Caching](https://arxiv.org/abs/2312.03209) and [T-Gate](https://arxiv.org/abs/2404.02747), optimize performance by reusing cached outputs from previous steps instead of recalculating them. This **training-free** caching approach is compatible with a variety of models, like **DiT** and **UNet**, enabling considerable acceleration without compromising quality.
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Skip-Softmax Sparse Attention for Diffusion Models
2+
3+
> [!WARNING]
4+
> **Third-Party License Notice — LTX-2**
5+
>
6+
> LTX-2 packages (`ltx-core`, `ltx-pipelines`, `ltx-trainer`) are third-party dependencies
7+
> developed and provided by [Lightricks](https://github.com/Lightricks/LTX-2). They are
8+
> **NOT** covered by the Apache 2.0 license governing NVIDIA Model Optimizer.
9+
>
10+
> You **MUST** comply with the
11+
> [LTX Community License Agreement](https://github.com/Lightricks/LTX-2/blob/main/LICENSE)
12+
> when installing and using LTX-2 with NVIDIA Model Optimizer. Any derivative models or
13+
> fine-tuned weights produced from LTX-2 (including quantized, distilled, or sparsified
14+
> checkpoints) remain subject to the LTX Community License Agreement, not Apache 2.0.
15+
16+
Skip-softmax sparse attention (BLASST, <https://arxiv.org/pdf/2512.12087>) skips KV
17+
tiles whose attention scores are negligible during the FlashAttention computation,
18+
reducing FLOPs without retraining.
19+
20+
Two modes are supported:
21+
- **Fixed raw threshold** — pass a log2-space threshold directly to the Triton
22+
kernel. No calibration needed. Good for quick testing and sweeps.
23+
- **Calibrated threshold** — an exponential model
24+
(`scale_factor = a * exp(b * target_sparsity)`) is calibrated once via the
25+
Triton calibration kernel, then the target sparsity can be adjusted at runtime
26+
without recalibration. Log-space fitting (`fit_logspace=True`) is recommended
27+
for diffusion models where scale_factors span many orders of magnitude.
28+
29+
## Supported Models
30+
31+
| Model | Script | Notes |
32+
|-------|--------|-------|
33+
| WAN 2.2 5B | `wan22_skip_softmax.py` | Single transformer, self-attention only |
34+
| WAN 2.2 14B | `wan22_skip_softmax.py` | Dual transformer (auto-detected) |
35+
| LTX-2 | (coming soon) | Via `ltx_triton_attention.py` backend |
36+
37+
## Quick Start
38+
39+
```bash
40+
# Fixed raw threshold (no calibration, fast)
41+
python wan22_skip_softmax.py \
42+
--model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
43+
--raw-threshold -0.7 \
44+
--prompt "A cat playing piano" --output out.mp4
45+
46+
# With calibration
47+
python wan22_skip_softmax.py \
48+
--model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
49+
--calibrate --target-sparsity 0.5 \
50+
--prompt "A cat playing piano" --output out.mp4
51+
52+
# Dense baseline (no sparsity, for comparison)
53+
python wan22_skip_softmax.py \
54+
--model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
55+
--baseline \
56+
--prompt "A cat playing piano" --output baseline.mp4
57+
58+
# Report runtime sparsity (per-layer tile skip ratios)
59+
python wan22_skip_softmax.py \
60+
--model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
61+
--raw-threshold -0.7 --report-avg-sparsity \
62+
--prompt "A cat playing piano" --output out.mp4
63+
```
64+
65+
## Threshold Modes
66+
67+
| Mode | How threshold reaches the kernel | Use case |
68+
|------|----------------------------------|----------|
69+
| **Raw threshold** (`--raw-threshold -0.7`) | Passed directly as `skip_threshold_log2` — no conversion | Quick testing, sweeps |
70+
| **Calibrated** (`--calibrate --target-sparsity 0.5`) | `scale_factor = a * exp(b * target)`, then backend computes `threshold = scale_factor / seq_k`, then kernel converts `log2(threshold) * sm_scale` | Production use with automatic seqlen adaptation |
71+
| **Static lambda** (default `skip_softmax_threshold=0.1`) | `log2(lambda) * sm_scale` | Fallback when neither raw nor calibrated |
72+
73+
## Known Issues
74+
75+
- **14B dual transformer calibration**: Transformers are calibrated sequentially — transformer_2's calibration runs while transformer_1 is already sparsified, introducing asymmetric calibration conditions.
76+
- **Minimum achievable sparsity**: Even the strictest threshold may yield 30-40% sparsity on diffusion models (many tiles are inherently negligible). Targets below this floor cause extrapolation; an inference-time warning is emitted.

0 commit comments

Comments
 (0)