|
12 | 12 | from tensorrt_llm._torch.visual_gen.output import MediaOutput |
13 | 13 | from tensorrt_llm._torch.visual_gen.pipeline import BasePipeline |
14 | 14 | from tensorrt_llm._torch.visual_gen.pipeline_registry import register_pipeline |
15 | | -from tensorrt_llm._torch.visual_gen.teacache import ExtractorConfig, register_extractor_from_config |
| 15 | +from tensorrt_llm._torch.visual_gen.teacache import ( |
| 16 | + ExtractorConfig, |
| 17 | + TeaCacheBackend, |
| 18 | + register_extractor_from_config, |
| 19 | +) |
16 | 20 | from tensorrt_llm._torch.visual_gen.utils import postprocess_video_tensor |
17 | 21 | from tensorrt_llm._utils import nvtx_range |
18 | 22 | from tensorrt_llm.logger import logger |
@@ -77,13 +81,6 @@ def __init__(self, model_config): |
77 | 81 | self.boundary_ratio = getattr(model_config.pretrained_config, "boundary_ratio", None) |
78 | 82 | self.is_wan22 = self.boundary_ratio is not None |
79 | 83 |
|
80 | | - # Validate TeaCache compatibility before allocating GPU memory |
81 | | - if self.is_wan22 and model_config.teacache.enable_teacache: |
82 | | - raise ValueError( |
83 | | - "TeaCache is not supported for Wan 2.2 T2V models. " |
84 | | - "Set enable_teacache=False in TeaCacheConfig." |
85 | | - ) |
86 | | - |
87 | 84 | super().__init__(model_config) |
88 | 85 |
|
89 | 86 | def _compute_wan_timestep_embedding(self, module, timestep=None, **kwargs): |
@@ -277,16 +274,41 @@ def post_load_weights(self) -> None: |
277 | 274 | if not self.is_wan22: |
278 | 275 | self._setup_teacache(self.transformer, coefficients=WAN_TEACACHE_COEFFICIENTS) |
279 | 276 | self.transformer_cache_backend = self.cache_backend |
280 | | - else: |
281 | | - # TeaCache is not supported for Wan 2.2: the dual-transformer |
282 | | - # architecture (transformer + transformer_2) requires separate |
283 | | - # TeaCache coefficients that have not been calibrated yet. |
284 | | - self.transformer_cache_backend = None |
285 | 277 |
|
286 | 278 | if self.transformer_2 is not None: |
287 | 279 | if hasattr(self.transformer_2, "post_load_weights"): |
288 | 280 | self.transformer_2.post_load_weights() |
289 | 281 |
|
| 282 | + # Wan 2.2 TeaCache after both transformers' post_load_weights (FP8 scales, etc.) |
| 283 | + if ( |
| 284 | + self.transformer is not None |
| 285 | + and self.transformer_2 is not None |
| 286 | + and self.is_wan22 |
| 287 | + and self.model_config.teacache.enable_teacache |
| 288 | + ): |
| 289 | + self._apply_teacache_coefficients(WAN_TEACACHE_COEFFICIENTS) |
| 290 | + tc = self.model_config.teacache |
| 291 | + if tc.coefficients is None or tc.coefficients_2 is None: |
| 292 | + raise ValueError( |
| 293 | + "Wan 2.2 TeaCache requires explicit teacache.coefficients and " |
| 294 | + "teacache.coefficients_2 (high-noise and low-noise stage polynomials). " |
| 295 | + "There is no built-in coefficient table for Wan 2.2." |
| 296 | + ) |
| 297 | + cfg_high = tc.model_copy(deep=True) |
| 298 | + cfg_low = tc.model_copy(deep=True) |
| 299 | + cfg_low.coefficients = tc.coefficients_2 |
| 300 | + logger.info("TeaCache: Initializing (Wan 2.2 high-noise transformer)...") |
| 301 | + self.cache_backend = TeaCacheBackend(cfg_high) |
| 302 | + self.cache_backend.enable(self.transformer) |
| 303 | + self.transformer_cache_backend = self.cache_backend |
| 304 | + logger.info("TeaCache: Initializing (Wan 2.2 low-noise transformer_2)...") |
| 305 | + self.transformer_2_cache_backend = TeaCacheBackend(cfg_low) |
| 306 | + self.transformer_2_cache_backend.enable(self.transformer_2) |
| 307 | + self._teacache_backends = [ |
| 308 | + self.cache_backend, |
| 309 | + self.transformer_2_cache_backend, |
| 310 | + ] |
| 311 | + |
290 | 312 | def _run_warmup(self, height: int, width: int, num_frames: int, steps: int) -> None: |
291 | 313 | with torch.no_grad(): |
292 | 314 | self.forward( |
|
0 commit comments