Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1d8ac33
Add the Skip softmax diffusion
jingyu-ml Apr 2, 2026
1f8f0d3
Add test case
jingyu-ml Apr 2, 2026
5873652
Fixed error
jingyu-ml Apr 2, 2026
4c179a3
Fixed the test case
jingyu-ml Apr 2, 2026
2c323df
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml Apr 2, 2026
8702b7b
Removed the token import
jingyu-ml Apr 6, 2026
bbe2123
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml Apr 6, 2026
70099a5
removed the unused code
jingyu-ml Apr 6, 2026
6cc96a4
Update the README
jingyu-ml Apr 6, 2026
4de0d3b
Updated the example script
jingyu-ml Apr 7, 2026
b3d3d4d
Update the readme
jingyu-ml Apr 7, 2026
8dc6162
Update the calibration kernel
jingyu-ml Apr 7, 2026
8aa32cc
ADd the readme
jingyu-ml Apr 7, 2026
fbeabcf
Update the example script
jingyu-ml Apr 7, 2026
6a4ab8b
Update the code
jingyu-ml Apr 7, 2026
d7dd15c
Update the calibration loop
jingyu-ml Apr 8, 2026
b86d311
Remove the eager attention
jingyu-ml Apr 8, 2026
f5a9af9
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml Apr 8, 2026
45bcad6
Update the calibration, fixed some bugs
jingyu-ml Apr 9, 2026
22c5b85
Add the test case
jingyu-ml Apr 9, 2026
aa44a9d
Fixed the lint error
jingyu-ml Apr 9, 2026
e5293de
Updated the README
jingyu-ml Apr 9, 2026
40fdd44
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml Apr 9, 2026
40d61dd
Update the test case
jingyu-ml Apr 9, 2026
3845b47
Fixed the CICD
jingyu-ml Apr 9, 2026
560015c
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml Apr 13, 2026
f86580c
Added the ltx2 warning
jingyu-ml Apr 13, 2026
ee162b3
addressed the ltx2 issue and the import issue
jingyu-ml Apr 13, 2026
eef0577
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml Apr 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions examples/diffusers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Cache Diffusion is a technique that reuses cached outputs from previous diffusio
| Pre-Requisites | Required & optional packages to use this technique | \[[Link](#pre-requisites)\] | |
| Getting Started | Learn how to optimize your models using quantization/cache diffusion to reduce precision and improve inference efficiency | \[[Link](#getting-started)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] |
| Support Matrix | View the support matrix to see quantization/cahce diffusion compatibility and feature availability across different models | \[[Link](#support-matrix)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] |
| Sparse Attention (Skip-Softmax) | Skip-softmax sparse attention for diffusion models | \[[Link](#sparse-attention-skip-softmax)\] | |
| Cache Diffusion | Caching technique to accelerate inference without compromising quality | \[[Link](#cache-diffusion)\] | |
| Post Training Quantization (PTQ) | Example scripts on how to run PTQ on diffusion models | \[[Link](#post-training-quantization-ptq)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] |
| Quantization Aware Training (QAT) | Example scripts on how to run QAT on diffusion models | \[[Link](#quantization-aware-training-qat)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] |
Expand Down Expand Up @@ -290,6 +291,67 @@ mto.restore(pipe.unet, your_quantized_ckpt)

By following these steps, your PEFT LoRA model should be efficiently quantized using ModelOpt, ready for deployment while maximizing performance.

## Sparse Attention (Skip-Softmax)

Skip-softmax sparse attention skips KV tiles whose attention scores are negligible during the softmax computation, reducing FLOPs without retraining. An exponential model (`scale_factor = a * exp(b * target_sparsity)`) is calibrated once, then the target sparsity can be adjusted at runtime without recalibration.

### Getting Started

```python
import modelopt.torch.sparsity.attention_sparsity as mtsa

# 1. Define config with calibration
config = {
"sparse_cfg": {
"calibration": {
"target_sparse_ratio": {"prefill": 0.5},
"threshold_trials": [1e-6, 5e-6, 1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3,
1e-2, 2e-2, 5e-2, 1e-1, 2e-1, 3e-1, 5e-1, 7e-1,
8e-1, 9e-1, 9.9e-1],
},
"*.attn1": {
"method": "triton_skip_softmax",
"backend": "triton",
"is_causal": False,
"collect_stats": True,
"enable": True,
},
"*.attn2": {"enable": False},
"default": {"enable": False},
},
}

# 2. Provide a calibration forward loop
def forward_loop(model):
pipeline(prompt="a cat", num_frames=81, num_inference_steps=40, ...)

# 3. Sparsify + calibrate
mtsa.sparsify(transformer, config, forward_loop=forward_loop)

# 4. Generate as usual — sparsity is applied automatically
output = pipeline(prompt="a dog on the beach", ...)
```

### Example Scripts

#### Wan 2.2 [Script](./sparsity/wan22_skip_softmax.py)

The 14B model automatically sparsifies both `transformer` and `transformer_2`.

```bash
# 5B model — calibrate + generate (4 prompts from OpenVid-1M, 151 frames, 40 steps)
python sparsity/wan22_skip_softmax.py \
--model-path Wan-AI/Wan2.2-TI2V-5B-Diffusers \
--calibrate --target-sparsity 0.5 --calib-size 4 \
--prompt "A sunset over mountains" --output out.mp4

# 14B model (both transformers sparsified)
python sparsity/wan22_skip_softmax.py \
--model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers \
--calibrate --target-sparsity 0.5 --calib-size 4 \
--prompt "A sunset over mountains" --output out.mp4
```

## Cache Diffusion

Cache Diffusion methods, such as [DeepCache](https://arxiv.org/abs/2312.00858), [Block Caching](https://arxiv.org/abs/2312.03209) and [T-Gate](https://arxiv.org/abs/2404.02747), optimize performance by reusing cached outputs from previous steps instead of recalculating them. This **training-free** caching approach is compatible with a variety of models, like **DiT** and **UNet**, enabling considerable acceleration without compromising quality.
Expand Down
154 changes: 154 additions & 0 deletions examples/diffusers/sparsity/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Skip-Softmax Sparse Attention for Diffusion Models

> [!WARNING]
> **Third-Party License Notice — LTX-2**
>
> LTX-2 packages (`ltx-core`, `ltx-pipelines`, `ltx-trainer`) are third-party dependencies
> developed and provided by [Lightricks](https://github.com/Lightricks/LTX-2). They are
> **NOT** covered by the Apache 2.0 license governing NVIDIA Model Optimizer.
>
> You **MUST** comply with the
> [LTX Community License Agreement](https://github.com/Lightricks/LTX-2/blob/main/LICENSE)
> when installing and using LTX-2 with NVIDIA Model Optimizer. Any derivative models or
> fine-tuned weights produced from LTX-2 (including quantized, distilled, or sparsified
> checkpoints) remain subject to the LTX Community License Agreement, not Apache 2.0.

Skip-softmax sparse attention (BLASST, <https://arxiv.org/pdf/2512.12087>) skips KV
tiles whose attention scores are negligible during the FlashAttention computation,
reducing FLOPs without retraining.

Two modes are supported:
- **Fixed raw threshold** — pass a log2-space threshold directly to the Triton
kernel. No calibration needed. Good for quick testing and sweeps.
- **Calibrated threshold** — an exponential model
(`scale_factor = a * exp(b * target_sparsity)`) is calibrated once via the
Triton calibration kernel, then the target sparsity can be adjusted at runtime
without recalibration. Log-space fitting (`fit_logspace=True`) is recommended
for diffusion models where scale_factors span many orders of magnitude.

## Supported Models

| Model | Script | Notes |
|-------|--------|-------|
| WAN 2.2 5B | `wan22_skip_softmax.py` | Single transformer, self-attention only |
| WAN 2.2 14B | `wan22_skip_softmax.py` | Dual transformer (auto-detected) |
| LTX-2 | (coming soon) | Via `ltx_triton_attention.py` backend |

## Quick Start

```bash
# Fixed raw threshold (no calibration, fast)
python wan22_skip_softmax.py \
--model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
--raw-threshold -0.7 \
--prompt "A cat playing piano" --output out.mp4

# With calibration
python wan22_skip_softmax.py \
--model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
--calibrate --target-sparsity 0.5 \
--prompt "A cat playing piano" --output out.mp4

# Dense baseline (no sparsity, for comparison)
python wan22_skip_softmax.py \
--model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
--baseline \
--prompt "A cat playing piano" --output baseline.mp4

# Report runtime sparsity (per-layer tile skip ratios)
python wan22_skip_softmax.py \
--model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
--raw-threshold -0.7 --report-avg-sparsity \
--prompt "A cat playing piano" --output out.mp4
```

## Architecture
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Architecture section is not appropriate for README. Threshold Modes and Known Issues could be kept.


### Inference Path (Triton kernel with tile skipping)

```text
SparseAttentionModule.forward()
└─ triton_skip_softmax._triton_inference_context()
├─ Priority: raw_threshold > scale_factor (calibrated) > static threshold
├─ _set_triton_backends(raw_threshold=X) or (scale_factor=X)
├─ attention_backend("modelopt_triton")
└─ _diffusers_triton_attention() → attention()
└─ _attn_fwd kernel: skip tiles where tile_row_max < row_max + threshold
```

### Calibration Path (Triton calibration kernel)

```text
mtsa.sparsify(transformer, config, forward_loop)
├─ apply_mode() → replace attention with SparseAttentionModule
└─ calibrate()
├─ DynamicThresholdCalibrator._set_thresholds()
│ └─ method._threshold_trials = [1e-6, ..., 9.9e-1]
├─ forward_loop(model)
│ └─ SparseAttentionModule.forward()
│ └─ triton_skip_softmax._triton_calibration_context()
│ ├─ set_triton_skip_softmax_config(calibration_mode=True)
│ ├─ attention_backend("modelopt_triton")
│ └─ _diffusers_triton_attention() → attention_calibrate()
│ └─ _attn_fwd_calibrate kernel:
│ - Full attention (no skipping) for correct output
│ - Vectorized multi-threshold sparsity measurement
│ - Per-program output buffers (no atomic contention)
│ - Python-side reduction: sum across programs
├─ Fit: scale_factor = a * exp(b * sparsity)
│ └─ fit_logspace=True: fits in log space (minimizes relative error)
└─ Apply a, b to all modules
└─ Inference: threshold = scale_factor / seq_k
```

## Core Files

### Triton Kernels (`modelopt/torch/kernels/`)

| File | Role |
|------|------|
| `triton_fa.py` | `_attn_fwd`: forward kernel with optional tile skipping + sparsity measurement. `_attn_fwd_calibrate`: calibration kernel with vectorized multi-threshold testing and per-program buffers (zero atomic contention). `attention()` and `attention_calibrate()` Python APIs. |

### Sparse Attention Methods (`modelopt/torch/sparsity/attention_sparsity/methods/`)

| File | Role |
|------|------|
| `triton_skip_softmax.py` | Primary method for diffusion models. Calibration context → Triton calibration kernel. Inference context → Triton forward kernel. Supports `scale_factor` (calibrated), `raw_threshold` (direct), and static `skip_softmax_threshold`. |
| `flash_skip_softmax.py` | PyTorch-based method for HF LLMs (not used by diffusers/LTX). |
| `registry.py` | Base class `SparseAttentionMethod` with `calibration_params`, `target_sparse_ratio`, `set_calibration_mode()`. |

### Kernel Backends (`modelopt/torch/sparsity/attention_sparsity/kernels/`)

| File | Role |
|------|------|
| `diffusers_triton_attention.py` | Registers `modelopt_triton` backend in diffusers. Handles calibration mode (→ `attention_calibrate`) and inference mode (→ `attention` with `scale_factor/seq_k` or `raw_threshold`). Runtime sparsity counter accumulation. |
| `ltx_triton_attention.py` | Patches `ltx_core.Attention` modules for Triton dispatch. Same calibration/inference modes. |
| `hf_triton_attention.py` | HuggingFace `attn_implementation="modelopt_triton"` backend for LLMs. |

### Calibration (`modelopt/torch/sparsity/attention_sparsity/calibration/`)

| File | Role |
|------|------|
| `calibrate.py` | Orchestrates calibration. Skips RULER dataset when user provides `forward_loop` (diffusion models). Applies fitted (a, b) to all modules. |
| `calibrator.py` | `DynamicThresholdCalibrator`: collects (scale_factor, sparsity) pairs via Triton calibration kernel, fits exponential model `scale_factor = a * exp(b * sparsity)`. Supports `fit_logspace=True` for log-space fitting (recommended for diffusion models). |

### Config & Conversion

| File | Role |
|------|------|
| `config.py` | `SparseAttentionAttributeConfig` with `skip_softmax_threshold`, `skip_softmax_raw_threshold`, calibration settings. `CalibrationConfig` with `fit_logspace` field. |
| `conversion.py` | `_register_diffusers_backends_if_needed()` auto-registers Triton backends on `sparsify()`. |
| `sparse_attention.py` | `SparseAttentionModule` wrapper — delegates to method's `get_sparse_context()`. |

## Threshold Modes

| Mode | How threshold reaches the kernel | Use case |
|------|----------------------------------|----------|
| **Raw threshold** (`--raw-threshold -0.7`) | Passed directly as `skip_threshold_log2` — no conversion | Quick testing, sweeps |
| **Calibrated** (`--calibrate --target-sparsity 0.5`) | `scale_factor = a * exp(b * target)`, then backend computes `threshold = scale_factor / seq_k`, then kernel converts `log2(threshold) * sm_scale` | Production use with automatic seqlen adaptation |
| **Static lambda** (default `skip_softmax_threshold=0.1`) | `log2(lambda) * sm_scale` | Fallback when neither raw nor calibrated |

## Known Issues

- **14B dual transformer calibration**: Transformers are calibrated sequentially — transformer_2's calibration runs while transformer_1 is already sparsified, introducing asymmetric calibration conditions.
- **Minimum achievable sparsity**: Even the strictest threshold may yield 30-40% sparsity on diffusion models (many tiles are inherently negligible). Targets below this floor cause extrapolation; an inference-time warning is emitted.
Loading
Loading