|
| 1 | +# Skip-Softmax Sparse Attention for Diffusion Models |
| 2 | + |
| 3 | +> [!WARNING] |
| 4 | +> **Third-Party License Notice — LTX-2** |
| 5 | +> |
| 6 | +> LTX-2 packages (`ltx-core`, `ltx-pipelines`, `ltx-trainer`) are third-party dependencies |
| 7 | +> developed and provided by [Lightricks](https://github.com/Lightricks/LTX-2). They are |
| 8 | +> **NOT** covered by the Apache 2.0 license governing NVIDIA Model Optimizer. |
| 9 | +> |
| 10 | +> You **MUST** comply with the |
| 11 | +> [LTX Community License Agreement](https://github.com/Lightricks/LTX-2/blob/main/LICENSE) |
| 12 | +> when installing and using LTX-2 with NVIDIA Model Optimizer. Any derivative models or |
| 13 | +> fine-tuned weights produced from LTX-2 (including quantized, distilled, or sparsified |
| 14 | +> checkpoints) remain subject to the LTX Community License Agreement, not Apache 2.0. |
| 15 | +
|
| 16 | +Skip-softmax sparse attention (BLASST, <https://arxiv.org/pdf/2512.12087>) skips KV |
| 17 | +tiles whose attention scores are negligible during the FlashAttention computation, |
| 18 | +reducing FLOPs without retraining. |
| 19 | + |
| 20 | +Two modes are supported: |
| 21 | +- **Fixed raw threshold** — pass a log2-space threshold directly to the Triton |
| 22 | + kernel. No calibration needed. Good for quick testing and sweeps. |
| 23 | +- **Calibrated threshold** — an exponential model |
| 24 | + (`scale_factor = a * exp(b * target_sparsity)`) is calibrated once via the |
| 25 | + Triton calibration kernel, then the target sparsity can be adjusted at runtime |
| 26 | + without recalibration. Log-space fitting (`fit_logspace=True`) is recommended |
| 27 | + for diffusion models where scale_factors span many orders of magnitude. |
| 28 | + |
| 29 | +## Supported Models |
| 30 | + |
| 31 | +| Model | Script | Notes | |
| 32 | +|-------|--------|-------| |
| 33 | +| WAN 2.2 5B | `wan22_skip_softmax.py` | Single transformer, self-attention only | |
| 34 | +| WAN 2.2 14B | `wan22_skip_softmax.py` | Dual transformer (auto-detected) | |
| 35 | +| LTX-2 | (coming soon) | Via `ltx_triton_attention.py` backend | |
| 36 | + |
| 37 | +## Quick Start |
| 38 | + |
| 39 | +```bash |
| 40 | +# Fixed raw threshold (no calibration, fast) |
| 41 | +python wan22_skip_softmax.py \ |
| 42 | + --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ |
| 43 | + --raw-threshold -0.7 \ |
| 44 | + --prompt "A cat playing piano" --output out.mp4 |
| 45 | + |
| 46 | +# With calibration |
| 47 | +python wan22_skip_softmax.py \ |
| 48 | + --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ |
| 49 | + --calibrate --target-sparsity 0.5 \ |
| 50 | + --prompt "A cat playing piano" --output out.mp4 |
| 51 | + |
| 52 | +# Dense baseline (no sparsity, for comparison) |
| 53 | +python wan22_skip_softmax.py \ |
| 54 | + --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ |
| 55 | + --baseline \ |
| 56 | + --prompt "A cat playing piano" --output baseline.mp4 |
| 57 | + |
| 58 | +# Report runtime sparsity (per-layer tile skip ratios) |
| 59 | +python wan22_skip_softmax.py \ |
| 60 | + --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ |
| 61 | + --raw-threshold -0.7 --report-avg-sparsity \ |
| 62 | + --prompt "A cat playing piano" --output out.mp4 |
| 63 | +``` |
| 64 | + |
| 65 | +## Threshold Modes |
| 66 | + |
| 67 | +| Mode | How threshold reaches the kernel | Use case | |
| 68 | +|------|----------------------------------|----------| |
| 69 | +| **Raw threshold** (`--raw-threshold -0.7`) | Passed directly as `skip_threshold_log2` — no conversion | Quick testing, sweeps | |
| 70 | +| **Calibrated** (`--calibrate --target-sparsity 0.5`) | `scale_factor = a * exp(b * target)`, then backend computes `threshold = scale_factor / seq_k`, then kernel converts `log2(threshold) * sm_scale` | Production use with automatic seqlen adaptation | |
| 71 | +| **Static lambda** (default `skip_softmax_threshold=0.1`) | `log2(lambda) * sm_scale` | Fallback when neither raw nor calibrated | |
| 72 | + |
| 73 | +## Known Issues |
| 74 | + |
| 75 | +- **14B dual transformer calibration**: Transformers are calibrated sequentially — transformer_2's calibration runs while transformer_1 is already sparsified, introducing asymmetric calibration conditions. |
| 76 | +- **Minimum achievable sparsity**: Even the strictest threshold may yield 30-40% sparsity on diffusion models (many tiles are inherently negligible). Targets below this floor cause extrapolation; an inference-time warning is emitted. |
0 commit comments