Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 52 additions & 10 deletions docs/source/models/visual-generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ TensorRT-LLM **VisualGen** provides a unified inference stack for diffusion mode
| `black-forest-labs/FLUX.2-dev` | Text-to-Image |
| `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` | Text-to-Video |
| `Wan-AI/Wan2.1-T2V-14B-Diffusers` | Text-to-Video |
| `FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers` | Text-to-Video (VSA) |
| `Wan-AI/Wan2.1-I2V-14B-480P-Diffusers` | Image-to-Video |
| `Wan-AI/Wan2.1-I2V-14B-720P-Diffusers` | Image-to-Video |
| `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | Text-to-Video |
Expand All @@ -43,19 +44,22 @@ Models are auto-detected from the checkpoint directory. Diffusers-format models

### Feature Matrix

| Model | FP8 blockwise | NVFP4 | TeaCache | Cache-DiT | CFG Parallelism | Ulysses Parallelism | Parallel VAE | CUDA Graph | torch.compile | trtllm-serve | Attention2D | Ring Attention | Tensor Parallelism |
|---|---|---|---|---|---|---|---|---|---|---|--|--|--|
| **FLUX.1** | Yes | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes |
| **FLUX.2** | Yes | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes |
| **Wan 2.1** | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
| **Wan 2.2** | Yes | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
| **LTX-2** | Yes | Yes | No | Yes | Yes | Yes | No | No | Yes | Yes | Yes | Yes | No |
| **Qwen-Image** [^2] | Yes | Yes | No | No | No | Yes | No | Yes | Yes | Yes | Yes | Yes | No |
| **Cosmos3** | Yes | Yes | No | No | Yes | Yes | Yes | Yes | Yes | Yes | No | No | Yes |
| Model | FP8 blockwise | NVFP4 | TeaCache | Cache-DiT | CFG Parallelism | Ulysses Parallelism | Parallel VAE | CUDA Graph | torch.compile | trtllm-serve | Attention2D | Ring Attention | Tensor Parallelism | VSA |
|---|---|---|---|---|---|---|---|---|---|---|--|--|--|--|
| **FLUX.1** | Yes | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | No |
| **FLUX.2** | Yes | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | No |
| **Wan 2.1** | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No |
| **Wan 2.1 VSA** [^3] | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | No | Yes | Yes |
| **Wan 2.2** | Yes | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No |
| **LTX-2** | Yes | Yes | No | Yes | Yes | Yes | No | No | Yes | Yes | Yes | Yes | No | No |
| **Qwen-Image** [^2] | Yes | Yes | No | No | No | Yes | No | Yes | Yes | Yes | Yes | Yes | No | No |
| **Cosmos3** | Yes | Yes | No | No | Yes | Yes | Yes | Yes | Yes | Yes | No | No | Yes | No |

[^1]: FLUX models use embedded guidance and do not have a separate negative prompt path, so CFG parallelism is not applicable.

[^2]: Qwen-Image ships a native BF16 implementation with per-module numerical parity vs `diffusers.QwenImagePipeline` (cosine >= 0.999 on the full 20B transformer) and `trtllm-serve` / `/v1/images/generations` support. FP8 blockwise and NVFP4 use VisualGen dynamic quantization from BF16 checkpoints; no pre-quantized checkpoint is required.
[^2]: Qwen-Image ships a native BF16 implementation with per-module numerical parity vs `diffusers.QwenImagePipeline` (cosine >= 0.999 on the full 20B transformer) and `trtllm-serve` / `/v1/images/generations` support. FP8 blockwise and NVFP4 use VisualGen dynamic quantization from BF16 checkpoints; no pre-quantized checkpoint is required.

[^3]: `FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers` — VSA-fine-tuned checkpoint with learned sparse-attention gates. Requires `CUTEDSL` on Blackwell sm_100+ (falls back to dense SDPA on older hardware). Ring and Attention2D not supported (no LSE output); Ulysses supported.

## Quick Start

Expand Down Expand Up @@ -222,6 +226,44 @@ args = VisualGenArgs(

**Wan 2.2 dual-transformer note:** Wan 2.2 uses two expert transformers (high-noise and low-noise stacks). All `CacheDiTConfig` parameters apply to both stacks, except `max_warmup_steps` and `max_cached_steps`: the low-noise stack always uses fixed internal caps (`max_warmup_steps=2`, `max_cached_steps=20`) regardless of user config.

### Video Sparse Attention (VSA)

VSA reduces the compute cost of self-attention in video diffusion models by selectively attending to only the most relevant spatial-temporal blocks. It uses a two-branch design: a lightweight coarse mean-pool branch computes block-level attention scores to identify the top-K most relevant token blocks, then a fine branch runs a block-sparse CuTe kernel over only those blocks. The two outputs are blended with learned gates.

**Requirements:**
- VSA-fine-tuned checkpoint: [`FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers`](https://huggingface.co/FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers). Standard Wan checkpoints do not have the learned VSA gates.
- Blackwell GPU (sm_100+) for the CuTe JIT kernel. Falls back to dense SDPA on older hardware with no accuracy loss.
- `CUTEDSL` attention backend.
- Not compatible with Ring attention or Attention2D (VSA does not produce per-split LSE). Ulysses is supported.

**`vsa_sparsity`** controls the fraction of K/V blocks skipped in the fine branch (0.0 = dense, 0.9 = 90% blocks skipped). Higher sparsity gives more speedup at the cost of some quality.

Python API:

```python
from tensorrt_llm import VisualGenArgs
from tensorrt_llm.visual_gen.args import AttentionConfig, VideoSparseAttentionConfig

args = VisualGenArgs(
model="FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers",
attention_config=AttentionConfig(
backend="CUTEDSL",
sparse_attention_config=VideoSparseAttentionConfig(vsa_sparsity=0.9),
),
)
```

YAML (for use with `--visual_gen_args` or `trtllm-serve`):

```yaml
attention_config:
backend: CUTEDSL
sparse_attention_config:
algorithm: vsa
vsa_sparsity: 0.90
```


### Multi-GPU Parallelism

Configured under `VisualGenArgs.parallel_config`. Modes can be combined:
Expand Down
16 changes: 15 additions & 1 deletion tensorrt_llm/_torch/visual_gen/attention_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,15 @@
simplified metadata that doesn't require KV caching.
"""

from .cute_dsl import CuTeDSLAttention
from .cute_dsl import (
VSA_TILE_SIZE,
CuTeDSLAttention,
VSAAttention,
VSAMetadata,
VSAMetadataBuilder,
get_vsa_forward_context,
set_vsa_forward_context,
)
from .flash_attn4 import FlashAttn4Attention
from .interface import AttentionBackend, AttentionTensorLayout
from .parallel import Attention2DAttention, RingAttention, UlyssesAttention, wrap_parallel_attention
Expand All @@ -35,11 +43,17 @@
"get_visual_gen_attention_backend",
"create_attention",
"CuTeDSLAttention",
"VSAAttention",
"FlashAttn4Attention",
"TrtllmAttention",
"TrtllmAttentionMetadata",
"UlyssesAttention",
"VanillaAttention",
"RingAttention",
"wrap_parallel_attention",
"VSAMetadata",
"VSAMetadataBuilder",
"VSA_TILE_SIZE",
"get_vsa_forward_context",
"set_vsa_forward_context",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
CuTe DSL attention backend family for visual generation models.

fmha.py — CuTeDSLAttention (dense cubin path, head_dim=128)
vsa.py — VSAAttention (Video Sparse Attention, CuTe JIT + SDPA fallback)
"""

from .fmha import CuTeDSLAttention, _cute_dsl_import_error
from .vsa import (
VSA_KERNEL_MAX_CUBES,
VSA_TILE_SIZE,
VSAAttention,
VSAMetadata,
VSAMetadataBuilder,
VSAPreprocessor,
get_vsa_forward_context,
set_vsa_forward_context,
)

__all__ = [
"CuTeDSLAttention",
"VSAAttention",
"VSAMetadata",
"VSAMetadataBuilder",
"VSAPreprocessor",
"VSA_TILE_SIZE",
"VSA_KERNEL_MAX_CUBES",
"set_vsa_forward_context",
"get_vsa_forward_context",
"_cute_dsl_import_error",
]
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
CuTe DSL (NVIDIA kernels) Backend for Visual Generation Models
CuTe DSL (NVIDIA kernels) Dense FMHA Backend for Visual Generation Models

Uses pre-compiled cubins derived from CUTLASS CuTe DSL FMHA.
Expects NHD layout ([B, S, H, D]) and supports float16/bfloat16.
For the VSA sparse path use VSAAttention in vsa.py.
"""

import math
Expand All @@ -26,8 +27,8 @@

from tensorrt_llm.visual_gen.args import QuantAttentionConfig

from ...attention_backend.interface import PredefinedAttentionMask
from .interface import AttentionBackend, AttentionTensorLayout
from ....attention_backend.interface import PredefinedAttentionMask
from ..interface import AttentionBackend, AttentionTensorLayout

_cute_dsl_import_error = None
try:
Expand Down
Loading
Loading