Skip to content

Commit feec81a

Browse files
authored
Add the Skip softmax for diffusion (#1166)
### What does this PR do? Type of change: new feature, new example <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> <!-- Details about the change. --> ## Summary - Add skip-softmax sparse attention (BLASST) for diffusion models via dedicated Triton kernels — an inference kernel with tile skipping and a calibration kernel with vectorized multi-threshold sparsity measurement - Add `triton_skip_softmax` method with exponential model calibration (`scale_factor = a * exp(b * sparsity)`) and log-space fitting for diffusion models - Add Triton kernel backends for diffusers and LTX attention dispatch - Fix calibration to skip RULER dataset generation when user provides their own `forward_loop` (required for non-LLM models) ## Changes ### Triton kernels (`modelopt/torch/kernels/triton_fa.py`) - **`_attn_fwd`**: Forward kernel with optional tile skipping — tiles whose max attention score is far below the running softmax max are skipped entirely (no V load, no softmax, no accumulation). Runtime sparsity measurement via atomic counters. - **`_attn_fwd_calibrate`**: Calibration kernel that computes full attention while measuring how many tiles would be skipped at each of N thresholds simultaneously. Uses per-program output buffers (zero atomic contention) and vectorized multi-threshold comparison. - **`attention()`** / **`attention_calibrate()`**: Python wrappers for inference and calibration kernels. ### Kernel backends (`modelopt/torch/sparsity/attention_sparsity/kernels/`) - **`diffusers_triton_attention.py`**: Registers `modelopt_triton` backend in diffusers' attention dispatch. Handles [B, S, H, D] → varlen layout conversion, calibration/inference mode switching, thread-local configuration, and counter accumulation. - **`ltx_triton_attention.py`**: Patches `ltx_core.Attention` modules for Triton dispatch with the same calibration/inference modes. ### Method (`modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`) - `TritonSkipSoftmaxMethod`: Context managers for calibration (→ calibration kernel) and inference (→ forward kernel with tile skipping). Three threshold priority levels: raw threshold > calibrated scale_factor > static threshold. ### Calibration (`modelopt/torch/sparsity/attention_sparsity/calibration/`) - **`calibrator.py`**: `DynamicThresholdCalibrator` with `fit_logspace` option — fits exponential model in log space (minimizes relative error) for diffusion models where scale_factors span many orders of magnitude. Records observed sparsity range for extrapolation warnings. - **`calibrate.py`**: Skips RULER dataset when `forward_loop` is provided; passes `fit_logspace` through from config. ### Config & conversion - **`config.py`**: `CalibrationConfig.fit_logspace` field (default False, recommended True for diffusion models). `skip_softmax_raw_threshold` field for direct threshold mode. - **`conversion.py`**: Auto-registers diffusers/LTX Triton backends on `sparsify()`. Updated summary display. ### Example - **`wan22_skip_softmax.py`**: End-to-end example for WAN 2.2 5B/14B with baseline, raw-threshold, and calibrated modes. Supports runtime sparsity reporting. ## Threshold modes | Mode | How it works | Use case | |------|-------------|----------| | **Raw threshold** (`--raw-threshold -0.7`) | Passed directly to kernel as `skip_threshold_log2` | Quick testing, sweeps | | **Calibrated** (`--calibrate --target-sparsity 0.5`) | `scale_factor = a * exp(b * target)`, then `threshold = scale_factor / seq_k` at runtime | Production use with seqlen adaptation | | **Static** (default `skip_softmax_threshold=0.1`) | `log2(lambda) * sm_scale` | Fallback | ## Usage ```bash # Fixed raw threshold (no calibration) python examples/diffusers/sparsity/wan22_skip_softmax.py \ --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ --raw-threshold -0.7 \ --prompt "A cat playing piano" --output out.mp4 # With calibration (log-space fit for diffusion models) python examples/diffusers/sparsity/wan22_skip_softmax.py \ --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ --calibrate --target-sparsity 0.5 \ --prompt "A cat playing piano" --output out.mp4 # Dense baseline for comparison python examples/diffusers/sparsity/wan22_skip_softmax.py \ --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ --baseline \ --prompt "A cat playing piano" --output baseline.mp4 ``` ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ <!--- If ❌, explain why. --> - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ <!--- Mandatory --> - Did you write any new necessary tests?: ✅ <!--- Mandatory for new features or examples. --> - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ❌ <!--- Only for new features, API changes, critical bug fixes or backward incompatible changes. --> ### Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added skip-softmax sparse attention support for Diffusers models, enabling efficient video generation * Added support for both eager and Triton attention backends for sparse attention * Added new example script for Wan 2.2 text-to-video generation with sparse attention optimization * **Documentation** * Updated documentation with sparse attention configuration guide and usage examples * **Tests** * Added comprehensive unit tests for kernel backend registration and skip-softmax functionality <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
1 parent 76b6fd5 commit feec81a

35 files changed

Lines changed: 4180 additions & 76 deletions

.github/codecov.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,15 @@ coverage:
1111
target: auto
1212
threshold: 1% # Allow atmost 1% coverage drop from main branch.
1313
patch: false
14+
15+
# Exclude GPU-only Triton kernel files from ALL codecov calculations (project
16+
# and patch checks, all flags). Rationale: these files are dominated by
17+
# @triton.jit kernel bodies that CPU unit tests cannot exercise. GPU tests
18+
# cover them end-to-end (see tests/gpu/torch/sparsity/attention_sparsity/) but
19+
# the `gpu`-flag upload may race with the PR status check, so relying on flag
20+
# combination alone leaves the project check flaky. Dropping these files here
21+
# makes the check deterministic — local `pytest --cov` and GPU runs still
22+
# measure them; only the codecov PR status ignores them.
23+
ignore:
24+
- "modelopt/torch/kernels/triton_fa.py"
25+
- "modelopt/torch/kernels/hf_triton_attention.py"

.github/workflows/example_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ jobs:
6363
strategy: &torch_strategy
6464
fail-fast: false
6565
matrix:
66-
example: [llm_distill, llm_qat, llm_sparsity]
66+
example: [llm_distill, llm_qat, llm_sparsity, diffusers_sparsity]
6767
include:
6868
- example: speculative_decoding
6969
docker_image: "26.01"

examples/diffusers/README.md

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Cache Diffusion is a technique that reuses cached outputs from previous diffusio
1313
| Pre-Requisites | Required & optional packages to use this technique | \[[Link](#pre-requisites)\] | |
1414
| Getting Started | Learn how to optimize your models using quantization/cache diffusion to reduce precision and improve inference efficiency | \[[Link](#getting-started)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] |
1515
| Support Matrix | View the support matrix to see quantization/cahce diffusion compatibility and feature availability across different models | \[[Link](#support-matrix)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] |
16+
| Sparse Attention (Skip-Softmax) | Skip-softmax sparse attention for diffusion models | \[[Link](#sparse-attention-skip-softmax)\] | |
1617
| Cache Diffusion | Caching technique to accelerate inference without compromising quality | \[[Link](#cache-diffusion)\] | |
1718
| Post Training Quantization (PTQ) | Example scripts on how to run PTQ on diffusion models | \[[Link](#post-training-quantization-ptq)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] |
1819
| Quantization Aware Training (QAT) | Example scripts on how to run QAT on diffusion models | \[[Link](#quantization-aware-training-qat)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] |
@@ -290,6 +291,59 @@ mto.restore(pipe.unet, your_quantized_ckpt)
290291

291292
By following these steps, your PEFT LoRA model should be efficiently quantized using ModelOpt, ready for deployment while maximizing performance.
292293

294+
## Sparse Attention (Skip-Softmax)
295+
296+
Skip-softmax sparse attention skips KV tiles whose attention scores are negligible during the softmax computation, reducing FLOPs without retraining. An exponential model (`scale_factor = a * exp(b * target_sparsity)`) is calibrated once, then the target sparsity can be adjusted at runtime without recalibration.
297+
298+
### Getting Started
299+
300+
```python
301+
import modelopt.torch.sparsity.attention_sparsity as mtsa
302+
303+
# 1. Define config with calibration
304+
config = {
305+
"sparse_cfg": {
306+
"calibration": {
307+
"target_sparse_ratio": {"prefill": 0.5},
308+
},
309+
"*.attn1": {
310+
"method": "triton_skip_softmax",
311+
"backend": "triton",
312+
"is_causal": False,
313+
"collect_stats": True,
314+
"enable": True,
315+
},
316+
"*.attn2": {"enable": False},
317+
"default": {"enable": False},
318+
},
319+
}
320+
321+
# 2. Provide a calibration forward loop
322+
def forward_loop(model):
323+
pipeline(prompt="a cat", num_frames=81, num_inference_steps=40, ...)
324+
325+
# 3. Sparsify + calibrate
326+
mtsa.sparsify(transformer, config, forward_loop=forward_loop)
327+
328+
# 4. Generate as usual — sparsity is applied automatically
329+
output = pipeline(prompt="a dog on the beach", ...)
330+
```
331+
332+
### Example Scripts
333+
334+
#### Wan 2.2 [Script](./sparsity/wan22_skip_softmax.py)
335+
336+
The 14B model automatically sparsifies both `transformer` and `transformer_2`.
337+
338+
```bash
339+
340+
# 5B/14B model
341+
python sparsity/wan22_skip_softmax.py \
342+
--model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers|Wan-AI/Wan2.2-TI2V-5B-Diffusers \
343+
--calibrate --target-sparsity 0.5 --calib-size 4 \
344+
--prompt "A sunset over mountains" --output out.mp4
345+
```
346+
293347
## Cache Diffusion
294348

295349
Cache Diffusion methods, such as [DeepCache](https://arxiv.org/abs/2312.00858), [Block Caching](https://arxiv.org/abs/2312.03209) and [T-Gate](https://arxiv.org/abs/2404.02747), optimize performance by reusing cached outputs from previous steps instead of recalculating them. This **training-free** caching approach is compatible with a variety of models, like **DiT** and **UNet**, enabling considerable acceleration without compromising quality.
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Skip-Softmax Sparse Attention for Diffusion Models
2+
3+
> [!WARNING]
4+
> **Third-Party License Notice — LTX-2**
5+
>
6+
> LTX-2 packages (`ltx-core`, `ltx-pipelines`, `ltx-trainer`) are third-party dependencies
7+
> developed and provided by [Lightricks](https://github.com/Lightricks/LTX-2). They are
8+
> **NOT** covered by the Apache 2.0 license governing NVIDIA Model Optimizer.
9+
>
10+
> You **MUST** comply with the
11+
> [LTX Community License Agreement](https://github.com/Lightricks/LTX-2/blob/main/LICENSE)
12+
> when installing and using LTX-2 with NVIDIA Model Optimizer. Any derivative models or
13+
> fine-tuned weights produced from LTX-2 (including quantized, distilled, or sparsified
14+
> checkpoints) remain subject to the LTX Community License Agreement, not Apache 2.0.
15+
16+
Skip-softmax sparse attention (BLASST, <https://arxiv.org/pdf/2512.12087>) skips KV
17+
tiles whose attention scores are negligible during the FlashAttention computation,
18+
reducing FLOPs without retraining.
19+
20+
Two modes are supported:
21+
- **Fixed raw threshold** — pass a log2-space threshold directly to the Triton
22+
kernel. No calibration needed. Good for quick testing and sweeps.
23+
- **Calibrated threshold** — an exponential model
24+
(`scale_factor = a * exp(b * target_sparsity)`) is calibrated once via the
25+
Triton calibration kernel, then the target sparsity can be adjusted at runtime
26+
without recalibration. Log-space fitting (`fit_logspace=True`) is recommended
27+
for diffusion models where scale_factors span many orders of magnitude.
28+
29+
## Supported Models
30+
31+
| Model | Script | Notes |
32+
|-------|--------|-------|
33+
| WAN 2.2 5B | `wan22_skip_softmax.py` | Single transformer, self-attention only |
34+
| WAN 2.2 14B | `wan22_skip_softmax.py` | Dual transformer (auto-detected) |
35+
| LTX-2 | (coming soon) | Via `ltx_triton_attention.py` backend |
36+
37+
## Quick Start
38+
39+
```bash
40+
# Fixed raw threshold (no calibration, fast)
41+
python wan22_skip_softmax.py \
42+
--model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
43+
--raw-threshold -0.7 \
44+
--prompt "A cat playing piano" --output out.mp4
45+
46+
# With calibration
47+
python wan22_skip_softmax.py \
48+
--model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
49+
--calibrate --target-sparsity 0.5 \
50+
--prompt "A cat playing piano" --output out.mp4
51+
52+
# Dense baseline (no sparsity, for comparison)
53+
python wan22_skip_softmax.py \
54+
--model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
55+
--baseline \
56+
--prompt "A cat playing piano" --output baseline.mp4
57+
58+
# Report runtime sparsity (per-layer tile skip ratios)
59+
python wan22_skip_softmax.py \
60+
--model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
61+
--raw-threshold -0.7 --report-avg-sparsity \
62+
--prompt "A cat playing piano" --output out.mp4
63+
```
64+
65+
## Threshold Modes
66+
67+
| Mode | How threshold reaches the kernel | Use case |
68+
|------|----------------------------------|----------|
69+
| **Raw threshold** (`--raw-threshold -0.7`) | Passed directly as `skip_threshold_log2` — no conversion | Quick testing, sweeps |
70+
| **Calibrated** (`--calibrate --target-sparsity 0.5`) | `scale_factor = a * exp(b * target)`, then backend computes `threshold = scale_factor / seq_k`, then kernel converts `log2(threshold) * sm_scale` | Production use with automatic seqlen adaptation |
71+
| **Static lambda** (default `skip_softmax_threshold=0.1`) | `log2(lambda) * sm_scale` | Fallback when neither raw nor calibrated |
72+
73+
## Known Issues
74+
75+
- **14B dual transformer calibration**: Transformers are calibrated sequentially — transformer_2's calibration runs while transformer_1 is already sparsified, introducing asymmetric calibration conditions.
76+
- **Minimum achievable sparsity**: Even the strictest threshold may yield 30-40% sparsity on diffusion models (many tiles are inherently negligible). Targets below this floor cause extrapolation; an inference-time warning is emitted.

0 commit comments

Comments
 (0)