diff --git a/docs/source/models/visual-generation.md b/docs/source/models/visual-generation.md index 58c2874b3045..c5caefacbb70 100644 --- a/docs/source/models/visual-generation.md +++ b/docs/source/models/visual-generation.md @@ -48,14 +48,18 @@ Models are auto-detected from the checkpoint directory. Diffusers-format models | **FLUX.1** | Yes | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | | **FLUX.2** | Yes | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | | **Wan 2.1** | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | -| **Wan 2.2** | Yes | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | -| **LTX-2** | Yes | Yes | No | Yes | Yes | Yes | No | No | Yes | Yes | Yes | Yes | No | -| **Qwen-Image** [^2] | Yes | Yes | No | No | No | Yes | No | Yes | Yes | Yes | Yes | Yes | No | +| **Wan 2.2** | Yes | Yes | Yes [^2] | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | +| **LTX-2** | Yes | Yes | Yes [^3] | Yes | Yes | Yes | No | No | Yes | Yes | Yes | Yes | No | +| **Qwen-Image** [^4] | Yes | Yes | No | No | No | Yes | No | Yes | Yes | Yes | Yes | Yes | No | | **Cosmos3** | Yes | Yes | No | No | Yes | Yes | Yes | Yes | Yes | Yes | No | No | Yes | [^1]: FLUX models use embedded guidance and do not have a separate negative prompt path, so CFG parallelism is not applicable. -[^2]: Qwen-Image ships a native BF16 implementation with per-module numerical parity vs `diffusers.QwenImagePipeline` (cosine >= 0.999 on the full 20B transformer) and `trtllm-serve` / `/v1/images/generations` support. FP8 blockwise and NVFP4 use VisualGen dynamic quantization from BF16 checkpoints; no pre-quantized checkpoint is required. +[^2]: Wan 2.2 has two stage transformers; TeaCache requires explicit `teacache.coefficients` (high-noise) and `teacache.coefficients_2` (low-noise). There is no built-in coefficient table for Wan 2.2. + +[^3]: LTX-2 has no built-in TeaCache coefficient table in TRT-LLM; set `teacache.coefficients` explicitly when enabling TeaCache. + +[^4]: Qwen-Image ships a native BF16 implementation with per-module numerical parity vs `diffusers.QwenImagePipeline` (cosine >= 0.999 on the full 20B transformer) and `trtllm-serve` / `/v1/images/generations` support. FP8 blockwise and NVFP4 use VisualGen dynamic quantization from BF16 checkpoints; no pre-quantized checkpoint is required. ## Quick Start diff --git a/tensorrt_llm/_torch/visual_gen/cache/teacache_accelerator.py b/tensorrt_llm/_torch/visual_gen/cache/teacache_accelerator.py index 273247e3209d..5119d275c52a 100644 --- a/tensorrt_llm/_torch/visual_gen/cache/teacache_accelerator.py +++ b/tensorrt_llm/_torch/visual_gen/cache/teacache_accelerator.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Tuple import torch.nn as nn @@ -17,52 +17,55 @@ class TeaCacheAccelerator(CacheAccelerator): - """Whole-transformer skip using existing :class:`TeaCacheBackend`.""" + """Whole-transformer step skipping via :class:`TeaCacheBackend`.""" - def __init__(self, teacache_cfg: TeaCacheConfig): + def __init__(self, pipeline: Any, teacache_cfg: TeaCacheConfig): + self._pipeline = pipeline self._cfg = teacache_cfg - self._backend: Optional[TeaCacheBackend] = None - self._module: Optional[nn.Module] = None + self._backends: List[Tuple[nn.Module, TeaCacheBackend]] = [] def wrap(self, *args: Any, **kwargs: Any) -> None: - """Enable TeaCache on ``model``. - - Coefficient matching is done on the pipeline via - :meth:`BasePipeline._apply_teacache_coefficients` before this runs. - - Keyword args: ``model`` (required). - """ - model: Optional[nn.Module] = kwargs.get("model") - if model is None: + if self._backends: return - if self._backend is not None: + transformer = getattr(self._pipeline, "transformer", None) + if transformer is None: return - logger.info("TeaCache: Initializing...") - self._backend = TeaCacheBackend(self._cfg) - self._backend.enable(model) - self._module = model + transformer_2 = getattr(self._pipeline, "transformer_2", None) + if transformer_2 is not None and self._cfg.coefficients_2 is not None: + cfg_low = self._cfg.model_copy(update={"coefficients": self._cfg.coefficients_2}) + logger.info("TeaCache: Initializing (high-noise transformer)...") + backend_high = TeaCacheBackend(self._cfg) + backend_high.enable(transformer) + logger.info("TeaCache: Initializing (low-noise transformer_2)...") + backend_low = TeaCacheBackend(cfg_low) + backend_low.enable(transformer_2) + self._backends = [(transformer, backend_high), (transformer_2, backend_low)] + else: + logger.info("TeaCache: Initializing...") + backend = TeaCacheBackend(self._cfg) + backend.enable(transformer) + self._backends = [(transformer, backend)] def unwrap(self) -> None: - if self._backend is None or self._module is None: - return - self._backend.disable(self._module) - self._backend = None - self._module = None + for module, backend in self._backends: + backend.disable(module) + self._backends = [] def refresh(self, num_inference_steps: int) -> None: - if self._backend: - self._backend.refresh(num_inference_steps) + for _, backend in self._backends: + backend.refresh(num_inference_steps) def get_stats(self) -> Dict[str, Any]: - if not self._backend: + if not self._backends: return {} - return self._backend.get_stats() or {} + if len(self._backends) == 1: + return self._backends[0][1].get_stats() or {} + return {f"transformer_{i}": b.get_stats() for i, (_, b) in enumerate(self._backends)} def is_enabled(self) -> bool: - return bool(self._backend and self._backend.is_enabled()) + return bool(self._backends and any(b.is_enabled() for _, b in self._backends)) @property - def tea_cache_backend(self) -> Optional[TeaCacheBackend]: - """The wrapped TeaCache hook owner (for introspection or advanced use).""" - return self._backend + def backends(self) -> List[Tuple[nn.Module, TeaCacheBackend]]: + return self._backends diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py index 4528a657709f..81e8672e295a 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py @@ -235,7 +235,8 @@ def post_load_weights(self) -> None: ) # TeaCache or Cache-DiT - self._setup_cache_acceleration(self.transformer, FLUX_TEACACHE_COEFFICIENTS) + self._apply_teacache_coefficients(FLUX_TEACACHE_COEFFICIENTS) + self._setup_cache_acceleration() @property def default_generation_params(self): diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py index 7850d9413e0f..7c6de849bd02 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py @@ -302,33 +302,34 @@ def load_weights(self, weights: dict) -> None: self.transformer.eval() def post_load_weights(self) -> None: - """Post-load setup: TeaCache registration.""" + """Post-load setup: cache acceleration (TeaCache or Cache-DiT).""" super().post_load_weights() if self.transformer is not None: - # Register TeaCache extractor for FLUX.2 (must be after device placement) - # Only set guidance_param_name for variants with guidance_embeds - guidance_param = "guidance" if self.transformer.guidance_embeds else None - forward_params = [ - "hidden_states", - "encoder_hidden_states", - "timestep", - "img_ids", - "txt_ids", - "guidance", - "return_dict", - ] - register_extractor_from_config( - ExtractorConfig( - model_class_name="Flux2Transformer2DModel", - timestep_embed_fn=self._compute_flux2_timestep_embedding, - guidance_param_name=guidance_param, - forward_params=forward_params, - return_dict_default=False, + if self.pipeline_config.cache_backend == "teacache": + # Register TeaCache extractor for FLUX.2 (must be after device placement) + # Only set guidance_param_name for variants with guidance_embeds + guidance_param = "guidance" if self.transformer.guidance_embeds else None + forward_params = [ + "hidden_states", + "encoder_hidden_states", + "timestep", + "img_ids", + "txt_ids", + "guidance", + "return_dict", + ] + register_extractor_from_config( + ExtractorConfig( + model_class_name="Flux2Transformer2DModel", + timestep_embed_fn=self._compute_flux2_timestep_embedding, + guidance_param_name=guidance_param, + forward_params=forward_params, + return_dict_default=False, + ) ) - ) - # TeaCache or Cache-DiT - self._setup_cache_acceleration(self.transformer, FLUX2_TEACACHE_COEFFICIENTS) + self._apply_teacache_coefficients(FLUX2_TEACACHE_COEFFICIENTS) + self._setup_cache_acceleration() @property def default_generation_params(self): diff --git a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py index ff919c739f17..106cf4923c8b 100644 --- a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py +++ b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py @@ -16,7 +16,7 @@ from transformers import Gemma3ForConditionalGeneration, GemmaTokenizerFast from tensorrt_llm._torch.utils import make_weak_ref -from tensorrt_llm._torch.visual_gen.cache.teacache import CacheContext +from tensorrt_llm._torch.visual_gen.cache.teacache import CacheContext, register_extractor from tensorrt_llm._torch.visual_gen.checkpoints.prefetch import prefetch_files_to_host_cache from tensorrt_llm._torch.visual_gen.cuda_graph_runner import CUDAGraphRunner, CUDAGraphRunnerConfig from tensorrt_llm._torch.visual_gen.output import CudaPhaseTimer, PipelineOutput @@ -1015,17 +1015,22 @@ def post_load_weights(self) -> None: """Finalize after weight loading: TeaCache, Cache-DiT, derived attributes.""" super().post_load_weights() - # TODO: TeaCache disabled: LTX2_TEACACHE_COEFFICIENTS are unverified. - # To re-enable, uncomment the following lines and verify coefficients. - # register_extractor( - # "LTXModel", - # LTX2TeaCacheExtractor(self._compute_ltx2_timestep_embedding), - # ) - # self._setup_teacache(self.transformer, coefficients=LTX2_TEACACHE_COEFFICIENTS) + # LTX-2: single transformer (one DiT for video+audio); TeaCache only with explicit coefficients. + if self.transformer is not None and self.pipeline_config.cache_backend == "teacache": + if self.pipeline_config.teacache.coefficients is None: + raise ValueError( + "TeaCache on LTX-2 requires explicit teacache.coefficients " + "(no built-in coefficient table)." + ) + register_extractor( + "LTXModel", + LTX2TeaCacheExtractor(self._compute_ltx2_timestep_embedding), + ) + self._setup_cache_acceleration() # Cache-DiT if self.transformer is not None and self.pipeline_config.cache_backend == "cache_dit": - self._setup_cache_acceleration(self.transformer, coefficients=None) + self._setup_cache_acceleration() # Compression ratios from native scale factors self.vae_spatial_compression_ratio = VIDEO_SCALE_FACTORS.width diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py index ea4e1b880608..9346e2c2277f 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py @@ -102,13 +102,6 @@ def __init__(self, pipeline_config): self.is_wan22_14b = self.boundary_ratio is not None self.is_wan22_5b = self.expand_timesteps - # Validate TeaCache compatibility before allocating GPU memory - if (self.is_wan22_14b or self.is_wan22_5b) and pipeline_config.cache_backend == "teacache": - raise ValueError( - "TeaCache is not supported for Wan 2.2 models. " - "Use cache_backend='none' or 'cache_dit' (not 'teacache')." - ) - # Fixed latent for reproducible benchmarking (e.g. MLPerf). # Set TRTLLM_VIDEO_FIXED_LATENT_PATH to a .pt file containing a pre-sampled # noise tensor; it will be used in place of freshly sampled random latents for @@ -331,20 +324,31 @@ def post_load_weights(self) -> None: ) if not self.is_wan22_14b: - self._setup_cache_acceleration( - self.transformer, coefficients=WAN_TEACACHE_COEFFICIENTS - ) - self.transformer_cache_backend = self.cache_accelerator + self._apply_teacache_coefficients(WAN_TEACACHE_COEFFICIENTS) + self._setup_cache_acceleration() else: if self.pipeline_config.cache_backend == "cache_dit": - self._setup_cache_acceleration(self.transformer, coefficients=None) - # TeaCache is not supported for Wan 2.2 unless using Cache-DiT. - self.transformer_cache_backend = self.cache_accelerator + self._setup_cache_acceleration() if self.transformer_2 is not None: if hasattr(self.transformer_2, "post_load_weights"): self.transformer_2.post_load_weights() + # Wan 2.2 TeaCache after both transformers' post_load_weights (FP8 scales, etc.) + if ( + self.transformer is not None + and self.transformer_2 is not None + and self.pipeline_config.cache_backend == "teacache" + ): + tc = self.pipeline_config.teacache + if tc.coefficients is None or tc.coefficients_2 is None: + raise ValueError( + "Wan 2.2 TeaCache requires explicit teacache.coefficients and " + "teacache.coefficients_2 (high-noise and low-noise stage polynomials). " + "There is no built-in coefficient table for Wan 2.2." + ) + self._setup_cache_acceleration() + def _run_warmup(self, height: int, width: int, num_frames: int, steps: int) -> None: with torch.no_grad(): self.forward( diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py index 5d8ea711abea..0b7d5aa44819 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py @@ -96,13 +96,6 @@ def __init__(self, pipeline_config): ) self.is_wan22_14b = self.boundary_ratio is not None - # Validate TeaCache compatibility before allocating GPU memory - if self.is_wan22_14b and pipeline_config.cache_backend == "teacache": - raise ValueError( - "TeaCache is not supported for Wan 2.2 models. " - "Use cache_backend='none' or 'cache_dit' (not 'teacache')." - ) - super().__init__(pipeline_config) def _compute_wan_timestep_embedding(self, module, timestep=None, **kwargs): @@ -338,19 +331,30 @@ def post_load_weights(self) -> None: ) if not self.is_wan22_14b: - self._setup_cache_acceleration( - self.transformer, coefficients=WAN_I2V_TEACACHE_COEFFICIENTS - ) - self.transformer_cache_backend = self.cache_accelerator + self._apply_teacache_coefficients(WAN_I2V_TEACACHE_COEFFICIENTS) + self._setup_cache_acceleration() else: if self.pipeline_config.cache_backend == "cache_dit": - self._setup_cache_acceleration(self.transformer, coefficients=None) - self.transformer_cache_backend = self.cache_accelerator + self._setup_cache_acceleration() if self.transformer_2 is not None: if hasattr(self.transformer_2, "post_load_weights"): self.transformer_2.post_load_weights() + if ( + self.transformer is not None + and self.transformer_2 is not None + and self.pipeline_config.cache_backend == "teacache" + ): + tc = self.pipeline_config.teacache + if tc.coefficients is None or tc.coefficients_2 is None: + raise ValueError( + "Wan 2.2 TeaCache requires explicit teacache.coefficients and " + "teacache.coefficients_2 (high-noise and low-noise stage polynomials). " + "There is no built-in coefficient table for Wan 2.2." + ) + self._setup_cache_acceleration() + def _run_warmup(self, height: int, width: int, num_frames: int, steps: int) -> None: dummy_image = PIL.Image.new("RGB", (width, height)) with torch.no_grad(): diff --git a/tensorrt_llm/_torch/visual_gen/pipeline.py b/tensorrt_llm/_torch/visual_gen/pipeline.py index 3e77c2135aba..ec35e53bd95e 100644 --- a/tensorrt_llm/_torch/visual_gen/pipeline.py +++ b/tensorrt_llm/_torch/visual_gen/pipeline.py @@ -400,48 +400,92 @@ def post_load_weights(self) -> None: if self.transformer is not None and hasattr(self.transformer, "post_load_weights"): self.transformer.post_load_weights() - def _apply_teacache_coefficients(self, coefficients: Optional[Dict]) -> None: - """Pick TeaCache coefficients from checkpoint path; updates pipeline config in place.""" + def _apply_teacache_coefficients(self, coefficients: Optional[Dict] = None) -> None: + """Resolve TeaCache polynomial coefficients into pipeline_config.cache (TeaCacheConfig). + + Precedence: + + 1. User-specified TeaCacheConfig.coefficients — any non-None list skips built-in + variant matching. + + 2. Pipeline table — if step 1 does not apply and coefficients is a non-empty dict + (model-specific tables from the pipeline subclass), match + pretrained_config._name_or_path against keys and set coefficients (and optional + default_thresh). + + 3. If coefficients are still unresolved after step 2, _setup_cache_acceleration + raises: TeaCache must not run without resolved coefficients. + + Args: + coefficients: Optional mapping from variant key to coefficient list or nested + dict (ret_steps / standard), from the pipeline subclass. + """ + teacache_cfg = self.pipeline_config.teacache + if teacache_cfg is None: + return + if teacache_cfg.is_explicit_user_override(): + logger.info( + "TeaCache: Using user-configured coefficients " + "(skipping built-in checkpoint variant matching)" + ) + return + + teacache_explicit = teacache_cfg.model_dump(exclude_unset=True) + if not coefficients: + if teacache_cfg.coefficients is None: + raise ValueError( + "TeaCache is enabled but no polynomial coefficients were resolved. " + "Set teacache.coefficients in VisualGenArgs, or use a pipeline and " + "checkpoint whose path matches a built-in coefficient table." + ) return - teacache_cfg = self.pipeline_config.teacache - checkpoint_path = getattr( - self.pipeline_config.primary_pretrained_config, "_name_or_path", "" + + checkpoint_path = ( + getattr(self.pipeline_config.primary_pretrained_config, "_name_or_path", "") or "" ) - matched = False + for model_size, coeff_data in coefficients.items(): - if model_size.lower() in checkpoint_path.lower(): - matched = True - if isinstance(coeff_data, dict): - mode = "ret_steps" if teacache_cfg.use_ret_steps else "standard" - if mode in coeff_data: - teacache_cfg.coefficients = coeff_data[mode] - logger.info(f"TeaCache: Using {model_size} coefficients ({mode} mode)") - default_thresh = coeff_data.get("default_thresh") - if ( - default_thresh is not None - and "teacache_thresh" not in teacache_cfg.model_fields_set - ): - teacache_cfg.teacache_thresh = default_thresh - logger.info( - f"TeaCache: Using {model_size} default threshold {default_thresh}" - ) - else: - teacache_cfg.coefficients = coeff_data - logger.info(f"TeaCache: Using {model_size} coefficients") + # Match model size in path (case-insensitive, e.g., "1.3B", "14B", "dev") + path_l = checkpoint_path.lower() + key_l = model_size.lower() + if key_l not in path_l: + continue + + if isinstance(coeff_data, dict): + # Select coefficient set based on warmup mode + mode = "ret_steps" if teacache_cfg.use_ret_steps else "standard" + if mode not in coeff_data: + logger.warning( + "TeaCache: matched variant %r but table has no %r entry " + "(available keys: %s). Trying other variants.", + model_size, + mode, + list(coeff_data.keys()), + ) + continue + teacache_cfg.coefficients = coeff_data[mode] + logger.info(f"TeaCache: Using {model_size} coefficients ({mode} mode)") + # Apply model-specific default threshold if user didn't explicitly set one + default_thresh = coeff_data.get("default_thresh") + if default_thresh is not None and "teacache_thresh" not in teacache_explicit: + teacache_cfg.teacache_thresh = default_thresh + logger.info(f"TeaCache: Using {model_size} default threshold {default_thresh}") + break + else: + # Single coefficient list (no mode distinction) + teacache_cfg.coefficients = coeff_data + logger.info(f"TeaCache: Using {model_size} coefficients") break - if not matched: + else: raise ValueError( f"TeaCache: No coefficients found for checkpoint '{checkpoint_path}'. " f"Available variants: {list(coefficients.keys())}. " - f"TeaCache is not supported for this model variant." + f"Set teacache.coefficients explicitly in VisualGenArgs to use TeaCache anyway, " + f"or use a checkpoint path that contains one of the variant keys." ) - def _setup_cache_acceleration( - self, - model: Optional[nn.Module] = None, - coefficients: Optional[Dict] = None, - ) -> None: + def _setup_cache_acceleration(self) -> None: """Enable TeaCache or Cache-DiT from model_config.cache_backend.""" if getattr(self, "cache_accelerator", None) is not None: @@ -460,15 +504,9 @@ def _setup_cache_acceleration( if not use_teacache: return - BasePipeline._apply_teacache_coefficients(self, coefficients) - - if model is None: - return - - acc = TeaCacheAccelerator(cfg.teacache) - acc.wrap(model=model) - if acc.is_enabled(): - self.cache_accelerator = acc + acc = TeaCacheAccelerator(self, cfg.teacache) + acc.wrap() + self.cache_accelerator = acc def setup_parallel_vae(self): """Enable parallel-VAE decode mode and wrap the VAE on participating ranks. @@ -1140,11 +1178,20 @@ def denoise( if stats: if self.pipeline_config.cache_backend == "cache_dit": logger.info("Cache-DiT stats: %s", stats) - elif "hit_rate" in stats: - logger.info( - f"TeaCache: {stats['hit_rate']:.1%} hit rate " - f"({stats['cached']}/{stats['total']} steps)" - ) + elif self.pipeline_config.cache_backend == "teacache": + first_val = next(iter(stats.values()), None) + if isinstance(first_val, dict): + for key, s in stats.items(): + if "hit_rate" in s: + logger.info( + f"TeaCache {key}: {s['hit_rate']:.1%} hit rate " + f"({s['cached']}/{s['total']} steps)" + ) + elif "hit_rate" in stats: + logger.info( + f"TeaCache: {stats['hit_rate']:.1%} hit rate " + f"({stats['cached']}/{stats['total']} steps)" + ) else: logger.info("Cache acceleration stats: %s", stats) diff --git a/tensorrt_llm/visual_gen/args.py b/tensorrt_llm/visual_gen/args.py index 570744f6c6b9..975782643630 100644 --- a/tensorrt_llm/visual_gen/args.py +++ b/tensorrt_llm/visual_gen/args.py @@ -297,23 +297,38 @@ class TeaCacheConfig(BaseCacheConfig): teacache_thresh: float = Field(0.2, gt=0.0, status="prototype") use_ret_steps: bool = Field(False, status="prototype") - coefficients: List[float] = Field( - default_factory=lambda: [1.0, 0.0], + coefficients: Optional[List[float]] = Field( + default=None, status="prototype", description=( "Polynomial coefficients used by the TeaCache decision function. " - "Variable-length (FLUX uses 5, Wan uses 4); the pipeline overrides " - "this per-checkpoint at load time." + "None (default) uses the pipeline's built-in per-checkpoint table; " + "an explicit list overrides the table entirely." + ), + ) + + coefficients_2: Optional[List[float]] = Field( + default=None, + status="prototype", + description=( + "Second polynomial (Wan 2.2 dual-transformer low-noise stage only). " + "Required together with coefficients when enabling TeaCache on Wan 2.2." ), ) @model_validator(mode="after") def validate_teacache(self) -> "TeaCacheConfig": """Validate TeaCache configuration.""" - if len(self.coefficients) == 0: + if self.coefficients is not None and len(self.coefficients) == 0: raise ValueError("TeaCache coefficients list cannot be empty") + if self.coefficients_2 is not None and len(self.coefficients_2) == 0: + raise ValueError("TeaCache coefficients_2 list cannot be empty") return self + def is_explicit_user_override(self) -> bool: + """Return True if coefficients were set by the user and should skip built-in table matching.""" + return self.coefficients is not None + class CacheDiTConfig(BaseCacheConfig): """Configuration for Cache-DiT (DBCache, TaylorSeer, SCM). diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 74386f393832..5490ff14ade3 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -238,6 +238,9 @@ l0_b200: - unittest/_torch/visual_gen/test_wan22_ti2v_5b_pipeline.py - unittest/_torch/visual_gen/test_wan21_i2v_teacache.py - unittest/_torch/visual_gen/test_wan21_t2v_teacache.py + - unittest/_torch/visual_gen/test_wan21_t2v_teacache_user_coefficients.py + - unittest/_torch/visual_gen/test_wan22_i2v_teacache.py + - unittest/_torch/visual_gen/test_wan22_t2v_teacache.py - unittest/_torch/visual_gen/test_wan_transformer.py - unittest/_torch/visual_gen/test_cosmos3_transformer.py - unittest/_torch/visual_gen/test_cosmos3_pipeline.py diff --git a/tests/unittest/_torch/visual_gen/test_teacache.py b/tests/unittest/_torch/visual_gen/test_teacache.py index a0080a6d2008..c53f77aaa135 100644 --- a/tests/unittest/_torch/visual_gen/test_teacache.py +++ b/tests/unittest/_torch/visual_gen/test_teacache.py @@ -15,34 +15,164 @@ """Unit tests for TeaCache (CPU-only, no model weights needed).""" from types import SimpleNamespace +from typing import Optional from unittest.mock import MagicMock, patch import pytest from tensorrt_llm._torch.visual_gen.cache.teacache import TeaCacheBackend -from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig, DiffusionPipelineConfig +from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig from tensorrt_llm._torch.visual_gen.pipeline import BasePipeline from tensorrt_llm.visual_gen.args import TeaCacheConfig +class _PipelineConfigShim: + """Minimal pipeline_config shim that delegates all reads back to a DiffusionModelConfig. + + Shares the same TeaCacheConfig object so mutations from _apply_teacache_coefficients + are visible via both pipeline_config.teacache and model_config.teacache. + """ + + def __init__(self, model_config: DiffusionModelConfig): + self._mc = model_config + + @property + def cache_backend(self): + return self._mc.cache_backend + + @property + def teacache(self): + return self._mc.teacache + + @property + def cache_dit(self): + return self._mc.cache_dit + + @property + def primary_pretrained_config(self): + return self._mc.pretrained_config + + def model_copy(self, update=None): + if update: + for k, v in update.items(): + setattr(self._mc, k, v) + return self + + +class _PipelineTeacacheTestDouble: + """Minimal object for BasePipeline._setup_cache_acceleration / _apply_teacache_coefficients tests.""" + + def __init__(self, model_config: DiffusionModelConfig): + self.model_config = model_config + self.cache_accelerator = None + self.cache_backend = None + self.pipeline_config = _PipelineConfigShim(model_config) + + def _apply_teacache_coefficients(self, coefficients: Optional[dict] = None): + return BasePipeline._apply_teacache_coefficients(self, coefficients) + + class TestSetupTeacache: """Tests for _setup_cache_acceleration TeaCache coefficient matching and fail-early behavior.""" def _make_pipeline_mock(self, checkpoint_name, use_ret_steps=False): - pipeline = MagicMock() - pipeline.cache_accelerator = None - model_config = DiffusionModelConfig( - pretrained_config=SimpleNamespace(_name_or_path=f"/path/to/{checkpoint_name}/snapshot"), + return _PipelineTeacacheTestDouble( + DiffusionModelConfig( + pretrained_config=SimpleNamespace( + _name_or_path=f"/path/to/{checkpoint_name}/snapshot" + ), + cache=TeaCacheConfig( + teacache_thresh=0.3, + use_ret_steps=use_ret_steps, + coefficients=None, + ), + skip_create_weights_in_init=True, + ) ) - pipeline.pipeline_config = DiffusionPipelineConfig( - model_configs={"transformer": model_config}, - cache=TeaCacheConfig( - teacache_thresh=0.3, - use_ret_steps=use_ret_steps, - ), - skip_create_weights_in_init=True, + + def _make_pipeline_teacache_enable_only(self, checkpoint_name, use_ret_steps=False): + """Only use_ret_steps (+ defaults) so teacache_thresh may be omitted from explicit set.""" + return _PipelineTeacacheTestDouble( + DiffusionModelConfig( + pretrained_config=SimpleNamespace( + _name_or_path=f"/path/to/{checkpoint_name}/snapshot" + ), + cache=TeaCacheConfig( + use_ret_steps=use_ret_steps, + coefficients=None, + ), + skip_create_weights_in_init=True, + ) ) - return pipeline + + def test_setup_cache_acceleration_raises_when_no_table_and_no_user_coefficients(self): + """Fails if TeaCache is on but the pipeline passes no coefficient table.""" + pipeline = self._make_pipeline_mock("FLUX.1-dev") + with pytest.raises(ValueError, match="no polynomial coefficients were resolved"): + BasePipeline._apply_teacache_coefficients(pipeline, None) + BasePipeline._setup_cache_acceleration(pipeline) + + def test_setup_cache_acceleration_raises_when_empty_table(self): + """Same as no table: nothing to resolve from.""" + pipeline = self._make_pipeline_mock("FLUX.1-dev") + with pytest.raises(ValueError, match="no polynomial coefficients were resolved"): + BasePipeline._apply_teacache_coefficients(pipeline, {}) + BasePipeline._setup_cache_acceleration(pipeline) + + def test_matching_variant_selects_ret_steps_mode(self): + """Nested table: use_ret_steps=True selects ret_steps coefficients.""" + pipeline = self._make_pipeline_mock("FLUX.1-dev", use_ret_steps=True) + coefficients = { + "dev": {"standard": [1.0, 2.0, 3.0], "ret_steps": [4.0, 5.0]}, + } + with patch.object(TeaCacheBackend, "enable"): + BasePipeline._apply_teacache_coefficients(pipeline, coefficients) + BasePipeline._setup_cache_acceleration(pipeline) + + assert pipeline.model_config.teacache.coefficients == [4.0, 5.0] + + def test_flat_list_table_entry(self): + """Table value may be a plain list (no standard/ret_steps nesting).""" + pipeline = self._make_pipeline_mock("FLUX.1-dev") + coefficients = {"dev": [11.0, 22.0, 33.0]} + with patch.object(TeaCacheBackend, "enable"): + BasePipeline._apply_teacache_coefficients(pipeline, coefficients) + BasePipeline._setup_cache_acceleration(pipeline) + + assert pipeline.model_config.teacache.coefficients == [11.0, 22.0, 33.0] + + def test_default_thresh_from_table_when_user_did_not_set_teacache_thresh(self): + """default_thresh applies when teacache_thresh was not explicitly set (exclude_unset).""" + pipeline = self._make_pipeline_teacache_enable_only("FLUX.1-dev") + builtin = { + "dev": { + "standard": [1.0, 2.0], + "ret_steps": [3.0, 4.0], + "default_thresh": 0.42, + }, + } + with patch.object(TeaCacheBackend, "enable"): + BasePipeline._apply_teacache_coefficients(pipeline, builtin) + BasePipeline._setup_cache_acceleration(pipeline) + + assert pipeline.model_config.teacache.coefficients == [1.0, 2.0] + assert pipeline.model_config.teacache.teacache_thresh == 0.42 + + def test_explicit_identity_coefficients_still_skip_table(self): + """[1.0, 0.0] is a user override: no variant lookup, no ValueError on unknown path.""" + pipeline = self._make_pipeline_mock("FLUX.1-unknown-variant") + pipeline.model_config.cache = TeaCacheConfig( + teacache_thresh=0.3, + coefficients=[1.0, 0.0], + ) + with patch.object(TeaCacheBackend, "enable"): + BasePipeline._apply_teacache_coefficients( + pipeline, + {"dev": {"standard": [99.0, 99.0]}}, + ) + BasePipeline._setup_cache_acceleration(pipeline) + + assert pipeline.model_config.teacache.coefficients == [1.0, 0.0] def test_matching_variant_selects_coefficients(self): """Picks coefficients whose key appears in checkpoint path.""" @@ -52,7 +182,8 @@ def test_matching_variant_selects_coefficients(self): "schnell": {"standard": [10.0, 20.0]}, } with patch.object(TeaCacheBackend, "enable"): - BasePipeline._setup_cache_acceleration(pipeline, MagicMock(), coefficients) + BasePipeline._apply_teacache_coefficients(pipeline, coefficients) + BasePipeline._setup_cache_acceleration(pipeline) assert pipeline.pipeline_config.teacache.coefficients == [1.0, 2.0, 3.0] @@ -64,12 +195,329 @@ def test_no_match_raises_valueerror(self): "schnell": {"standard": [10.0, 20.0]}, } with pytest.raises(ValueError, match="No coefficients found"): - BasePipeline._setup_cache_acceleration(pipeline, MagicMock(), coefficients) + BasePipeline._apply_teacache_coefficients(pipeline, coefficients) def test_disabled_teacache_is_noop(self): """No-op when cache is None (TeaCache not selected).""" pipeline = self._make_pipeline_mock("FLUX.1-dev") pipeline.pipeline_config = pipeline.pipeline_config.model_copy(update={"cache": None}) - BasePipeline._setup_cache_acceleration(pipeline, MagicMock(), {"dev": [1.0]}) + BasePipeline._setup_cache_acceleration(pipeline) assert pipeline.cache_accelerator is None + + def test_user_configured_coefficients_skip_variant_matching(self): + """Explicit TeaCacheConfig.coefficients skips dict lookup (no ValueError).""" + pipeline = self._make_pipeline_mock("FLUX.1-unknown-variant") + pipeline.model_config.cache = TeaCacheConfig( + teacache_thresh=0.3, + coefficients=[0.25, 0.5, 0.75], + ) + builtin = { + "dev": {"standard": [1.0, 2.0, 3.0], "ret_steps": [4.0, 5.0]}, + } + with patch.object(TeaCacheBackend, "enable"): + BasePipeline._apply_teacache_coefficients(pipeline, builtin) + BasePipeline._setup_cache_acceleration(pipeline) + + assert pipeline.model_config.teacache.coefficients == [0.25, 0.5, 0.75] + + def test_user_configured_coefficients_take_precedence_over_builtin_table(self): + """User coefficients are not overwritten when a built-in variant would also match.""" + pipeline = self._make_pipeline_mock("FLUX.1-dev") + pipeline.model_config.cache = TeaCacheConfig( + teacache_thresh=0.3, + coefficients=[9.0, 8.0, 7.0], + ) + builtin = { + "dev": {"standard": [1.0, 2.0, 3.0], "ret_steps": [4.0, 5.0]}, + } + with patch.object(TeaCacheBackend, "enable"): + BasePipeline._apply_teacache_coefficients(pipeline, builtin) + BasePipeline._setup_cache_acceleration(pipeline) + + assert pipeline.model_config.teacache.coefficients == [9.0, 8.0, 7.0] + + def test_apply_teacache_coefficients_only(self): + """_apply_teacache_coefficients updates config without enabling backend.""" + pipeline = self._make_pipeline_mock("FLUX.1-unknown-variant") + pipeline.model_config.cache = TeaCacheConfig( + coefficients=[0.1, 0.2], + ) + BasePipeline._apply_teacache_coefficients(pipeline, {"dev": {"standard": [99.0]}}) + assert pipeline.model_config.teacache.coefficients == [0.1, 0.2] + + def test_nested_table_missing_requested_mode_warns_then_raises_if_no_fallback(self): + """Variant matches path but nested dict lacks 'standard' / 'ret_steps' entry for mode.""" + pipeline = self._make_pipeline_mock("FLUX.1-dev") + # Only ret_steps; standard mode requested (use_ret_steps=False) + coefficients = {"dev": {"ret_steps": [9.0, 8.0]}} + # tensorrt_llm.logger does not propagate to the root logger, so pytest caplog + # does not see these records; assert via the pipeline module's logger.warning. + with patch("tensorrt_llm._torch.visual_gen.pipeline.logger.warning") as mock_warning: + with pytest.raises(ValueError, match="No coefficients found"): + BasePipeline._apply_teacache_coefficients(pipeline, coefficients) + mock_warning.assert_called_once() + joined = " ".join(str(a) for a in mock_warning.call_args[0]) + assert "matched variant" in joined + + +class TestTeaCacheConfigValidation: + """TeaCacheConfig validation (no pipeline).""" + + def test_empty_coefficients_rejected(self): + with pytest.raises(ValueError, match="cannot be empty"): + TeaCacheConfig(coefficients=[]) + + def test_empty_coefficients_2_rejected(self): + with pytest.raises(ValueError, match="coefficients_2"): + TeaCacheConfig(coefficients_2=[]) + + +class TestTeaCacheAcceleratorRefresh: + """TeaCacheAccelerator.refresh delegates to all registered backends.""" + + def _make_accelerator(self, backends): + from tensorrt_llm._torch.visual_gen.cache.teacache_accelerator import TeaCacheAccelerator + + acc = TeaCacheAccelerator.__new__(TeaCacheAccelerator) + acc._backends = backends + return acc + + def test_single_backend_refresh(self): + backend = MagicMock() + acc = self._make_accelerator([(MagicMock(), backend)]) + acc.refresh(50) + backend.refresh.assert_called_once_with(50) + + def test_refreshes_two_distinct_backends(self): + b1, b2 = MagicMock(), MagicMock() + acc = self._make_accelerator([(MagicMock(), b1), (MagicMock(), b2)]) + acc.refresh(30) + b1.refresh.assert_called_once_with(30) + b2.refresh.assert_called_once_with(30) + + def test_no_backends_is_noop(self): + acc = self._make_accelerator([]) + acc.refresh(10) # should not raise + + +class TestFlux2TeacacheTable: + """FLUX.2 built-in coefficient table (dev variant).""" + + def test_flux2_dev_variant_resolves_from_checkpoint_path(self): + from tensorrt_llm._torch.visual_gen.models.flux.pipeline_flux2 import ( + FLUX2_TEACACHE_COEFFICIENTS, + ) + + pipeline = _PipelineTeacacheTestDouble( + DiffusionModelConfig( + pretrained_config=SimpleNamespace( + _name_or_path="/weights/black-forest-labs/FLUX.2-dev/snapshot" + ), + cache=TeaCacheConfig(teacache_thresh=0.2), + skip_create_weights_in_init=True, + ) + ) + with patch.object(TeaCacheBackend, "enable"): + BasePipeline._apply_teacache_coefficients(pipeline, FLUX2_TEACACHE_COEFFICIENTS) + BasePipeline._setup_cache_acceleration(pipeline) + assert pipeline.model_config.teacache.coefficients is not None + assert len(pipeline.model_config.teacache.coefficients) >= 2 + + +class TestExplicitTeaCacheCoefficientsRequired: + """LTX-2 and Wan 2.2 have no built-in TeaCache tables; teacache requires user coefficients.""" + + def test_ltx2_raises_when_teacache_enabled_without_coefficients(self): + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + + pipe = object.__new__(LTX2Pipeline) + pipe.transformer = MagicMock() + pipe.model_config = DiffusionModelConfig( + pretrained_config=SimpleNamespace(), + cache=TeaCacheConfig(coefficients=None), + skip_create_weights_in_init=True, + ) + pipe.pipeline_config = _PipelineConfigShim(pipe.model_config) + with patch.object(BasePipeline, "post_load_weights", lambda self: None): + with pytest.raises(ValueError, match="LTX-2 requires explicit teacache.coefficients"): + LTX2Pipeline.post_load_weights(pipe) + + def test_ltx2_succeeds_with_explicit_coefficients(self): + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + + pipe = object.__new__(LTX2Pipeline) + pipe.transformer = MagicMock() + pipe.model_config = DiffusionModelConfig( + pretrained_config=SimpleNamespace(), + cache=TeaCacheConfig( + coefficients=[1.0, 2.0, 3.0], + ), + skip_create_weights_in_init=True, + ) + pipe.pipeline_config = _PipelineConfigShim(pipe.model_config) + with patch.object(BasePipeline, "post_load_weights", lambda self: None): + with patch( + "tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2.register_extractor" + ): + with patch.object(TeaCacheBackend, "enable"): + LTX2Pipeline.post_load_weights(pipe) + assert pipe.cache_accelerator is not None + + def test_wan22_raises_when_teacache_enabled_without_both_coefficient_lists(self): + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan import WanPipeline + + pipe = object.__new__(WanPipeline) + pipe.transformer = MagicMock() + pipe.transformer_2 = MagicMock() + pipe.is_wan22 = True + pipe.is_wan22_14b = True + pipe.model_config = DiffusionModelConfig( + pretrained_config=SimpleNamespace( + _name_or_path="/models/wan14b/snapshot", boundary_ratio=0.2 + ), + cache=TeaCacheConfig( + coefficients=None, + coefficients_2=None, + ), + skip_create_weights_in_init=True, + ) + pipe.pipeline_config = _PipelineConfigShim(pipe.model_config) + with patch.object(BasePipeline, "post_load_weights", lambda self: None): + with patch( + "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan.register_extractor_from_config" + ): + with pytest.raises(ValueError, match="Wan 2.2 TeaCache requires explicit"): + WanPipeline.post_load_weights(pipe) + + def test_wan22_i2v_raises_when_teacache_enabled_without_both_coefficient_lists(self): + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan_i2v import ( + WanImageToVideoPipeline, + ) + + pipe = object.__new__(WanImageToVideoPipeline) + pipe.transformer = MagicMock() + pipe.transformer_2 = MagicMock() + pipe.is_wan22 = True + pipe.is_wan22_14b = True + pipe.model_config = DiffusionModelConfig( + pretrained_config=SimpleNamespace( + _name_or_path="/models/wan720p/snapshot", boundary_ratio=0.2 + ), + cache=TeaCacheConfig( + coefficients=None, + coefficients_2=None, + ), + skip_create_weights_in_init=True, + ) + pipe.pipeline_config = _PipelineConfigShim(pipe.model_config) + with patch.object(BasePipeline, "post_load_weights", lambda self: None): + with patch( + "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan_i2v.register_extractor_from_config" + ): + with pytest.raises(ValueError, match="Wan 2.2 TeaCache requires explicit"): + WanImageToVideoPipeline.post_load_weights(pipe) + + def test_wan22_t2v_installs_two_teacache_backends_when_coefficients_provided(self): + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan import WanPipeline + + pipe = object.__new__(WanPipeline) + pipe.transformer = MagicMock() + pipe.transformer_2 = MagicMock() + pipe.is_wan22 = True + pipe.is_wan22_14b = True + pipe.model_config = DiffusionModelConfig( + pretrained_config=SimpleNamespace(_name_or_path="/wan/snapshot", boundary_ratio=0.2), + cache=TeaCacheConfig( + coefficients=[1.0, 2.0], + coefficients_2=[3.0, 4.0], + ), + skip_create_weights_in_init=True, + ) + pipe.pipeline_config = _PipelineConfigShim(pipe.model_config) + mock_enable = MagicMock() + backend_a = MagicMock() + backend_a.enable = mock_enable + backend_b = MagicMock() + backend_b.enable = mock_enable + with patch.object(BasePipeline, "post_load_weights", lambda self: None): + with patch( + "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan.register_extractor_from_config" + ): + with patch( + "tensorrt_llm._torch.visual_gen.cache.teacache_accelerator.TeaCacheBackend" + ) as TB: + TB.side_effect = [backend_a, backend_b] + WanPipeline.post_load_weights(pipe) + assert TB.call_count == 2 + assert mock_enable.call_count == 2 + assert pipe.cache_accelerator is not None + assert len(pipe.cache_accelerator.backends) == 2 + + def test_wan22_t2v_transformer_gets_coefficients_and_transformer_2_gets_coefficients_2(self): + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan import WanPipeline + + pipe = object.__new__(WanPipeline) + pipe.transformer = MagicMock() + pipe.transformer_2 = MagicMock() + pipe.is_wan22 = True + pipe.is_wan22_14b = True + pipe.model_config = DiffusionModelConfig( + pretrained_config=SimpleNamespace(_name_or_path="/wan/snapshot", boundary_ratio=0.2), + cache=TeaCacheConfig( + coefficients=[1.0, 2.0], + coefficients_2=[3.0, 4.0], + ), + skip_create_weights_in_init=True, + ) + pipe.pipeline_config = _PipelineConfigShim(pipe.model_config) + with patch.object(BasePipeline, "post_load_weights", lambda self: None): + with patch( + "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan.register_extractor_from_config" + ): + with patch( + "tensorrt_llm._torch.visual_gen.cache.teacache_accelerator.TeaCacheBackend" + ) as TB: + TB.return_value = MagicMock() + WanPipeline.post_load_weights(pipe) + + assert TB.call_count == 2 + cfg_high = TB.call_args_list[0][0][0] + cfg_low = TB.call_args_list[1][0][0] + assert cfg_high.coefficients == [1.0, 2.0] + assert cfg_low.coefficients == [3.0, 4.0] + + def test_wan22_i2v_transformer_gets_coefficients_and_transformer_2_gets_coefficients_2(self): + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan_i2v import ( + WanImageToVideoPipeline, + ) + + pipe = object.__new__(WanImageToVideoPipeline) + pipe.transformer = MagicMock() + pipe.transformer_2 = MagicMock() + pipe.is_wan22 = True + pipe.is_wan22_14b = True + pipe.model_config = DiffusionModelConfig( + pretrained_config=SimpleNamespace(_name_or_path="/wan/snapshot", boundary_ratio=0.2), + cache=TeaCacheConfig( + coefficients=[5.0, 6.0], + coefficients_2=[7.0, 8.0], + ), + skip_create_weights_in_init=True, + ) + pipe.pipeline_config = _PipelineConfigShim(pipe.model_config) + with patch.object(BasePipeline, "post_load_weights", lambda self: None): + with patch( + "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan_i2v.register_extractor_from_config" + ): + with patch( + "tensorrt_llm._torch.visual_gen.cache.teacache_accelerator.TeaCacheBackend" + ) as TB: + TB.return_value = MagicMock() + WanImageToVideoPipeline.post_load_weights(pipe) + + assert TB.call_count == 2 + cfg_high = TB.call_args_list[0][0][0] + cfg_low = TB.call_args_list[1][0][0] + assert cfg_high.coefficients == [5.0, 6.0] + assert cfg_low.coefficients == [7.0, 8.0] diff --git a/tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py b/tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py index 75cc07ade90b..798755df8c86 100644 --- a/tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py +++ b/tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py @@ -171,7 +171,7 @@ def _assert_i2v_teacache( seed=INFER_SEED, ) - stats = pipeline.transformer_cache_backend.get_stats() + stats = pipeline.cache_accelerator.get_stats() print(f"\n ===== TeaCache — Wan 2.1 {model} single-stage {height}x{width} =====") print( @@ -252,5 +252,5 @@ def test_wan22_raises_if_teacache_enabled(self): model=WAN22_I2V_A14B_PATH, cache_config=TeaCacheConfig(), ) - with pytest.raises(ValueError, match="TeaCache is not supported for Wan 2\\.2"): + with pytest.raises(ValueError, match=r"Wan 2\.2 TeaCache requires explicit"): PipelineLoader(args).load(skip_warmup=True) diff --git a/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py b/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py index 8d3b5f14f178..76e19f8232e5 100644 --- a/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py +++ b/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py @@ -155,7 +155,7 @@ def _assert_single_stage_teacache( seed=INFER_SEED, ) - stats = pipeline.transformer_cache_backend.get_stats() + stats = pipeline.cache_accelerator.get_stats() print(f"\n ===== TeaCache — Wan 2.1 {model} single-stage {height}x{width} =====") print( @@ -234,5 +234,5 @@ def test_wan22_raises_if_teacache_enabled(self): model=WAN22_A14B_PATH, cache_config=TeaCacheConfig(), ) - with pytest.raises(ValueError, match=r"TeaCache is not supported for Wan 2\.2"): + with pytest.raises(ValueError, match=r"Wan 2\.2 TeaCache requires explicit"): PipelineLoader(args).load(skip_warmup=True) diff --git a/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache_user_coefficients.py b/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache_user_coefficients.py new file mode 100644 index 000000000000..aa16c708b6a7 --- /dev/null +++ b/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache_user_coefficients.py @@ -0,0 +1,166 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Confirm that user-configured TeaCache coefficients affect the cache rate on Wan 2.1 1.3B. + +Loads model weights and runs two forward passes back-to-back, each with a +different user-supplied TeaCacheConfig.coefficients list, then prints and compares +the cached step counts to confirm the user override is respected. + +Run: + pytest tests/unittest/_torch/visual_gen/test_wan13b_teacache_coefficients.py -v -s + +Override checkpoint: + DIFFUSION_MODEL_PATH_WAN21_1_3B=/path/to/weights \\ + pytest tests/unittest/_torch/visual_gen/test_wan13b_teacache_coefficients.py -v -s +""" + +import gc +import os + +os.environ["TLLM_DISABLE_MPI"] = "1" + +from pathlib import Path + +import pytest +import torch + +from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader +from tensorrt_llm.visual_gen.args import TeaCacheConfig, VisualGenArgs + + +@pytest.fixture(autouse=True, scope="module") +def _cleanup_mpi_env(): + yield + os.environ.pop("TLLM_DISABLE_MPI", None) + + +@pytest.fixture(autouse=True) +def _cleanup_gpu(): + gc.collect() + torch.cuda.empty_cache() + yield + gc.collect() + torch.cuda.empty_cache() + + +def _llm_models_root() -> Path: + if "LLM_MODELS_ROOT" in os.environ: + root = Path(os.environ["LLM_MODELS_ROOT"]) + else: + root = Path("/home/scratch.trt_llm_data_ci/llm-models/") + if not root.exists(): + root = Path("/scratch.trt_llm_data/llm-models/") + assert root.exists(), ( + "Set LLM_MODELS_ROOT or ensure /home/scratch.trt_llm_data_ci/llm-models/ is accessible." + ) + return root + + +WAN21_1_3B_PATH = os.environ.get( + "DIFFUSION_MODEL_PATH_WAN21_1_3B", + str(_llm_models_root() / "Wan2.1-T2V-1.3B-Diffusers"), +) + +PROMPT = "a cat sitting on a windowsill" +HEIGHT, WIDTH = 480, 832 +NUM_FRAMES = 33 # (33-1)/4+1 = 9 latent frames; smallest realistic shape +NUM_STEPS = 50 # enough steps for coefficients to produce meaningful cache hits +SEED = 42 + +# Two user-supplied coefficient lists passed explicitly via TeaCacheConfig.coefficients, +# bypassing the built-in variant-lookup table entirely, testing the user-override code path. + +COEFFICIENTS_CALIBRATED = [ + 2.39676752e03, + -1.31110545e03, + 2.01331979e02, + -8.29855975e00, + 1.37887774e-01, +] + +COEFFICIENTS_IDENTITY_LINEAR = [ + 1.0, + 0.0, +] + + +def _run_forward(coefficients: list, thresh: float, label: str) -> dict: + """Load the pipeline with the given user-supplied coefficients, run one forward pass.""" + if not os.path.exists(WAN21_1_3B_PATH): + pytest.skip(f"Checkpoint not found: {WAN21_1_3B_PATH}") + + args = VisualGenArgs( + model=WAN21_1_3B_PATH, + cache_config=TeaCacheConfig( + teacache_thresh=thresh, + coefficients=coefficients, + ), + ) + pipeline = PipelineLoader(args).load(skip_warmup=True) + try: + with torch.no_grad(): + pipeline.forward( + prompt=PROMPT, + negative_prompt="", + height=HEIGHT, + width=WIDTH, + num_frames=NUM_FRAMES, + num_inference_steps=NUM_STEPS, + seed=SEED, + ) + stats = pipeline.cache_accelerator.get_stats() + finally: + del pipeline + gc.collect() + torch.cuda.empty_cache() + + print( + f" {label:<30s} cached={stats['cached_steps']:>3}/{stats['total_steps']} " + f"({stats['hit_rate']:.1%} cache rate)" + ) + return stats + + +@pytest.mark.integration +@pytest.mark.wan_t2v +@pytest.mark.teacache +class TestWan13BUserCoefficientsAffectCacheRate: + """User-supplied TeaCacheConfig.coefficients override the built-in table and change cache rate.""" + + def test_different_user_coefficients_produce_different_cache_rates(self): + print(f"\n {'coefficient set':<30s} {'cached / total':>15} hit rate") + print(f" {'-' * 65}") + + stats_a = _run_forward( + COEFFICIENTS_CALIBRATED, + thresh=0.2, + label="calibrated 1.3B standard", + ) + stats_b = _run_forward( + COEFFICIENTS_IDENTITY_LINEAR, + thresh=0.2, + label="identity linear [1, 0]", + ) + + print(f" {'-' * 65}") + print( + f" difference in cached steps: " + f"{abs(stats_a['cached_steps'] - stats_b['cached_steps'])}" + ) + + assert stats_a["cached_steps"] != stats_b["cached_steps"], ( + f"Both coefficient sets produced {stats_a['cached_steps']} cached steps — " + f"expected user-configured coefficients to produce different cache rates." + ) diff --git a/tests/unittest/_torch/visual_gen/test_wan22_i2v_teacache.py b/tests/unittest/_torch/visual_gen/test_wan22_i2v_teacache.py new file mode 100644 index 000000000000..7cc439162d4a --- /dev/null +++ b/tests/unittest/_torch/visual_gen/test_wan22_i2v_teacache.py @@ -0,0 +1,200 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Wan 2.2 I2V TeaCache happy path. + +Wan 2.2 uses a dual-transformer architecture (high-noise + low-noise stages). +TeaCache requires explicit coefficients for both transformers via +TeaCacheConfig.coefficients (high-noise) and TeaCacheConfig.coefficients_2 (low-noise). + +Verifies: + - Both transformer backends are initialized + - Both backends produce cache hits after a forward pass + - Stats are returned for each transformer separately + +Run: + pytest tests/unittest/_torch/visual_gen/test_wan22_i2v_teacache.py -v -s + +Override checkpoint path: + DIFFUSION_MODEL_PATH_WAN22_I2V=/path/to/wan22 \\ + pytest tests/unittest/_torch/visual_gen/test_wan22_i2v_teacache.py -v -s +""" + +import os + +os.environ["TLLM_DISABLE_MPI"] = "1" + +import gc +from pathlib import Path + +import numpy as np +import pytest +import torch +from PIL import Image + +from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader +from tensorrt_llm.visual_gen.args import TeaCacheConfig, VisualGenArgs + + +@pytest.fixture(autouse=True, scope="module") +def _cleanup_mpi_env(): + yield + os.environ.pop("TLLM_DISABLE_MPI", None) + + +@pytest.fixture(autouse=True) +def _cleanup_gpu(): + gc.collect() + torch.cuda.empty_cache() + yield + gc.collect() + torch.cuda.empty_cache() + + +# ============================================================================ +# Path helpers +# ============================================================================ + + +def _llm_models_root() -> Path: + if "LLM_MODELS_ROOT" in os.environ: + root = Path(os.environ["LLM_MODELS_ROOT"]) + else: + root = Path("/home/scratch.trt_llm_data_ci/llm-models/") + if not root.exists(): + root = Path("/scratch.trt_llm_data/llm-models/") + assert root.exists(), ( + "Set LLM_MODELS_ROOT or ensure /home/scratch.trt_llm_data_ci/llm-models/ is accessible." + ) + return root + + +def _checkpoint(env_var: str, default_name: str) -> str: + return os.environ.get(env_var) or str(_llm_models_root() / default_name) + + +WAN22_I2V_A14B_PATH = _checkpoint("DIFFUSION_MODEL_PATH_WAN22_I2V", "Wan2.2-I2V-A14B-Diffusers") + +INFER_NUM_FRAMES = 33 # (33-1)/4+1 = 9 latent frames; smallest realistic shape +INFER_NUM_STEPS = 20 # Wan 2.2 has no reference hit rate; just enough to exercise both backends +INFER_SEED = 42 + +# Placeholder coefficients for Wan 2.2 dual-transformer TeaCache. +# Wan 2.2 has no built-in coefficient table; users must supply both sets. +# These are used to exercise the dual-backend path and verify cache hits fire. +WAN22_I2V_HIGH_NOISE_COEFFICIENTS = [ + -5784.54975374, + 5449.50911966, + -1811.16591783, + 256.27178429, + -13.02301147, +] +WAN22_I2V_LOW_NOISE_COEFFICIENTS = [ + 2.39676752e03, + -1.31110545e03, + 2.01331979e02, + -8.29855975e00, + 1.37887774e-01, +] + + +# ============================================================================ +# Helpers +# ============================================================================ + + +def _make_test_image(height: int, width: int) -> Image.Image: + img = np.zeros((height, width, 3), dtype=np.uint8) + img[:, :, 0] = np.linspace(0, 255, height, dtype=np.uint8)[:, None] + return Image.fromarray(img, mode="RGB") + + +# ============================================================================ +# Fixture +# ============================================================================ + + +@pytest.fixture +def wan22_i2v_pipeline(): + if not os.path.exists(WAN22_I2V_A14B_PATH): + pytest.skip(f"Checkpoint not found: {WAN22_I2V_A14B_PATH}") + args = VisualGenArgs( + model=WAN22_I2V_A14B_PATH, + cache_config=TeaCacheConfig( + teacache_thresh=0.15, + coefficients=WAN22_I2V_HIGH_NOISE_COEFFICIENTS, + coefficients_2=WAN22_I2V_LOW_NOISE_COEFFICIENTS, + ), + ) + pipeline = PipelineLoader(args).load(skip_warmup=True) + yield pipeline + del pipeline + torch.cuda.empty_cache() + + +# ============================================================================ +# Assertion helper +# ============================================================================ + + +def _assert_dual_stage_teacache(pipeline, height: int, width: int) -> None: + test_image = _make_test_image(height, width) + + with torch.no_grad(): + pipeline.forward( + image=test_image, + prompt="a cat sitting on a windowsill", + negative_prompt="", + height=height, + width=width, + num_frames=INFER_NUM_FRAMES, + num_inference_steps=INFER_NUM_STEPS, + seed=INFER_SEED, + ) + + stats = pipeline.cache_accelerator.get_stats() + + print(f"\n ===== TeaCache — Wan 2.2 I2V-A14B dual-stage {height}x{width} =====") + for key, s in stats.items(): + print( + f" {key}: {s['cached_steps']}/{s['total_steps']} cached ({s['hit_rate']:.1%} hit rate)" + ) + print(" ================================================================") + + assert len(stats) == 2, f"Expected stats for 2 transformers, got: {list(stats.keys())}" + total_steps_sum = sum(s["total_steps"] for s in stats.values()) + assert total_steps_sum == INFER_NUM_STEPS, ( + f"Sum of steps across both transformers {total_steps_sum} != {INFER_NUM_STEPS}" + ) + for key, s in stats.items(): + assert s["total_steps"] > 0, f"{key}: transformer ran 0 steps" + assert s["compute_steps"] + s["cached_steps"] == s["total_steps"] + + +# ============================================================================ +# Tests +# ============================================================================ + + +@pytest.mark.integration +@pytest.mark.wan_i2v +@pytest.mark.teacache +class TestWan22I2V_TeaCache: + """Wan2.2-I2V-A14B 480x832 dual-stage TeaCache.""" + + def test_wan22_i2v_teacache_forward_runs(self, wan22_i2v_pipeline): + _assert_dual_stage_teacache(wan22_i2v_pipeline, height=480, width=832) + + def test_wan22_i2v_teacache_two_backends_initialized(self, wan22_i2v_pipeline): + assert len(wan22_i2v_pipeline.cache_accelerator.backends) == 2 diff --git a/tests/unittest/_torch/visual_gen/test_wan22_t2v_teacache.py b/tests/unittest/_torch/visual_gen/test_wan22_t2v_teacache.py new file mode 100644 index 000000000000..f295fab62e88 --- /dev/null +++ b/tests/unittest/_torch/visual_gen/test_wan22_t2v_teacache.py @@ -0,0 +1,184 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Wan 2.2 T2V TeaCache happy path. + +Wan 2.2 uses a dual-transformer architecture (high-noise + low-noise stages). +TeaCache requires explicit coefficients for both transformers via +TeaCacheConfig.coefficients (high-noise) and TeaCacheConfig.coefficients_2 (low-noise). + +Verifies: + - Both transformer backends are initialized + - Both backends produce cache hits after a forward pass + - Stats are returned for each transformer separately + +Run: + pytest tests/unittest/_torch/visual_gen/test_wan22_t2v_teacache.py -v -s + +Override checkpoint path: + DIFFUSION_MODEL_PATH_WAN22_T2V=/path/to/wan22 \\ + pytest tests/unittest/_torch/visual_gen/test_wan22_t2v_teacache.py -v -s +""" + +import os + +os.environ["TLLM_DISABLE_MPI"] = "1" + +import gc +from pathlib import Path + +import pytest +import torch + +from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader +from tensorrt_llm.visual_gen.args import TeaCacheConfig, VisualGenArgs + + +@pytest.fixture(autouse=True, scope="module") +def _cleanup_mpi_env(): + yield + os.environ.pop("TLLM_DISABLE_MPI", None) + + +@pytest.fixture(autouse=True) +def _cleanup_gpu(): + gc.collect() + torch.cuda.empty_cache() + yield + gc.collect() + torch.cuda.empty_cache() + + +# ============================================================================ +# Path helpers +# ============================================================================ + + +def _llm_models_root() -> Path: + if "LLM_MODELS_ROOT" in os.environ: + root = Path(os.environ["LLM_MODELS_ROOT"]) + else: + root = Path("/home/scratch.trt_llm_data_ci/llm-models/") + if not root.exists(): + root = Path("/scratch.trt_llm_data/llm-models/") + assert root.exists(), ( + "Set LLM_MODELS_ROOT or ensure /home/scratch.trt_llm_data_ci/llm-models/ is accessible." + ) + return root + + +def _checkpoint(env_var: str, default_name: str) -> str: + return os.environ.get(env_var) or str(_llm_models_root() / default_name) + + +WAN22_A14B_PATH = _checkpoint("DIFFUSION_MODEL_PATH_WAN22_T2V", "Wan2.2-T2V-A14B-Diffusers") + +INFER_NUM_FRAMES = 33 # (33-1)/4+1 = 9 latent frames; smallest realistic shape +INFER_NUM_STEPS = 20 # Wan 2.2 has no reference hit rate; just enough to exercise both backends +INFER_SEED = 42 + +# Placeholder coefficients for Wan 2.2 dual-transformer TeaCache. +# Wan 2.2 has no built-in coefficient table; users must supply both sets. +# These are used to exercise the dual-backend path and verify cache hits fire. +WAN22_T2V_HIGH_NOISE_COEFFICIENTS = [ + -5784.54975374, + 5449.50911966, + -1811.16591783, + 256.27178429, + -13.02301147, +] +WAN22_T2V_LOW_NOISE_COEFFICIENTS = [ + 2.39676752e03, + -1.31110545e03, + 2.01331979e02, + -8.29855975e00, + 1.37887774e-01, +] + + +# ============================================================================ +# Fixture +# ============================================================================ + + +@pytest.fixture +def wan22_t2v_pipeline(): + if not os.path.exists(WAN22_A14B_PATH): + pytest.skip(f"Checkpoint not found: {WAN22_A14B_PATH}") + args = VisualGenArgs( + model=WAN22_A14B_PATH, + cache_config=TeaCacheConfig( + teacache_thresh=0.15, + coefficients=WAN22_T2V_HIGH_NOISE_COEFFICIENTS, + coefficients_2=WAN22_T2V_LOW_NOISE_COEFFICIENTS, + ), + ) + pipeline = PipelineLoader(args).load(skip_warmup=True) + yield pipeline + del pipeline + torch.cuda.empty_cache() + + +# ============================================================================ +# Assertion helper +# ============================================================================ + + +def _assert_dual_stage_teacache(pipeline, height: int, width: int) -> None: + with torch.no_grad(): + pipeline.forward( + prompt="a cat sitting on a windowsill", + negative_prompt="", + height=height, + width=width, + num_frames=INFER_NUM_FRAMES, + num_inference_steps=INFER_NUM_STEPS, + seed=INFER_SEED, + ) + + stats = pipeline.cache_accelerator.get_stats() + + print(f"\n ===== TeaCache — Wan 2.2 T2V-A14B dual-stage {height}x{width} =====") + for key, s in stats.items(): + print( + f" {key}: {s['cached_steps']}/{s['total_steps']} cached ({s['hit_rate']:.1%} hit rate)" + ) + print(" ================================================================") + + assert len(stats) == 2, f"Expected stats for 2 transformers, got: {list(stats.keys())}" + total_steps_sum = sum(s["total_steps"] for s in stats.values()) + assert total_steps_sum == INFER_NUM_STEPS, ( + f"Sum of steps across both transformers {total_steps_sum} != {INFER_NUM_STEPS}" + ) + for key, s in stats.items(): + assert s["total_steps"] > 0, f"{key}: transformer ran 0 steps" + assert s["compute_steps"] + s["cached_steps"] == s["total_steps"] + + +# ============================================================================ +# Tests +# ============================================================================ + + +@pytest.mark.integration +@pytest.mark.wan_t2v +@pytest.mark.teacache +class TestWan22T2V_TeaCache: + """Wan2.2-T2V-A14B 480x832 dual-stage TeaCache.""" + + def test_wan22_t2v_teacache_forward_runs(self, wan22_t2v_pipeline): + _assert_dual_stage_teacache(wan22_t2v_pipeline, height=480, width=832) + + def test_wan22_t2v_teacache_two_backends_initialized(self, wan22_t2v_pipeline): + assert len(wan22_t2v_pipeline.cache_accelerator.backends) == 2