Skip to content

Commit 37625ce

Browse files
committed
teacache for previously disabled pipelines; update examples
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
1 parent 5554be2 commit 37625ce

13 files changed

Lines changed: 402 additions & 47 deletions

File tree

docs/source/models/visual-generation.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,15 @@ Models are auto-detected from the checkpoint directory. Diffusers-format models
4141
| **FLUX.1** | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes |
4242
| **FLUX.2** | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes |
4343
| **Wan 2.1** | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
44-
| **Wan 2.2** | Yes | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes |
45-
| **LTX-2** | Yes | Yes | No | Yes | Yes | No | No | Yes | Yes |
44+
| **Wan 2.2** | Yes | Yes | Yes [^2] | Yes | Yes | Yes | Yes | Yes | Yes |
45+
| **LTX-2** | Yes | Yes | Yes [^3] | Yes | Yes | No | No | Yes | Yes |
4646

4747
[^1]: FLUX models use embedded guidance and do not have a separate negative prompt path, so CFG parallelism is not applicable.
4848

49+
[^2]: Wan 2.2 has two stage transformers; TeaCache requires explicit `teacache.coefficients` (high-noise) and `teacache.coefficients_2` (low-noise). There is no built-in coefficient table for Wan 2.2.
50+
51+
[^3]: LTX-2 has no built-in TeaCache coefficient table in TRT-LLM; set `teacache.coefficients` explicitly when enabling TeaCache.
52+
4953
## Quick Start
5054

5155
Here is a simple example to generate a video with Wan 2.1:
@@ -109,7 +113,7 @@ args = VisualGenArgs(
109113

110114
### TeaCache
111115

112-
TeaCache caches transformer outputs when timestep embeddings change slowly between denoising steps, skipping redundant computation. Enable with `teacache.enable_teacache: true` (YAML config). The `teacache_thresh` parameter controls the similarity threshold.
116+
TeaCache caches transformer outputs when timestep embeddings change slowly between denoising steps, skipping redundant computation. Enable with `teacache.enable_teacache: true` (YAML config). The `teacache_thresh` parameter controls the similarity threshold. For Wan 2.2, set both `coefficients` and `coefficients_2` (YAML or CLI). For LTX-2, set `coefficients` when enabling TeaCache (no built-in table). Other models (e.g. FLUX.1, FLUX.2, Wan 2.1) can omit `coefficients` to use the built-in checkpoint table.
113117

114118
### Multi-GPU Parallelism
115119

examples/visual_gen/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,8 @@ python visual_gen_ltx2.py \
217217
| `--image_cond_strength` ||| 1.0 | Image conditioning strength |
218218
| `--enable_teacache` |||| False | Cache optimization |
219219
| `--teacache_thresh` |||| 0.2 | TeaCache similarity threshold |
220+
| `--teacache_coefficients` |||| *(omit)* | Optional polynomial coeffs; overrides built-in table |
221+
| `--use_ret_steps` |||| False | TeaCache retention-steps mode (WAN/FLUX tables) |
220222
| `--attention_backend` |||| VANILLA | `VANILLA`, `TRTLLM`, or `FA4` |
221223
| `--cfg_size` |||| 1 | CFG parallelism |
222224
| `--ulysses_size` |||| 1 | Sequence parallelism |

examples/visual_gen/serve/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ Before running these examples, ensure you have:
5353
```
5454
For LTX-2, you need to provide a proper text_encoder_path in `./configs/ltx2.yml`.
5555

56+
**TeaCache:** Example YAML files set `enable_teacache` and `teacache_thresh` only. Omit `coefficients` to use each pipeline’s **built-in** coefficient table (checkpoint path matching). Add `coefficients: [ ... ]` under `teacache` only when you need to override those defaults.
57+
5658
## Examples
5759

5860
Current supported & tested models:

examples/visual_gen/visual_gen_flux.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,17 @@ def parse_args():
123123
help="Use ret_steps mode for TeaCache. "
124124
"Using Retention Steps will result in faster generation speed and better generation quality.",
125125
)
126+
parser.add_argument(
127+
"--teacache_coefficients",
128+
nargs="+",
129+
type=float,
130+
default=None,
131+
metavar="FLOAT",
132+
help=(
133+
"Optional TeaCache polynomial coefficients (overrides checkpoint table). "
134+
"Example: --teacache_coefficients 1.0 0.0 0.5"
135+
),
136+
)
126137

127138
# Quantization
128139
parser.add_argument(
@@ -222,6 +233,11 @@ def build_diffusion_args(args) -> VisualGenArgs:
222233
else {}
223234
),
224235
"use_ret_steps": args.use_ret_steps,
236+
**(
237+
{"coefficients": list(args.teacache_coefficients)}
238+
if args.teacache_coefficients is not None
239+
else {}
240+
),
225241
},
226242
parallel={
227243
"dit_ulysses_size": args.ulysses_size,

examples/visual_gen/visual_gen_wan_i2v.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,29 @@ def parse_args():
9797
help="Use ret_steps mode for TeaCache. "
9898
"Using Retention Steps will result in faster generation speed and better generation quality.",
9999
)
100+
parser.add_argument(
101+
"--teacache_coefficients",
102+
nargs="+",
103+
type=float,
104+
default=None,
105+
metavar="FLOAT",
106+
help=(
107+
"Optional TeaCache polynomial coefficients (overrides checkpoint table). "
108+
"Example: --teacache_coefficients 1.0 0.0 0.5"
109+
),
110+
)
111+
parser.add_argument(
112+
"--teacache_coefficients_2",
113+
nargs="+",
114+
type=float,
115+
default=None,
116+
metavar="FLOAT",
117+
help=(
118+
"Second polynomial for Wan 2.2 low-noise transformer_2 (requires "
119+
"--teacache_coefficients for the high-noise transformer). "
120+
"Ignored for Wan 2.1."
121+
),
122+
)
100123

101124
# Quantization
102125
parser.add_argument(
@@ -182,6 +205,16 @@ def main():
182205
"enable_teacache": args.enable_teacache,
183206
"teacache_thresh": args.teacache_thresh,
184207
"use_ret_steps": args.use_ret_steps,
208+
**(
209+
{"coefficients": list(args.teacache_coefficients)}
210+
if args.teacache_coefficients is not None
211+
else {}
212+
),
213+
**(
214+
{"coefficients_2": list(args.teacache_coefficients_2)}
215+
if args.teacache_coefficients_2 is not None
216+
else {}
217+
),
185218
},
186219
parallel={
187220
"dit_cfg_size": args.cfg_size,

examples/visual_gen/visual_gen_wan_t2v.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,29 @@ def parse_args():
9191
help="Use ret_steps mode for TeaCache. "
9292
"Using Retention Steps will result in faster generation speed and better generation quality.",
9393
)
94+
parser.add_argument(
95+
"--teacache_coefficients",
96+
nargs="+",
97+
type=float,
98+
default=None,
99+
metavar="FLOAT",
100+
help=(
101+
"Optional TeaCache polynomial coefficients (overrides checkpoint table). "
102+
"Example: --teacache_coefficients 1.0 0.0 0.5"
103+
),
104+
)
105+
parser.add_argument(
106+
"--teacache_coefficients_2",
107+
nargs="+",
108+
type=float,
109+
default=None,
110+
metavar="FLOAT",
111+
help=(
112+
"Second polynomial for Wan 2.2 low-noise transformer_2 (requires "
113+
"--teacache_coefficients for the high-noise transformer). "
114+
"Ignored for Wan 2.1."
115+
),
116+
)
94117

95118
# Quantization
96119
parser.add_argument(
@@ -191,6 +214,16 @@ def main():
191214
"enable_teacache": args.enable_teacache,
192215
"teacache_thresh": args.teacache_thresh,
193216
"use_ret_steps": args.use_ret_steps,
217+
**(
218+
{"coefficients": list(args.teacache_coefficients)}
219+
if args.teacache_coefficients is not None
220+
else {}
221+
),
222+
**(
223+
{"coefficients_2": list(args.teacache_coefficients_2)}
224+
if args.teacache_coefficients_2 is not None
225+
else {}
226+
),
194227
},
195228
parallel={
196229
"dit_cfg_size": args.cfg_size,

tensorrt_llm/_torch/visual_gen/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ class TeaCacheConfig(StrictBaseModel):
163163
Applied as: rescaled_distance = poly(raw_distance).
164164
None means use the pipeline built-in coefficient table (checkpoint path
165165
matching). A non-None list overrides that table.
166+
coefficients_2: Second polynomial (Wan 2.2 dual-transformer low-noise stage only).
167+
Required together with coefficients when enabling TeaCache on Wan 2.2.
166168
ret_steps: Number of warmup steps (always compute, initialized at runtime)
167169
cutoff_steps: Step to stop caching (always compute after, initialized at runtime)
168170
num_steps: Total inference steps (set at runtime)
@@ -174,6 +176,7 @@ class TeaCacheConfig(StrictBaseModel):
174176
use_ret_steps: bool = False
175177

176178
coefficients: Optional[List[float]] = None
179+
coefficients_2: Optional[List[float]] = None
177180

178181
# Runtime state fields (initialized by TeaCacheBackend.refresh)
179182
ret_steps: Optional[int] = None
@@ -191,6 +194,8 @@ def validate_teacache(self) -> "TeaCacheConfig":
191194
# Validate coefficients (when provided)
192195
if self.coefficients is not None and len(self.coefficients) == 0:
193196
raise ValueError("TeaCache coefficients list cannot be empty")
197+
if self.coefficients_2 is not None and len(self.coefficients_2) == 0:
198+
raise ValueError("TeaCache coefficients_2 list cannot be empty")
194199

195200
# Validate ret_steps if set
196201
if self.ret_steps is not None and self.ret_steps < 0:

tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def load_weights(self, weights: dict) -> None:
290290
def post_load_weights(self) -> None:
291291
"""Post-load setup: TeaCache registration."""
292292
super().post_load_weights()
293-
if self.transformer is not None:
293+
if self.transformer is not None and self.model_config.teacache.enable_teacache:
294294
# Register TeaCache extractor for FLUX.2 (must be after device placement)
295295
# Only set guidance_param_name for variants with guidance_embeds
296296
guidance_param = "guidance" if self.transformer.guidance_embeds else None
@@ -313,7 +313,6 @@ def post_load_weights(self) -> None:
313313
)
314314
)
315315

316-
# Enable TeaCache with FLUX.2-specific polynomial coefficients
317316
self._setup_teacache(self.transformer, FLUX2_TEACACHE_COEFFICIENTS)
318317

319318
def infer(self, req):

tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from tensorrt_llm._torch.visual_gen.output import MediaOutput
1818
from tensorrt_llm._torch.visual_gen.pipeline import BasePipeline
1919
from tensorrt_llm._torch.visual_gen.pipeline_registry import register_pipeline
20-
from tensorrt_llm._torch.visual_gen.teacache import CacheContext
20+
from tensorrt_llm._torch.visual_gen.teacache import CacheContext, register_extractor
2121
from tensorrt_llm._torch.visual_gen.utils import postprocess_video_tensor
2222
from tensorrt_llm.logger import logger
2323

@@ -592,13 +592,18 @@ def post_load_weights(self) -> None:
592592
"""Finalize after weight loading: TeaCache, derived attributes."""
593593
super().post_load_weights()
594594

595-
# TODO: TeaCache disabled: LTX2_TEACACHE_COEFFICIENTS are unverified.
596-
# To re-enable, uncomment the following lines and verify coefficients.
597-
# register_extractor(
598-
# "LTXModel",
599-
# LTX2TeaCacheExtractor(self._compute_ltx2_timestep_embedding),
600-
# )
601-
# self._setup_teacache(self.transformer, coefficients=LTX2_TEACACHE_COEFFICIENTS)
595+
# LTX-2: single transformer (one DiT for video+audio); TeaCache only with explicit coefficients.
596+
if self.transformer is not None and self.model_config.teacache.enable_teacache:
597+
if self.model_config.teacache.coefficients is None:
598+
raise ValueError(
599+
"TeaCache on LTX-2 requires explicit teacache.coefficients "
600+
"(no built-in coefficient table)."
601+
)
602+
register_extractor(
603+
"LTXModel",
604+
LTX2TeaCacheExtractor(self._compute_ltx2_timestep_embedding),
605+
)
606+
self._setup_teacache(self.transformer, coefficients=None)
602607

603608
# Compression ratios from native scale factors
604609
self.vae_spatial_compression_ratio = VIDEO_SCALE_FACTORS.width

tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
from tensorrt_llm._torch.visual_gen.output import MediaOutput
1313
from tensorrt_llm._torch.visual_gen.pipeline import BasePipeline
1414
from tensorrt_llm._torch.visual_gen.pipeline_registry import register_pipeline
15-
from tensorrt_llm._torch.visual_gen.teacache import ExtractorConfig, register_extractor_from_config
15+
from tensorrt_llm._torch.visual_gen.teacache import (
16+
ExtractorConfig,
17+
TeaCacheBackend,
18+
register_extractor_from_config,
19+
)
1620
from tensorrt_llm._torch.visual_gen.utils import postprocess_video_tensor
1721
from tensorrt_llm._utils import nvtx_range
1822
from tensorrt_llm.logger import logger
@@ -77,13 +81,6 @@ def __init__(self, model_config):
7781
self.boundary_ratio = getattr(model_config.pretrained_config, "boundary_ratio", None)
7882
self.is_wan22 = self.boundary_ratio is not None
7983

80-
# Validate TeaCache compatibility before allocating GPU memory
81-
if self.is_wan22 and model_config.teacache.enable_teacache:
82-
raise ValueError(
83-
"TeaCache is not supported for Wan 2.2 T2V models. "
84-
"Set enable_teacache=False in TeaCacheConfig."
85-
)
86-
8784
super().__init__(model_config)
8885

8986
def _compute_wan_timestep_embedding(self, module, timestep=None, **kwargs):
@@ -277,16 +274,41 @@ def post_load_weights(self) -> None:
277274
if not self.is_wan22:
278275
self._setup_teacache(self.transformer, coefficients=WAN_TEACACHE_COEFFICIENTS)
279276
self.transformer_cache_backend = self.cache_backend
280-
else:
281-
# TeaCache is not supported for Wan 2.2: the dual-transformer
282-
# architecture (transformer + transformer_2) requires separate
283-
# TeaCache coefficients that have not been calibrated yet.
284-
self.transformer_cache_backend = None
285277

286278
if self.transformer_2 is not None:
287279
if hasattr(self.transformer_2, "post_load_weights"):
288280
self.transformer_2.post_load_weights()
289281

282+
# Wan 2.2 TeaCache after both transformers' post_load_weights (FP8 scales, etc.)
283+
if (
284+
self.transformer is not None
285+
and self.transformer_2 is not None
286+
and self.is_wan22
287+
and self.model_config.teacache.enable_teacache
288+
):
289+
self._apply_teacache_coefficients(WAN_TEACACHE_COEFFICIENTS)
290+
tc = self.model_config.teacache
291+
if tc.coefficients is None or tc.coefficients_2 is None:
292+
raise ValueError(
293+
"Wan 2.2 TeaCache requires explicit teacache.coefficients and "
294+
"teacache.coefficients_2 (high-noise and low-noise stage polynomials). "
295+
"There is no built-in coefficient table for Wan 2.2."
296+
)
297+
cfg_high = tc.model_copy(deep=True)
298+
cfg_low = tc.model_copy(deep=True)
299+
cfg_low.coefficients = tc.coefficients_2
300+
logger.info("TeaCache: Initializing (Wan 2.2 high-noise transformer)...")
301+
self.cache_backend = TeaCacheBackend(cfg_high)
302+
self.cache_backend.enable(self.transformer)
303+
self.transformer_cache_backend = self.cache_backend
304+
logger.info("TeaCache: Initializing (Wan 2.2 low-noise transformer_2)...")
305+
self.transformer_2_cache_backend = TeaCacheBackend(cfg_low)
306+
self.transformer_2_cache_backend.enable(self.transformer_2)
307+
self._teacache_backends = [
308+
self.cache_backend,
309+
self.transformer_2_cache_backend,
310+
]
311+
290312
def _run_warmup(self, height: int, width: int, num_frames: int, steps: int) -> None:
291313
with torch.no_grad():
292314
self.forward(

0 commit comments

Comments
 (0)