Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions docs/source/models/visual-generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,18 @@ Models are auto-detected from the checkpoint directory. Diffusers-format models
| **FLUX.1** | Yes | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes |
| **FLUX.2** | Yes | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes |
| **Wan 2.1** | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
| **Wan 2.2** | Yes | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
| **LTX-2** | Yes | Yes | No | Yes | Yes | Yes | No | No | Yes | Yes | Yes | Yes | No |
| **Qwen-Image** [^2] | Yes | Yes | No | No | No | Yes | No | Yes | Yes | Yes | Yes | Yes | No |
| **Wan 2.2** | Yes | Yes | Yes [^2] | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
| **LTX-2** | Yes | Yes | Yes [^3] | Yes | Yes | Yes | No | No | Yes | Yes | Yes | Yes | No |
| **Qwen-Image** [^4] | Yes | Yes | No | No | No | Yes | No | Yes | Yes | Yes | Yes | Yes | No |
| **Cosmos3** | Yes | Yes | No | No | Yes | Yes | Yes | Yes | Yes | Yes | No | No | Yes |

[^1]: FLUX models use embedded guidance and do not have a separate negative prompt path, so CFG parallelism is not applicable.

[^2]: Qwen-Image ships a native BF16 implementation with per-module numerical parity vs `diffusers.QwenImagePipeline` (cosine >= 0.999 on the full 20B transformer) and `trtllm-serve` / `/v1/images/generations` support. FP8 blockwise and NVFP4 use VisualGen dynamic quantization from BF16 checkpoints; no pre-quantized checkpoint is required.
[^2]: Wan 2.2 has two stage transformers; TeaCache requires explicit `teacache.coefficients` (high-noise) and `teacache.coefficients_2` (low-noise). There is no built-in coefficient table for Wan 2.2.

[^3]: LTX-2 has no built-in TeaCache coefficient table in TRT-LLM; set `teacache.coefficients` explicitly when enabling TeaCache.

[^4]: Qwen-Image ships a native BF16 implementation with per-module numerical parity vs `diffusers.QwenImagePipeline` (cosine >= 0.999 on the full 20B transformer) and `trtllm-serve` / `/v1/images/generations` support. FP8 blockwise and NVFP4 use VisualGen dynamic quantization from BF16 checkpoints; no pre-quantized checkpoint is required.

## Quick Start

Expand Down
67 changes: 35 additions & 32 deletions tensorrt_llm/_torch/visual_gen/cache/teacache_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from __future__ import annotations

from typing import Any, Dict, Optional
from typing import Any, Dict, List, Tuple

import torch.nn as nn

Expand All @@ -17,52 +17,55 @@


class TeaCacheAccelerator(CacheAccelerator):
"""Whole-transformer skip using existing :class:`TeaCacheBackend`."""
"""Whole-transformer step skipping via :class:`TeaCacheBackend`."""

def __init__(self, teacache_cfg: TeaCacheConfig):
def __init__(self, pipeline: Any, teacache_cfg: TeaCacheConfig):
self._pipeline = pipeline
self._cfg = teacache_cfg
self._backend: Optional[TeaCacheBackend] = None
self._module: Optional[nn.Module] = None
self._backends: List[Tuple[nn.Module, TeaCacheBackend]] = []

def wrap(self, *args: Any, **kwargs: Any) -> None:
"""Enable TeaCache on ``model``.

Coefficient matching is done on the pipeline via
:meth:`BasePipeline._apply_teacache_coefficients` before this runs.

Keyword args: ``model`` (required).
"""
model: Optional[nn.Module] = kwargs.get("model")
if model is None:
if self._backends:
return
if self._backend is not None:
transformer = getattr(self._pipeline, "transformer", None)
if transformer is None:
return

logger.info("TeaCache: Initializing...")
self._backend = TeaCacheBackend(self._cfg)
self._backend.enable(model)
self._module = model
transformer_2 = getattr(self._pipeline, "transformer_2", None)
if transformer_2 is not None and self._cfg.coefficients_2 is not None:
cfg_low = self._cfg.model_copy(update={"coefficients": self._cfg.coefficients_2})
logger.info("TeaCache: Initializing (high-noise transformer)...")
backend_high = TeaCacheBackend(self._cfg)
backend_high.enable(transformer)
logger.info("TeaCache: Initializing (low-noise transformer_2)...")
backend_low = TeaCacheBackend(cfg_low)
backend_low.enable(transformer_2)
self._backends = [(transformer, backend_high), (transformer_2, backend_low)]
else:
logger.info("TeaCache: Initializing...")
backend = TeaCacheBackend(self._cfg)
backend.enable(transformer)
self._backends = [(transformer, backend)]

def unwrap(self) -> None:
if self._backend is None or self._module is None:
return
self._backend.disable(self._module)
self._backend = None
self._module = None
for module, backend in self._backends:
backend.disable(module)
self._backends = []

def refresh(self, num_inference_steps: int) -> None:
if self._backend:
self._backend.refresh(num_inference_steps)
for _, backend in self._backends:
backend.refresh(num_inference_steps)

def get_stats(self) -> Dict[str, Any]:
if not self._backend:
if not self._backends:
return {}
return self._backend.get_stats() or {}
if len(self._backends) == 1:
return self._backends[0][1].get_stats() or {}
return {f"transformer_{i}": b.get_stats() for i, (_, b) in enumerate(self._backends)}

def is_enabled(self) -> bool:
return bool(self._backend and self._backend.is_enabled())
return bool(self._backends and any(b.is_enabled() for _, b in self._backends))

@property
def tea_cache_backend(self) -> Optional[TeaCacheBackend]:
"""The wrapped TeaCache hook owner (for introspection or advanced use)."""
return self._backend
def backends(self) -> List[Tuple[nn.Module, TeaCacheBackend]]:
return self._backends
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ def post_load_weights(self) -> None:
)

# TeaCache or Cache-DiT
self._setup_cache_acceleration(self.transformer, FLUX_TEACACHE_COEFFICIENTS)
self._apply_teacache_coefficients(FLUX_TEACACHE_COEFFICIENTS)
self._setup_cache_acceleration()

@property
def default_generation_params(self):
Expand Down
47 changes: 24 additions & 23 deletions tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,33 +302,34 @@ def load_weights(self, weights: dict) -> None:
self.transformer.eval()

def post_load_weights(self) -> None:
"""Post-load setup: TeaCache registration."""
"""Post-load setup: cache acceleration (TeaCache or Cache-DiT)."""
super().post_load_weights()
if self.transformer is not None:
# Register TeaCache extractor for FLUX.2 (must be after device placement)
# Only set guidance_param_name for variants with guidance_embeds
guidance_param = "guidance" if self.transformer.guidance_embeds else None
forward_params = [
"hidden_states",
"encoder_hidden_states",
"timestep",
"img_ids",
"txt_ids",
"guidance",
"return_dict",
]
register_extractor_from_config(
ExtractorConfig(
model_class_name="Flux2Transformer2DModel",
timestep_embed_fn=self._compute_flux2_timestep_embedding,
guidance_param_name=guidance_param,
forward_params=forward_params,
return_dict_default=False,
if self.pipeline_config.cache_backend == "teacache":
# Register TeaCache extractor for FLUX.2 (must be after device placement)
# Only set guidance_param_name for variants with guidance_embeds
guidance_param = "guidance" if self.transformer.guidance_embeds else None
forward_params = [
"hidden_states",
"encoder_hidden_states",
"timestep",
"img_ids",
"txt_ids",
"guidance",
"return_dict",
]
register_extractor_from_config(
ExtractorConfig(
model_class_name="Flux2Transformer2DModel",
timestep_embed_fn=self._compute_flux2_timestep_embedding,
guidance_param_name=guidance_param,
forward_params=forward_params,
return_dict_default=False,
)
)
)

# TeaCache or Cache-DiT
self._setup_cache_acceleration(self.transformer, FLUX2_TEACACHE_COEFFICIENTS)
self._apply_teacache_coefficients(FLUX2_TEACACHE_COEFFICIENTS)
self._setup_cache_acceleration()

@property
def default_generation_params(self):
Expand Down
23 changes: 14 additions & 9 deletions tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizerFast

from tensorrt_llm._torch.utils import make_weak_ref
from tensorrt_llm._torch.visual_gen.cache.teacache import CacheContext
from tensorrt_llm._torch.visual_gen.cache.teacache import CacheContext, register_extractor
from tensorrt_llm._torch.visual_gen.checkpoints.prefetch import prefetch_files_to_host_cache
from tensorrt_llm._torch.visual_gen.cuda_graph_runner import CUDAGraphRunner, CUDAGraphRunnerConfig
from tensorrt_llm._torch.visual_gen.output import CudaPhaseTimer, PipelineOutput
Expand Down Expand Up @@ -1015,17 +1015,22 @@ def post_load_weights(self) -> None:
"""Finalize after weight loading: TeaCache, Cache-DiT, derived attributes."""
super().post_load_weights()

# TODO: TeaCache disabled: LTX2_TEACACHE_COEFFICIENTS are unverified.
# To re-enable, uncomment the following lines and verify coefficients.
# register_extractor(
# "LTXModel",
# LTX2TeaCacheExtractor(self._compute_ltx2_timestep_embedding),
# )
# self._setup_teacache(self.transformer, coefficients=LTX2_TEACACHE_COEFFICIENTS)
# LTX-2: single transformer (one DiT for video+audio); TeaCache only with explicit coefficients.
if self.transformer is not None and self.pipeline_config.cache_backend == "teacache":
if self.pipeline_config.teacache.coefficients is None:
raise ValueError(
"TeaCache on LTX-2 requires explicit teacache.coefficients "
"(no built-in coefficient table)."
)
register_extractor(
"LTXModel",
LTX2TeaCacheExtractor(self._compute_ltx2_timestep_embedding),
)
self._setup_cache_acceleration()

# Cache-DiT
if self.transformer is not None and self.pipeline_config.cache_backend == "cache_dit":
self._setup_cache_acceleration(self.transformer, coefficients=None)
self._setup_cache_acceleration()

# Compression ratios from native scale factors
self.vae_spatial_compression_ratio = VIDEO_SCALE_FACTORS.width
Expand Down
32 changes: 18 additions & 14 deletions tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,6 @@ def __init__(self, pipeline_config):
self.is_wan22_14b = self.boundary_ratio is not None
self.is_wan22_5b = self.expand_timesteps

# Validate TeaCache compatibility before allocating GPU memory
if (self.is_wan22_14b or self.is_wan22_5b) and pipeline_config.cache_backend == "teacache":
raise ValueError(
"TeaCache is not supported for Wan 2.2 models. "
"Use cache_backend='none' or 'cache_dit' (not 'teacache')."
)

# Fixed latent for reproducible benchmarking (e.g. MLPerf).
# Set TRTLLM_VIDEO_FIXED_LATENT_PATH to a .pt file containing a pre-sampled
# noise tensor; it will be used in place of freshly sampled random latents for
Expand Down Expand Up @@ -331,20 +324,31 @@ def post_load_weights(self) -> None:
)

if not self.is_wan22_14b:
self._setup_cache_acceleration(
self.transformer, coefficients=WAN_TEACACHE_COEFFICIENTS
)
self.transformer_cache_backend = self.cache_accelerator
self._apply_teacache_coefficients(WAN_TEACACHE_COEFFICIENTS)
self._setup_cache_acceleration()
else:
if self.pipeline_config.cache_backend == "cache_dit":
self._setup_cache_acceleration(self.transformer, coefficients=None)
# TeaCache is not supported for Wan 2.2 unless using Cache-DiT.
self.transformer_cache_backend = self.cache_accelerator
self._setup_cache_acceleration()

if self.transformer_2 is not None:
if hasattr(self.transformer_2, "post_load_weights"):
self.transformer_2.post_load_weights()

# Wan 2.2 TeaCache after both transformers' post_load_weights (FP8 scales, etc.)
if (
self.transformer is not None
and self.transformer_2 is not None
and self.pipeline_config.cache_backend == "teacache"
):
tc = self.pipeline_config.teacache
if tc.coefficients is None or tc.coefficients_2 is None:
raise ValueError(
"Wan 2.2 TeaCache requires explicit teacache.coefficients and "
"teacache.coefficients_2 (high-noise and low-noise stage polynomials). "
"There is no built-in coefficient table for Wan 2.2."
)
self._setup_cache_acceleration()

def _run_warmup(self, height: int, width: int, num_frames: int, steps: int) -> None:
with torch.no_grad():
self.forward(
Expand Down
30 changes: 17 additions & 13 deletions tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,6 @@ def __init__(self, pipeline_config):
)
self.is_wan22_14b = self.boundary_ratio is not None

# Validate TeaCache compatibility before allocating GPU memory
if self.is_wan22_14b and pipeline_config.cache_backend == "teacache":
raise ValueError(
"TeaCache is not supported for Wan 2.2 models. "
"Use cache_backend='none' or 'cache_dit' (not 'teacache')."
)

super().__init__(pipeline_config)

def _compute_wan_timestep_embedding(self, module, timestep=None, **kwargs):
Expand Down Expand Up @@ -338,19 +331,30 @@ def post_load_weights(self) -> None:
)

if not self.is_wan22_14b:
self._setup_cache_acceleration(
self.transformer, coefficients=WAN_I2V_TEACACHE_COEFFICIENTS
)
self.transformer_cache_backend = self.cache_accelerator
self._apply_teacache_coefficients(WAN_I2V_TEACACHE_COEFFICIENTS)
self._setup_cache_acceleration()
else:
if self.pipeline_config.cache_backend == "cache_dit":
self._setup_cache_acceleration(self.transformer, coefficients=None)
self.transformer_cache_backend = self.cache_accelerator
self._setup_cache_acceleration()

if self.transformer_2 is not None:
if hasattr(self.transformer_2, "post_load_weights"):
self.transformer_2.post_load_weights()

if (
self.transformer is not None
and self.transformer_2 is not None
and self.pipeline_config.cache_backend == "teacache"
):
tc = self.pipeline_config.teacache
if tc.coefficients is None or tc.coefficients_2 is None:
raise ValueError(
"Wan 2.2 TeaCache requires explicit teacache.coefficients and "
"teacache.coefficients_2 (high-noise and low-noise stage polynomials). "
"There is no built-in coefficient table for Wan 2.2."
)
self._setup_cache_acceleration()

def _run_warmup(self, height: int, width: int, num_frames: int, steps: int) -> None:
dummy_image = PIL.Image.new("RGB", (width, height))
with torch.no_grad():
Expand Down
Loading
Loading