Commit feec81a
authored
Add the Skip softmax for diffusion (#1166)
### What does this PR do?
Type of change: new feature, new example <!-- Use one of the following:
Bug fix, new feature, new example, new tests, documentation. -->
<!-- Details about the change. -->
## Summary
- Add skip-softmax sparse attention (BLASST) for diffusion models via
dedicated Triton kernels — an inference kernel with tile skipping and a
calibration kernel with vectorized multi-threshold sparsity measurement
- Add `triton_skip_softmax` method with exponential model calibration
(`scale_factor = a * exp(b * sparsity)`) and log-space fitting for
diffusion models
- Add Triton kernel backends for diffusers and LTX attention dispatch
- Fix calibration to skip RULER dataset generation when user provides
their own `forward_loop` (required for non-LLM models)
## Changes
### Triton kernels (`modelopt/torch/kernels/triton_fa.py`)
- **`_attn_fwd`**: Forward kernel with optional tile skipping — tiles
whose max attention score is far below the running softmax max are
skipped entirely (no V load, no softmax, no accumulation). Runtime
sparsity measurement via atomic counters.
- **`_attn_fwd_calibrate`**: Calibration kernel that computes full
attention while measuring how many tiles would be skipped at each of N
thresholds simultaneously. Uses per-program output buffers (zero atomic
contention) and vectorized multi-threshold comparison.
- **`attention()`** / **`attention_calibrate()`**: Python wrappers for
inference and calibration kernels.
### Kernel backends
(`modelopt/torch/sparsity/attention_sparsity/kernels/`)
- **`diffusers_triton_attention.py`**: Registers `modelopt_triton`
backend in diffusers' attention dispatch. Handles [B, S, H, D] → varlen
layout conversion, calibration/inference mode switching, thread-local
configuration, and counter accumulation.
- **`ltx_triton_attention.py`**: Patches `ltx_core.Attention` modules
for Triton dispatch with the same calibration/inference modes.
### Method
(`modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`)
- `TritonSkipSoftmaxMethod`: Context managers for calibration (→
calibration kernel) and inference (→ forward kernel with tile skipping).
Three threshold priority levels: raw threshold > calibrated scale_factor
> static threshold.
### Calibration
(`modelopt/torch/sparsity/attention_sparsity/calibration/`)
- **`calibrator.py`**: `DynamicThresholdCalibrator` with `fit_logspace`
option — fits exponential model in log space (minimizes relative error)
for diffusion models where scale_factors span many orders of magnitude.
Records observed sparsity range for extrapolation warnings.
- **`calibrate.py`**: Skips RULER dataset when `forward_loop` is
provided; passes `fit_logspace` through from config.
### Config & conversion
- **`config.py`**: `CalibrationConfig.fit_logspace` field (default
False, recommended True for diffusion models).
`skip_softmax_raw_threshold` field for direct threshold mode.
- **`conversion.py`**: Auto-registers diffusers/LTX Triton backends on
`sparsify()`. Updated summary display.
### Example
- **`wan22_skip_softmax.py`**: End-to-end example for WAN 2.2 5B/14B
with baseline, raw-threshold, and calibrated modes. Supports runtime
sparsity reporting.
## Threshold modes
| Mode | How it works | Use case |
|------|-------------|----------|
| **Raw threshold** (`--raw-threshold -0.7`) | Passed directly to kernel
as `skip_threshold_log2` | Quick testing, sweeps |
| **Calibrated** (`--calibrate --target-sparsity 0.5`) | `scale_factor =
a * exp(b * target)`, then `threshold = scale_factor / seq_k` at runtime
| Production use with seqlen adaptation |
| **Static** (default `skip_softmax_threshold=0.1`) | `log2(lambda) *
sm_scale` | Fallback |
## Usage
```bash
# Fixed raw threshold (no calibration)
python examples/diffusers/sparsity/wan22_skip_softmax.py \
--model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
--raw-threshold -0.7 \
--prompt "A cat playing piano" --output out.mp4
# With calibration (log-space fit for diffusion models)
python examples/diffusers/sparsity/wan22_skip_softmax.py \
--model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
--calibrate --target-sparsity 0.5 \
--prompt "A cat playing piano" --output out.mp4
# Dense baseline for comparison
python examples/diffusers/sparsity/wan22_skip_softmax.py \
--model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
--baseline \
--prompt "A cat playing piano" --output baseline.mp4
```
### Before your PR is "*Ready for review*"
Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)
and your commits are signed (`git commit -s -S`).
Make sure you read and follow the [Security Best
Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors)
(e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(...,
weights_only=False)`, `pickle`, etc.).
- Is this change backward compatible?: ✅ <!--- If ❌, explain why. -->
- If you copied code from any other sources or added a new PIP
dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ <!---
Mandatory -->
- Did you write any new necessary tests?: ✅ <!--- Mandatory for new
features or examples. -->
- Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?:
❌ <!--- Only for new features, API changes, critical bug fixes or
backward incompatible changes. -->
### Additional Information
<!-- E.g. related issue. -->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
## Release Notes
* **New Features**
* Added skip-softmax sparse attention support for Diffusers models,
enabling efficient video generation
* Added support for both eager and Triton attention backends for sparse
attention
* Added new example script for Wan 2.2 text-to-video generation with
sparse attention optimization
* **Documentation**
* Updated documentation with sparse attention configuration guide and
usage examples
* **Tests**
* Added comprehensive unit tests for kernel backend registration and
skip-softmax functionality
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>1 parent 76b6fd5 commit feec81a
35 files changed
Lines changed: 4180 additions & 76 deletions
File tree
- .github
- workflows
- examples/diffusers
- sparsity
- modelopt/torch
- export
- kernels
- quantization/src/conv
- sparsity/attention_sparsity
- calibration
- kernels
- methods
- plugins
- tests
- _test_utils/torch
- examples/diffusers_sparsity
- gpu/torch/sparsity/attention_sparsity
- unit/torch
- kernels
- sparsity/attention_sparsity
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
11 | 11 | | |
12 | 12 | | |
13 | 13 | | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
63 | 63 | | |
64 | 64 | | |
65 | 65 | | |
66 | | - | |
| 66 | + | |
67 | 67 | | |
68 | 68 | | |
69 | 69 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
13 | 13 | | |
14 | 14 | | |
15 | 15 | | |
| 16 | + | |
16 | 17 | | |
17 | 18 | | |
18 | 19 | | |
| |||
290 | 291 | | |
291 | 292 | | |
292 | 293 | | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
293 | 347 | | |
294 | 348 | | |
295 | 349 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
0 commit comments