-
Notifications
You must be signed in to change notification settings - Fork 353
Add the Skip softmax for diffusion #1166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jingyu-ml
wants to merge
29
commits into
main
Choose a base branch
from
jingyux/diffusion-skip-softmax
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
1d8ac33
Add the Skip softmax diffusion
jingyu-ml 1f8f0d3
Add test case
jingyu-ml 5873652
Fixed error
jingyu-ml 4c179a3
Fixed the test case
jingyu-ml 2c323df
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml 8702b7b
Removed the token import
jingyu-ml bbe2123
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml 70099a5
removed the unused code
jingyu-ml 6cc96a4
Update the README
jingyu-ml 4de0d3b
Updated the example script
jingyu-ml b3d3d4d
Update the readme
jingyu-ml 8dc6162
Update the calibration kernel
jingyu-ml 8aa32cc
ADd the readme
jingyu-ml fbeabcf
Update the example script
jingyu-ml 6a4ab8b
Update the code
jingyu-ml d7dd15c
Update the calibration loop
jingyu-ml b86d311
Remove the eager attention
jingyu-ml f5a9af9
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml 45bcad6
Update the calibration, fixed some bugs
jingyu-ml 22c5b85
Add the test case
jingyu-ml aa44a9d
Fixed the lint error
jingyu-ml e5293de
Updated the README
jingyu-ml 40fdd44
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml 40d61dd
Update the test case
jingyu-ml 3845b47
Fixed the CICD
jingyu-ml 560015c
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml f86580c
Added the ltx2 warning
jingyu-ml ee162b3
addressed the ltx2 issue and the import issue
jingyu-ml eef0577
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| # Skip-Softmax Sparse Attention for Diffusion Models | ||
|
|
||
| > [!WARNING] | ||
| > **Third-Party License Notice — LTX-2** | ||
| > | ||
| > LTX-2 packages (`ltx-core`, `ltx-pipelines`, `ltx-trainer`) are third-party dependencies | ||
| > developed and provided by [Lightricks](https://github.com/Lightricks/LTX-2). They are | ||
| > **NOT** covered by the Apache 2.0 license governing NVIDIA Model Optimizer. | ||
| > | ||
| > You **MUST** comply with the | ||
| > [LTX Community License Agreement](https://github.com/Lightricks/LTX-2/blob/main/LICENSE) | ||
| > when installing and using LTX-2 with NVIDIA Model Optimizer. Any derivative models or | ||
| > fine-tuned weights produced from LTX-2 (including quantized, distilled, or sparsified | ||
| > checkpoints) remain subject to the LTX Community License Agreement, not Apache 2.0. | ||
|
|
||
| Skip-softmax sparse attention (BLASST, <https://arxiv.org/pdf/2512.12087>) skips KV | ||
| tiles whose attention scores are negligible during the FlashAttention computation, | ||
| reducing FLOPs without retraining. | ||
|
|
||
| Two modes are supported: | ||
| - **Fixed raw threshold** — pass a log2-space threshold directly to the Triton | ||
| kernel. No calibration needed. Good for quick testing and sweeps. | ||
| - **Calibrated threshold** — an exponential model | ||
| (`scale_factor = a * exp(b * target_sparsity)`) is calibrated once via the | ||
| Triton calibration kernel, then the target sparsity can be adjusted at runtime | ||
| without recalibration. Log-space fitting (`fit_logspace=True`) is recommended | ||
| for diffusion models where scale_factors span many orders of magnitude. | ||
|
|
||
| ## Supported Models | ||
|
|
||
| | Model | Script | Notes | | ||
| |-------|--------|-------| | ||
| | WAN 2.2 5B | `wan22_skip_softmax.py` | Single transformer, self-attention only | | ||
| | WAN 2.2 14B | `wan22_skip_softmax.py` | Dual transformer (auto-detected) | | ||
| | LTX-2 | (coming soon) | Via `ltx_triton_attention.py` backend | | ||
|
|
||
| ## Quick Start | ||
|
|
||
| ```bash | ||
| # Fixed raw threshold (no calibration, fast) | ||
| python wan22_skip_softmax.py \ | ||
| --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ | ||
| --raw-threshold -0.7 \ | ||
| --prompt "A cat playing piano" --output out.mp4 | ||
|
|
||
| # With calibration | ||
| python wan22_skip_softmax.py \ | ||
| --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ | ||
| --calibrate --target-sparsity 0.5 \ | ||
| --prompt "A cat playing piano" --output out.mp4 | ||
|
|
||
| # Dense baseline (no sparsity, for comparison) | ||
| python wan22_skip_softmax.py \ | ||
| --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ | ||
| --baseline \ | ||
| --prompt "A cat playing piano" --output baseline.mp4 | ||
|
|
||
| # Report runtime sparsity (per-layer tile skip ratios) | ||
| python wan22_skip_softmax.py \ | ||
| --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ | ||
| --raw-threshold -0.7 --report-avg-sparsity \ | ||
| --prompt "A cat playing piano" --output out.mp4 | ||
| ``` | ||
|
|
||
| ## Architecture | ||
|
|
||
| ### Inference Path (Triton kernel with tile skipping) | ||
|
|
||
| ```text | ||
| SparseAttentionModule.forward() | ||
| └─ triton_skip_softmax._triton_inference_context() | ||
| ├─ Priority: raw_threshold > scale_factor (calibrated) > static threshold | ||
| ├─ _set_triton_backends(raw_threshold=X) or (scale_factor=X) | ||
| ├─ attention_backend("modelopt_triton") | ||
| └─ _diffusers_triton_attention() → attention() | ||
| └─ _attn_fwd kernel: skip tiles where tile_row_max < row_max + threshold | ||
| ``` | ||
|
|
||
| ### Calibration Path (Triton calibration kernel) | ||
|
|
||
| ```text | ||
| mtsa.sparsify(transformer, config, forward_loop) | ||
| ├─ apply_mode() → replace attention with SparseAttentionModule | ||
| └─ calibrate() | ||
| ├─ DynamicThresholdCalibrator._set_thresholds() | ||
| │ └─ method._threshold_trials = [1e-6, ..., 9.9e-1] | ||
| ├─ forward_loop(model) | ||
| │ └─ SparseAttentionModule.forward() | ||
| │ └─ triton_skip_softmax._triton_calibration_context() | ||
| │ ├─ set_triton_skip_softmax_config(calibration_mode=True) | ||
| │ ├─ attention_backend("modelopt_triton") | ||
| │ └─ _diffusers_triton_attention() → attention_calibrate() | ||
| │ └─ _attn_fwd_calibrate kernel: | ||
| │ - Full attention (no skipping) for correct output | ||
| │ - Vectorized multi-threshold sparsity measurement | ||
| │ - Per-program output buffers (no atomic contention) | ||
| │ - Python-side reduction: sum across programs | ||
| ├─ Fit: scale_factor = a * exp(b * sparsity) | ||
| │ └─ fit_logspace=True: fits in log space (minimizes relative error) | ||
| └─ Apply a, b to all modules | ||
| └─ Inference: threshold = scale_factor / seq_k | ||
| ``` | ||
|
|
||
| ## Core Files | ||
|
|
||
| ### Triton Kernels (`modelopt/torch/kernels/`) | ||
|
|
||
| | File | Role | | ||
| |------|------| | ||
| | `triton_fa.py` | `_attn_fwd`: forward kernel with optional tile skipping + sparsity measurement. `_attn_fwd_calibrate`: calibration kernel with vectorized multi-threshold testing and per-program buffers (zero atomic contention). `attention()` and `attention_calibrate()` Python APIs. | | ||
|
|
||
| ### Sparse Attention Methods (`modelopt/torch/sparsity/attention_sparsity/methods/`) | ||
|
|
||
| | File | Role | | ||
| |------|------| | ||
| | `triton_skip_softmax.py` | Primary method for diffusion models. Calibration context → Triton calibration kernel. Inference context → Triton forward kernel. Supports `scale_factor` (calibrated), `raw_threshold` (direct), and static `skip_softmax_threshold`. | | ||
| | `flash_skip_softmax.py` | PyTorch-based method for HF LLMs (not used by diffusers/LTX). | | ||
| | `registry.py` | Base class `SparseAttentionMethod` with `calibration_params`, `target_sparse_ratio`, `set_calibration_mode()`. | | ||
|
|
||
| ### Kernel Backends (`modelopt/torch/sparsity/attention_sparsity/kernels/`) | ||
|
|
||
| | File | Role | | ||
| |------|------| | ||
| | `diffusers_triton_attention.py` | Registers `modelopt_triton` backend in diffusers. Handles calibration mode (→ `attention_calibrate`) and inference mode (→ `attention` with `scale_factor/seq_k` or `raw_threshold`). Runtime sparsity counter accumulation. | | ||
| | `ltx_triton_attention.py` | Patches `ltx_core.Attention` modules for Triton dispatch. Same calibration/inference modes. | | ||
| | `hf_triton_attention.py` | HuggingFace `attn_implementation="modelopt_triton"` backend for LLMs. | | ||
|
|
||
| ### Calibration (`modelopt/torch/sparsity/attention_sparsity/calibration/`) | ||
|
|
||
| | File | Role | | ||
| |------|------| | ||
| | `calibrate.py` | Orchestrates calibration. Skips RULER dataset when user provides `forward_loop` (diffusion models). Applies fitted (a, b) to all modules. | | ||
| | `calibrator.py` | `DynamicThresholdCalibrator`: collects (scale_factor, sparsity) pairs via Triton calibration kernel, fits exponential model `scale_factor = a * exp(b * sparsity)`. Supports `fit_logspace=True` for log-space fitting (recommended for diffusion models). | | ||
|
|
||
| ### Config & Conversion | ||
|
|
||
| | File | Role | | ||
| |------|------| | ||
| | `config.py` | `SparseAttentionAttributeConfig` with `skip_softmax_threshold`, `skip_softmax_raw_threshold`, calibration settings. `CalibrationConfig` with `fit_logspace` field. | | ||
| | `conversion.py` | `_register_diffusers_backends_if_needed()` auto-registers Triton backends on `sparsify()`. | | ||
| | `sparse_attention.py` | `SparseAttentionModule` wrapper — delegates to method's `get_sparse_context()`. | | ||
|
|
||
| ## Threshold Modes | ||
|
|
||
| | Mode | How threshold reaches the kernel | Use case | | ||
| |------|----------------------------------|----------| | ||
| | **Raw threshold** (`--raw-threshold -0.7`) | Passed directly as `skip_threshold_log2` — no conversion | Quick testing, sweeps | | ||
| | **Calibrated** (`--calibrate --target-sparsity 0.5`) | `scale_factor = a * exp(b * target)`, then backend computes `threshold = scale_factor / seq_k`, then kernel converts `log2(threshold) * sm_scale` | Production use with automatic seqlen adaptation | | ||
| | **Static lambda** (default `skip_softmax_threshold=0.1`) | `log2(lambda) * sm_scale` | Fallback when neither raw nor calibrated | | ||
|
|
||
| ## Known Issues | ||
|
|
||
| - **14B dual transformer calibration**: Transformers are calibrated sequentially — transformer_2's calibration runs while transformer_1 is already sparsified, introducing asymmetric calibration conditions. | ||
| - **Minimum achievable sparsity**: Even the strictest threshold may yield 30-40% sparsity on diffusion models (many tiles are inherently negligible). Targets below this floor cause extrapolation; an inference-time warning is emitted. | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
Architecturesection is not appropriate for README.Threshold ModesandKnown Issuescould be kept.