From 8c8f425476ea983a12af2a7f3d8a526235ff972e Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Fri, 17 Apr 2026 17:01:01 -0700 Subject: [PATCH 01/11] user-configured TeaCache coefficients Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- tensorrt_llm/_torch/visual_gen/pipeline.py | 116 ++++++--- .../_torch/visual_gen/test_teacache.py | 236 +++++++++++++++++- 2 files changed, 310 insertions(+), 42 deletions(-) diff --git a/tensorrt_llm/_torch/visual_gen/pipeline.py b/tensorrt_llm/_torch/visual_gen/pipeline.py index 7a0c629bf5c9..ee4e10a18dbb 100644 --- a/tensorrt_llm/_torch/visual_gen/pipeline.py +++ b/tensorrt_llm/_torch/visual_gen/pipeline.py @@ -101,6 +101,16 @@ def _parse_profile_range(): from .config import DiffusionPipelineConfig +def _teacache_coefficients_are_explicit_user_override(teacache_cfg: Any) -> bool: + """Return True if teacache.coefficients should skip built-in variant table matching. + + Only None means auto (use checkpoint path table). Any non-empty list, including + the identity polynomial [1.0, 0.0], is treated as an explicit user override. + """ + + return teacache_cfg.coefficients is not None + + class BasePipeline(nn.Module): """ Base class for diffusion pipelines. @@ -396,41 +406,81 @@ def post_load_weights(self) -> None: if self.transformer is not None and hasattr(self.transformer, "post_load_weights"): self.transformer.post_load_weights() - def _apply_teacache_coefficients(self, coefficients: Optional[Dict]) -> None: - """Pick TeaCache coefficients from checkpoint path; updates pipeline config in place.""" + def _apply_teacache_coefficients(self, coefficients: Optional[Dict] = None) -> None: + """Resolve TeaCache polynomial coefficients into pipeline_config.cache (TeaCacheConfig). + + Precedence: + + 1. User-specified TeaCacheConfig.coefficients — any non-None list skips built-in + variant matching. + + 2. Pipeline table — if step 1 does not apply and coefficients is a non-empty dict + (model-specific tables from the pipeline subclass), match + pretrained_config._name_or_path against keys and set coefficients (and optional + default_thresh). + + 3. If coefficients are still unresolved after step 2, _setup_cache_acceleration + raises: TeaCache must not run without resolved coefficients. + + Args: + coefficients: Optional mapping from variant key to coefficient list or nested + dict (ret_steps / standard), from the pipeline subclass. + """ + teacache_cfg = self.pipeline_config.teacache + if teacache_cfg is None: + return + if _teacache_coefficients_are_explicit_user_override(teacache_cfg): + logger.info( + "TeaCache: Using user-configured coefficients " + "(skipping built-in checkpoint variant matching)" + ) + return + + teacache_explicit = teacache_cfg.model_dump(exclude_unset=True) + if not coefficients: return - teacache_cfg = self.pipeline_config.teacache - checkpoint_path = getattr( - self.pipeline_config.primary_pretrained_config, "_name_or_path", "" - ) - matched = False + + checkpoint_path = getattr(self.pipeline_config.primary_pretrained_config, "_name_or_path", "") or "" + for model_size, coeff_data in coefficients.items(): - if model_size.lower() in checkpoint_path.lower(): - matched = True - if isinstance(coeff_data, dict): - mode = "ret_steps" if teacache_cfg.use_ret_steps else "standard" - if mode in coeff_data: - teacache_cfg.coefficients = coeff_data[mode] - logger.info(f"TeaCache: Using {model_size} coefficients ({mode} mode)") - default_thresh = coeff_data.get("default_thresh") - if ( - default_thresh is not None - and "teacache_thresh" not in teacache_cfg.model_fields_set - ): - teacache_cfg.teacache_thresh = default_thresh - logger.info( - f"TeaCache: Using {model_size} default threshold {default_thresh}" - ) - else: - teacache_cfg.coefficients = coeff_data - logger.info(f"TeaCache: Using {model_size} coefficients") + # Match model size in path (case-insensitive, e.g., "1.3B", "14B", "dev") + path_l = checkpoint_path.lower() + key_l = model_size.lower() + if key_l not in path_l: + continue + + if isinstance(coeff_data, dict): + # Select coefficient set based on warmup mode + mode = "ret_steps" if teacache_cfg.use_ret_steps else "standard" + if mode not in coeff_data: + logger.warning( + "TeaCache: matched variant %r but table has no %r entry " + "(available keys: %s). Trying other variants.", + model_size, + mode, + list(coeff_data.keys()), + ) + continue + teacache_cfg.coefficients = coeff_data[mode] + logger.info(f"TeaCache: Using {model_size} coefficients ({mode} mode)") + # Apply model-specific default threshold if user didn't explicitly set one + default_thresh = coeff_data.get("default_thresh") + if default_thresh is not None and "teacache_thresh" not in teacache_explicit: + teacache_cfg.teacache_thresh = default_thresh + logger.info(f"TeaCache: Using {model_size} default threshold {default_thresh}") + break + else: + # Single coefficient list (no mode distinction) + teacache_cfg.coefficients = coeff_data + logger.info(f"TeaCache: Using {model_size} coefficients") break - if not matched: + else: raise ValueError( f"TeaCache: No coefficients found for checkpoint '{checkpoint_path}'. " f"Available variants: {list(coefficients.keys())}. " - f"TeaCache is not supported for this model variant." + f"Set teacache.coefficients explicitly in VisualGenArgs to use TeaCache anyway, " + f"or use a checkpoint path that contains one of the variant keys." ) def _setup_cache_acceleration( @@ -456,7 +506,15 @@ def _setup_cache_acceleration( if not use_teacache: return - BasePipeline._apply_teacache_coefficients(self, coefficients) + self._apply_teacache_coefficients(coefficients) + + teacache_cfg = cfg.teacache + if teacache_cfg is None or teacache_cfg.coefficients is None: + raise ValueError( + "TeaCache is enabled but no polynomial coefficients were resolved. " + "Set teacache.coefficients in VisualGenArgs, or use a pipeline and " + "checkpoint whose path matches a built-in coefficient table." + ) if model is None: return diff --git a/tests/unittest/_torch/visual_gen/test_teacache.py b/tests/unittest/_torch/visual_gen/test_teacache.py index a0080a6d2008..9cd0408a86d2 100644 --- a/tests/unittest/_torch/visual_gen/test_teacache.py +++ b/tests/unittest/_torch/visual_gen/test_teacache.py @@ -15,34 +15,159 @@ """Unit tests for TeaCache (CPU-only, no model weights needed).""" from types import SimpleNamespace +from typing import Optional from unittest.mock import MagicMock, patch import pytest from tensorrt_llm._torch.visual_gen.cache.teacache import TeaCacheBackend -from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig, DiffusionPipelineConfig +from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig from tensorrt_llm._torch.visual_gen.pipeline import BasePipeline from tensorrt_llm.visual_gen.args import TeaCacheConfig +class _PipelineConfigShim: + """Minimal pipeline_config shim that delegates all reads back to a DiffusionModelConfig. + + Shares the same TeaCacheConfig object so mutations from _apply_teacache_coefficients + are visible via both pipeline_config.teacache and model_config.teacache. + """ + + def __init__(self, model_config: DiffusionModelConfig): + self._mc = model_config + + @property + def cache_backend(self): + return self._mc.cache_backend + + @property + def teacache(self): + return self._mc.teacache + + @property + def cache_dit(self): + return self._mc.cache_dit + + @property + def primary_pretrained_config(self): + return self._mc.pretrained_config + + def model_copy(self, update=None): + if update: + for k, v in update.items(): + setattr(self._mc, k, v) + return self + + +class _PipelineTeacacheTestDouble: + """Minimal object for BasePipeline._setup_cache_acceleration / _apply_teacache_coefficients tests.""" + + def __init__(self, model_config: DiffusionModelConfig): + self.model_config = model_config + self.cache_accelerator = None + self.cache_backend = None + self.pipeline_config = _PipelineConfigShim(model_config) + + def _apply_teacache_coefficients(self, coefficients: Optional[dict] = None): + return BasePipeline._apply_teacache_coefficients(self, coefficients) + + class TestSetupTeacache: """Tests for _setup_cache_acceleration TeaCache coefficient matching and fail-early behavior.""" def _make_pipeline_mock(self, checkpoint_name, use_ret_steps=False): - pipeline = MagicMock() - pipeline.cache_accelerator = None - model_config = DiffusionModelConfig( - pretrained_config=SimpleNamespace(_name_or_path=f"/path/to/{checkpoint_name}/snapshot"), + return _PipelineTeacacheTestDouble( + DiffusionModelConfig( + pretrained_config=SimpleNamespace( + _name_or_path=f"/path/to/{checkpoint_name}/snapshot" + ), + cache=TeaCacheConfig( + teacache_thresh=0.3, + use_ret_steps=use_ret_steps, + coefficients=None, + ), + skip_create_weights_in_init=True, + ) ) - pipeline.pipeline_config = DiffusionPipelineConfig( - model_configs={"transformer": model_config}, - cache=TeaCacheConfig( - teacache_thresh=0.3, - use_ret_steps=use_ret_steps, - ), - skip_create_weights_in_init=True, + + def _make_pipeline_teacache_enable_only(self, checkpoint_name, use_ret_steps=False): + """Only use_ret_steps (+ defaults) so teacache_thresh may be omitted from explicit set.""" + return _PipelineTeacacheTestDouble( + DiffusionModelConfig( + pretrained_config=SimpleNamespace( + _name_or_path=f"/path/to/{checkpoint_name}/snapshot" + ), + cache=TeaCacheConfig( + use_ret_steps=use_ret_steps, + coefficients=None, + ), + skip_create_weights_in_init=True, + ) + ) + + def test_setup_cache_acceleration_raises_when_no_table_and_no_user_coefficients(self): + """Fails if TeaCache is on but the pipeline passes no coefficient table.""" + pipeline = self._make_pipeline_mock("FLUX.1-dev") + with pytest.raises(ValueError, match="no polynomial coefficients were resolved"): + BasePipeline._setup_cache_acceleration(pipeline, MagicMock(), None) + + def test_setup_cache_acceleration_raises_when_empty_table(self): + """Same as no table: nothing to resolve from.""" + pipeline = self._make_pipeline_mock("FLUX.1-dev") + with pytest.raises(ValueError, match="no polynomial coefficients were resolved"): + BasePipeline._setup_cache_acceleration(pipeline, MagicMock(), {}) + + def test_matching_variant_selects_ret_steps_mode(self): + """Nested table: use_ret_steps=True selects ret_steps coefficients.""" + pipeline = self._make_pipeline_mock("FLUX.1-dev", use_ret_steps=True) + coefficients = { + "dev": {"standard": [1.0, 2.0, 3.0], "ret_steps": [4.0, 5.0]}, + } + with patch.object(TeaCacheBackend, "enable"): + BasePipeline._setup_cache_acceleration(pipeline, MagicMock(), coefficients) + + assert pipeline.model_config.teacache.coefficients == [4.0, 5.0] + + def test_flat_list_table_entry(self): + """Table value may be a plain list (no standard/ret_steps nesting).""" + pipeline = self._make_pipeline_mock("FLUX.1-dev") + coefficients = {"dev": [11.0, 22.0, 33.0]} + with patch.object(TeaCacheBackend, "enable"): + BasePipeline._setup_cache_acceleration(pipeline, MagicMock(), coefficients) + + assert pipeline.model_config.teacache.coefficients == [11.0, 22.0, 33.0] + + def test_default_thresh_from_table_when_user_did_not_set_teacache_thresh(self): + """default_thresh applies when teacache_thresh was not explicitly set (exclude_unset).""" + pipeline = self._make_pipeline_teacache_enable_only("FLUX.1-dev") + builtin = { + "dev": { + "standard": [1.0, 2.0], + "ret_steps": [3.0, 4.0], + "default_thresh": 0.42, + }, + } + with patch.object(TeaCacheBackend, "enable"): + BasePipeline._setup_cache_acceleration(pipeline, MagicMock(), builtin) + + assert pipeline.model_config.teacache.coefficients == [1.0, 2.0] + assert pipeline.model_config.teacache.teacache_thresh == 0.42 + + def test_explicit_identity_coefficients_still_skip_table(self): + """[1.0, 0.0] is a user override: no variant lookup, no ValueError on unknown path.""" + pipeline = self._make_pipeline_mock("FLUX.1-unknown-variant") + pipeline.model_config.cache = TeaCacheConfig( + teacache_thresh=0.3, + coefficients=[1.0, 0.0], ) - return pipeline + with patch.object(TeaCacheBackend, "enable"): + BasePipeline._setup_cache_acceleration( + pipeline, + MagicMock(), + {"dev": {"standard": [99.0, 99.0]}}, + ) + + assert pipeline.model_config.teacache.coefficients == [1.0, 0.0] def test_matching_variant_selects_coefficients(self): """Picks coefficients whose key appears in checkpoint path.""" @@ -73,3 +198,88 @@ def test_disabled_teacache_is_noop(self): BasePipeline._setup_cache_acceleration(pipeline, MagicMock(), {"dev": [1.0]}) assert pipeline.cache_accelerator is None + + def test_user_configured_coefficients_skip_variant_matching(self): + """Explicit TeaCacheConfig.coefficients skips dict lookup (no ValueError).""" + pipeline = self._make_pipeline_mock("FLUX.1-unknown-variant") + pipeline.model_config.cache = TeaCacheConfig( + teacache_thresh=0.3, + coefficients=[0.25, 0.5, 0.75], + ) + builtin = { + "dev": {"standard": [1.0, 2.0, 3.0], "ret_steps": [4.0, 5.0]}, + } + with patch.object(TeaCacheBackend, "enable"): + BasePipeline._setup_cache_acceleration(pipeline, MagicMock(), builtin) + + assert pipeline.model_config.teacache.coefficients == [0.25, 0.5, 0.75] + + def test_user_configured_coefficients_take_precedence_over_builtin_table(self): + """User coefficients are not overwritten when a built-in variant would also match.""" + pipeline = self._make_pipeline_mock("FLUX.1-dev") + pipeline.model_config.cache = TeaCacheConfig( + teacache_thresh=0.3, + coefficients=[9.0, 8.0, 7.0], + ) + builtin = { + "dev": {"standard": [1.0, 2.0, 3.0], "ret_steps": [4.0, 5.0]}, + } + with patch.object(TeaCacheBackend, "enable"): + BasePipeline._setup_cache_acceleration(pipeline, MagicMock(), builtin) + + assert pipeline.model_config.teacache.coefficients == [9.0, 8.0, 7.0] + + def test_apply_teacache_coefficients_only(self): + """_apply_teacache_coefficients updates config without enabling backend.""" + pipeline = self._make_pipeline_mock("FLUX.1-unknown-variant") + pipeline.model_config.cache = TeaCacheConfig( + coefficients=[0.1, 0.2], + ) + BasePipeline._apply_teacache_coefficients(pipeline, {"dev": {"standard": [99.0]}}) + assert pipeline.model_config.teacache.coefficients == [0.1, 0.2] + + def test_nested_table_missing_requested_mode_warns_then_raises_if_no_fallback(self): + """Variant matches path but nested dict lacks 'standard' / 'ret_steps' entry for mode.""" + pipeline = self._make_pipeline_mock("FLUX.1-dev") + coefficients = {"dev": {"ret_steps": [9.0, 8.0]}} + with patch( + "tensorrt_llm._torch.visual_gen.pipeline.logger.warning" + ) as mock_warning: + with pytest.raises(ValueError, match="No coefficients found"): + BasePipeline._apply_teacache_coefficients(pipeline, coefficients) + mock_warning.assert_called_once() + joined = " ".join(str(a) for a in mock_warning.call_args[0]) + assert "matched variant" in joined + + +class TestTeaCacheConfigValidation: + """TeaCacheConfig validation (no pipeline).""" + + def test_empty_coefficients_rejected(self): + with pytest.raises(ValueError, match="cannot be empty"): + TeaCacheConfig(coefficients=[]) + + +class TestFlux2TeacacheTable: + """FLUX.2 built-in coefficient table (dev variant).""" + + def test_flux2_dev_variant_resolves_from_checkpoint_path(self): + from tensorrt_llm._torch.visual_gen.models.flux.pipeline_flux2 import ( + FLUX2_TEACACHE_COEFFICIENTS, + ) + + pipeline = _PipelineTeacacheTestDouble( + DiffusionModelConfig( + pretrained_config=SimpleNamespace( + _name_or_path="/weights/black-forest-labs/FLUX.2-dev/snapshot" + ), + cache=TeaCacheConfig(teacache_thresh=0.2), + skip_create_weights_in_init=True, + ) + ) + with patch.object(TeaCacheBackend, "enable"): + BasePipeline._setup_cache_acceleration( + pipeline, MagicMock(), FLUX2_TEACACHE_COEFFICIENTS + ) + assert pipeline.model_config.teacache.coefficients is not None + assert len(pipeline.model_config.teacache.coefficients) >= 2 From c70de29593ea901831910b19c5fd3eecd6a6fc1f Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Fri, 17 Apr 2026 17:01:03 -0700 Subject: [PATCH 02/11] teacache for previously disabled pipelines; update examples Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- docs/source/models/visual-generation.md | 14 +- examples/visual_gen/README.md | 247 ++++++++++++++++++ examples/visual_gen/serve/README.md | 2 + .../visual_gen/models/flux/pipeline_flux2.py | 3 +- .../visual_gen/models/ltx2/pipeline_ltx2.py | 21 +- .../visual_gen/models/wan/pipeline_wan.py | 39 ++- .../visual_gen/models/wan/pipeline_wan_i2v.py | 37 ++- tensorrt_llm/_torch/visual_gen/pipeline.py | 9 + tensorrt_llm/visual_gen/args.py | 11 + .../_torch/visual_gen/test_teacache.py | 175 ++++++++++++- 10 files changed, 525 insertions(+), 33 deletions(-) diff --git a/docs/source/models/visual-generation.md b/docs/source/models/visual-generation.md index 0fe823ced41b..c77e38cff2ae 100644 --- a/docs/source/models/visual-generation.md +++ b/docs/source/models/visual-generation.md @@ -47,14 +47,18 @@ Models are auto-detected from the checkpoint directory. Diffusers-format models | **FLUX.1** | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | | **FLUX.2** | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | | **Wan 2.1** | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | -| **Wan 2.2** | Yes | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | -| **LTX-2** | Yes | Yes | No | Yes | Yes | No | No | Yes | Yes | Yes | Yes | No | -| **Qwen-Image** [^2] | Yes | Yes | No | No | Yes | No | Yes | Yes | Yes | Yes | Yes | No | +| **Wan 2.2** | Yes | Yes | Yes [^2] | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | +| **LTX-2** | Yes | Yes | Yes [^3] | Yes | Yes | No | No | Yes | Yes | Yes | Yes | No | +| **Qwen-Image** [^4] | Yes | Yes | No | No | Yes | No | Yes | Yes | Yes | Yes | Yes | No | | **Cosmos3** | Yes | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | No | No | Yes | [^1]: FLUX models use embedded guidance and do not have a separate negative prompt path, so CFG parallelism is not applicable. -[^2]: Qwen-Image ships a native BF16 implementation with per-module numerical parity vs `diffusers.QwenImagePipeline` (cosine >= 0.999 on the full 20B transformer) and `trtllm-serve` / `/v1/images/generations` support. FP8 blockwise and NVFP4 use VisualGen dynamic quantization from BF16 checkpoints; no pre-quantized checkpoint is required. +[^2]: Wan 2.2 has two stage transformers; TeaCache requires explicit `teacache.coefficients` (high-noise) and `teacache.coefficients_2` (low-noise). There is no built-in coefficient table for Wan 2.2. + +[^3]: LTX-2 has no built-in TeaCache coefficient table in TRT-LLM; set `teacache.coefficients` explicitly when enabling TeaCache. + +[^4]: Qwen-Image ships a native BF16 implementation with per-module numerical parity vs `diffusers.QwenImagePipeline` (cosine >= 0.999 on the full 20B transformer) and `trtllm-serve` / `/v1/images/generations` support. FP8 blockwise and NVFP4 use VisualGen dynamic quantization from BF16 checkpoints; no pre-quantized checkpoint is required. ## Quick Start @@ -164,7 +168,7 @@ cache_config: teacache_thresh: 0.2 ``` -The `teacache_thresh` parameter controls the similarity threshold. Cache-DiT is also supported via `cache_backend: cache_dit` with its own set of knobs (see `CacheDiTConfig`). +The `teacache_thresh` parameter controls the similarity threshold. For Wan 2.2, set both `coefficients` and `coefficients_2` (YAML or CLI). For LTX-2, set `coefficients` when enabling TeaCache (no built-in table). Other models (e.g. FLUX.1, FLUX.2, Wan 2.1) can omit `coefficients` to use the built-in checkpoint table. Cache-DiT is also supported via `cache_backend: cache_dit` with its own set of knobs (see `CacheDiTConfig`). ### Multi-GPU Parallelism diff --git a/examples/visual_gen/README.md b/examples/visual_gen/README.md index c38656472b35..709c0839eb02 100644 --- a/examples/visual_gen/README.md +++ b/examples/visual_gen/README.md @@ -33,3 +33,250 @@ python models/flux2.py --visual_gen_args configs/flux2-dev-fp4-1gpu.yaml Install deps from the repo root: `pip install -r requirements-dev.txt`. Output: `.png` for image models; `.mp4` for video models when FFmpeg is installed (otherwise `.avi`). + +## FLUX (Text-to-Image) + +### Basic Usage + +**FLUX.1:** + +```bash +python visual_gen_flux.py \ + --model_path black-forest-labs/FLUX.1-dev \ + --prompt "A cat sitting on a windowsill" \ + --height 1024 --width 1024 \ + --guidance_scale 3.5 \ + --output_path output.png +``` + +**With FP8 quantization:** + +```bash +python visual_gen_flux.py \ + --model_path black-forest-labs/FLUX.2-dev \ + --prompt "A cat sitting on a windowsill" \ + --linear_type trtllm-fp8-per-tensor \ + --output_path output_fp8.png +``` + +**Batch mode (multiple prompts from file):** + +```bash +python visual_gen_flux.py \ + --model_path black-forest-labs/FLUX.1-dev \ + --prompts_file prompts.txt \ + --output_dir results/ --seed 42 +``` + + +## WAN (Text-to-Video) + +### Basic Usage + +**Single GPU:** + +```bash +python visual_gen_wan_t2v.py \ + --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ + --prompt "A cute cat playing piano" \ + --height 480 --width 832 --num_frames 33 \ + --output_path output.mp4 +``` + +**With TeaCache:** +```bash +python visual_gen_wan_t2v.py \ + --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ + --prompt "A cute cat playing piano" \ + --height 480 --width 832 --num_frames 33 \ + --enable_teacache \ + --output_path output.mp4 +``` + +### Multi-GPU Parallelism + +WAN supports two parallelism modes that can be combined: +- **CFG Parallelism**: Split positive/negative prompts across GPUs +- **Ulysses Parallelism**: Split sequence across GPUs for longer sequences + + +**Ulysses Only (2 GPUs):** +```bash +python visual_gen_wan_t2v.py \ + --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ + --prompt "A cute cat playing piano" \ + --height 480 --width 832 --num_frames 33 \ + --attention_backend TRTLLM \ + --cfg_size 1 --ulysses_size 2 \ + --output_path output.mp4 +``` +GPU Layout: GPU 0-1 share sequence (6 heads each) + +**CFG Only (2 GPUs):** +```bash +python visual_gen_wan_t2v.py \ + --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ + --prompt "A cute cat playing piano" \ + --height 480 --width 832 --num_frames 33 \ + --attention_backend TRTLLM \ + --cfg_size 2 --ulysses_size 1 \ + --output_path output.mp4 +``` +GPU Layout: GPU 0 (positive) | GPU 1 (negative) + +**CFG + Ulysses (4 GPUs):** +```bash +python visual_gen_wan_t2v.py \ + --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ + --prompt "A cute cat playing piano" \ + --height 480 --width 832 --num_frames 33 \ + --attention_backend TRTLLM \ + --cfg_size 2 --ulysses_size 2 \ + --output_path output.mp4 +``` +GPU Layout: GPU 0-1 (positive, Ulysses) | GPU 2-3 (negative, Ulysses) + +**Large-Scale (8 GPUs):** +```bash +python visual_gen_wan_t2v.py \ + --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ + --prompt "A cute cat playing piano" \ + --height 480 --width 832 --num_frames 33 \ + --attention_backend TRTLLM \ + --cfg_size 2 --ulysses_size 4 \ + --output_path output.mp4 +``` + + +## WAN (Image-to-Video) + +```bash +python visual_gen_wan_i2v.py \ + --model_path Wan-AI/Wan2.1-I2V-14B-480P-Diffusers \ + --image_path input_image.jpg \ + --prompt "She turns around and smiles" \ + --height 480 --width 832 --num_frames 81 \ + --output_path output_i2v.mp4 +``` + + +## LTX2 (Text/Image-to-Video with Audio) + +LTX2 generates video **with audio** from text prompts or input images. +It uses a Gemma3 text encoder (provided separately via `--text_encoder_path`) +and supports BF16, FP8, and FP4 precision checkpoints. + +Please refer to tensorrt_llm/_torch/visual_gen/models/ltx2/LTX_2_CHECKPOINT_FORMAT.md for model checkpoint info. + +### Basic Usage + +**Text-to-Video (single GPU):** +```bash +python visual_gen_ltx2.py \ + --model_path ${MODEL_ROOT}/LTX-2-checkpoint/ \ + --text_encoder_path ${MODEL_ROOT}/gemma-3-12b-it \ + --prompt "A cute cat playing piano" \ + --height 720 --width 1280 --num_frames 121 \ + --steps 40 --guidance_scale 4.0 --seed 42 \ + --output_path output_t2v.mp4 +``` + +**Image-to-Video:** +```bash +python visual_gen_ltx2.py \ + --model_path ${MODEL_ROOT}/LTX-2-checkpoint/ \ + --text_encoder_path ${MODEL_ROOT}/gemma-3-12b-it \ + --prompt "A cute cat playing piano" \ + --image ${PROJECT_ROOT}/examples/visual_gen/cat_piano.png \ + --image_cond_strength 1.0 \ + --height 720 --width 1280 --num_frames 121 \ + --steps 40 --seed 42 \ + --output_path output_i2v.mp4 +``` + +### Precision Variants + +LTX2 ships checkpoints at three precision levels. Simply point `--model_path` at the +appropriate directory: + +```bash +# FP8 +python visual_gen_ltx2.py \ + --model_path ${MODEL_ROOT}/LTX-2-checkpoint/fp8/ \ + --text_encoder_path ${MODEL_ROOT}/gemma-3-12b-it \ + --prompt "A cute cat playing piano" \ + --height 720 --width 1280 --num_frames 121 \ + --output_path output_fp8.mp4 + +# FP4 +python visual_gen_ltx2.py \ + --model_path ${MODEL_ROOT}/LTX-2-checkpoint/fp4/ \ + --text_encoder_path ${MODEL_ROOT}/gemma-3-12b-it \ + --prompt "A cute cat playing piano" \ + --height 512 --width 768 --num_frames 121 \ + --output_path output_fp4.mp4 +``` + +--- + +## Common Arguments + +| Argument | FLUX | WAN | LTX2 | Default | Description | +|----------|------|-----|------|---------|-------------| +| `--model_path` | ✓ | ✓ | — | Path to model checkpoint directory | +| `--text_encoder_path` | — | ✓ | — | Path to Gemma3 text encoder | +| `--prompt` | ✓ | ✓ | — | Text prompt for generation | +| `--negative_prompt` | — | ✓ | *(built-in)* | Negative prompt | +| `--height` | ✓ | ✓ | ✓ | 1024 / 720 | Output height | +| `--width` | ✓ | ✓ | ✓ | 1024 / 1280 | Output width | +| `--num_frames` | — | ✓ | ✓ | 81 / 121 | Number of frames | +| `--frame_rate` | — | ✓ | 24.0 | Output frame rate (fps) | +| `--steps` | ✓ | ✓ | ✓ | 50 / 40 | Denoising steps | +| `--guidance_scale` | ✓ | ✓ | ✓ | 3.5 / 5.0 / 4.0 | Guidance strength | +| `--seed` | ✓ | ✓ | ✓ | 42 | Random seed | +| `--image` | — | ✓ | None | Input image for image-to-video | +| `--image_cond_strength` | — | ✓ | 1.0 | Image conditioning strength | +| `--enable_teacache` | ✓ | ✓ | — | False | Cache optimization | +| `--teacache_thresh` | ✓ | ✓ | — | 0.2 | TeaCache similarity threshold | +| `--teacache_coefficients` | ✓ | ✓ | — | *(omit)* | Optional polynomial coeffs; overrides built-in table | +| `--use_ret_steps` | ✓ | ✓ | — | False | TeaCache retention-steps mode (WAN/FLUX tables) | +| `--attention_backend` | ✓ | ✓ | — | VANILLA | `VANILLA`, `TRTLLM`, or `FA4` | +| `--cfg_size` | — | ✓ | — | 1 | CFG parallelism | +| `--ulysses_size` | ✓ | ✓ | — | 1 | Sequence parallelism | +| `--linear_type` | ✓ | ✓ | — | default | Quantization type | +| `--enhance_prompt` | — | ✓ | False | Gemma3 prompt enhancement | +| `--stg_scale` | — | ✓ | 0.0 | Spatiotemporal guidance scale | +| `--modality_scale` | — | ✓ | 1.0 | Cross-modal guidance scale | +| `--rescale_scale` | — | ✓ | 0.0 | Variance-preserving rescale factor | + +## Troubleshooting + +**Out of Memory:** +- Use quantization: `--linear_type trtllm-fp8-blockwise` (WAN) or `--linear_type trtllm-fp8-per-tensor` (FLUX) +- Reduce resolution or frames +- Enable TeaCache: `--enable_teacache` +- Use Ulysses parallelism with more GPUs + +**Slow Inference:** +- Enable TeaCache: `--enable_teacache` +- Use TRTLLM backend: `--attention_backend TRTLLM` +- Use multi-GPU: `--cfg_size 2` or `--ulysses_size 2` + +**Import Errors:** +- Run from repository root +- Install necessary dependencies, e.g., `pip install -r requirements-dev.txt` + +**Ulysses Errors:** +- `ulysses_size` must divide the model's head count (12 for WAN) +- Total GPUs = `cfg_size × ulysses_size` +- Sequence length must be divisible by `ulysses_size` + +## Output Formats + +- **FLUX**: `.png` (image) +- **WAN**: `.mp4` if FFmpeg is installed, otherwise `.avi` (video) +- **LTX2**: `.mp4` (video with audio) if FFmpeg is installed, otherwise `.avi` (video) + +## Serving + +See [`serve/README.md`](serve/README.md) for `trtllm-serve` examples including image generation (FLUX), video generation (WAN T2V/I2V), and API endpoint reference. diff --git a/examples/visual_gen/serve/README.md b/examples/visual_gen/serve/README.md index a979767b2a6a..1aaf04e5bb92 100644 --- a/examples/visual_gen/serve/README.md +++ b/examples/visual_gen/serve/README.md @@ -51,6 +51,8 @@ Before running these examples, ensure you have: ``` For LTX-2, you need to provide a proper text_encoder_path in `./configs/ltx2.yml`. + **TeaCache:** Example YAML files set `enable_teacache` and `teacache_thresh` only. Omit `coefficients` to use each pipeline’s **built-in** coefficient table (checkpoint path matching). Add `coefficients: [ ... ]` under `teacache` only when you need to override those defaults. + ## Examples Current supported & tested models: diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py index 19db372cd310..aeba393c735f 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py @@ -304,7 +304,7 @@ def load_weights(self, weights: dict) -> None: def post_load_weights(self) -> None: """Post-load setup: TeaCache registration.""" super().post_load_weights() - if self.transformer is not None: + if self.transformer is not None and self.model_config.cache_backend == "teacache": # Register TeaCache extractor for FLUX.2 (must be after device placement) # Only set guidance_param_name for variants with guidance_embeds guidance_param = "guidance" if self.transformer.guidance_embeds else None @@ -327,7 +327,6 @@ def post_load_weights(self) -> None: ) ) - # TeaCache or Cache-DiT self._setup_cache_acceleration(self.transformer, FLUX2_TEACACHE_COEFFICIENTS) @property diff --git a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py index aa1fb75c360b..c62e8b5141c3 100644 --- a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py +++ b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py @@ -16,7 +16,7 @@ from transformers import Gemma3ForConditionalGeneration, GemmaTokenizerFast from tensorrt_llm._torch.utils import make_weak_ref -from tensorrt_llm._torch.visual_gen.cache.teacache import CacheContext +from tensorrt_llm._torch.visual_gen.cache.teacache import CacheContext, register_extractor from tensorrt_llm._torch.visual_gen.checkpoints.prefetch import prefetch_files_to_host_cache from tensorrt_llm._torch.visual_gen.cuda_graph_runner import CUDAGraphRunner, CUDAGraphRunnerConfig from tensorrt_llm._torch.visual_gen.output import CudaPhaseTimer, PipelineOutput @@ -1012,13 +1012,18 @@ def post_load_weights(self) -> None: """Finalize after weight loading: TeaCache, Cache-DiT, derived attributes.""" super().post_load_weights() - # TODO: TeaCache disabled: LTX2_TEACACHE_COEFFICIENTS are unverified. - # To re-enable, uncomment the following lines and verify coefficients. - # register_extractor( - # "LTXModel", - # LTX2TeaCacheExtractor(self._compute_ltx2_timestep_embedding), - # ) - # self._setup_teacache(self.transformer, coefficients=LTX2_TEACACHE_COEFFICIENTS) + # LTX-2: single transformer (one DiT for video+audio); TeaCache only with explicit coefficients. + if self.transformer is not None and self.pipeline_config.cache_backend == "teacache": + if self.pipeline_config.teacache.coefficients is None: + raise ValueError( + "TeaCache on LTX-2 requires explicit teacache.coefficients " + "(no built-in coefficient table)." + ) + register_extractor( + "LTXModel", + LTX2TeaCacheExtractor(self._compute_ltx2_timestep_embedding), + ) + self._setup_cache_acceleration(self.transformer, coefficients=None) # Cache-DiT if self.transformer is not None and self.pipeline_config.cache_backend == "cache_dit": diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py index 73036ee8ad4a..fbfff8b9d315 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py @@ -11,6 +11,7 @@ from tensorrt_llm._torch.visual_gen.cache.teacache import ( ExtractorConfig, + TeaCacheBackend, register_extractor_from_config, ) from tensorrt_llm._torch.visual_gen.models.wan.defaults import ( @@ -101,13 +102,6 @@ def __init__(self, pipeline_config): self.is_wan22_14b = self.boundary_ratio is not None self.is_wan22_5b = self.expand_timesteps - # Validate TeaCache compatibility before allocating GPU memory - if (self.is_wan22_14b or self.is_wan22_5b) and pipeline_config.cache_backend == "teacache": - raise ValueError( - "TeaCache is not supported for Wan 2.2 models. " - "Use cache_backend='none' or 'cache_dit' (not 'teacache')." - ) - super().__init__(pipeline_config) def _compute_wan_timestep_embedding(self, module, timestep=None, **kwargs): @@ -322,13 +316,42 @@ def post_load_weights(self) -> None: else: if self.pipeline_config.cache_backend == "cache_dit": self._setup_cache_acceleration(self.transformer, coefficients=None) - # TeaCache is not supported for Wan 2.2 unless using Cache-DiT. self.transformer_cache_backend = self.cache_accelerator if self.transformer_2 is not None: if hasattr(self.transformer_2, "post_load_weights"): self.transformer_2.post_load_weights() + # Wan 2.2 TeaCache after both transformers' post_load_weights (FP8 scales, etc.) + if ( + self.transformer is not None + and self.transformer_2 is not None + and self.is_wan22 + and self.pipeline_config.cache_backend == "teacache" + ): + self._apply_teacache_coefficients(WAN_TEACACHE_COEFFICIENTS) + tc = self.pipeline_config.teacache + if tc.coefficients is None or tc.coefficients_2 is None: + raise ValueError( + "Wan 2.2 TeaCache requires explicit teacache.coefficients and " + "teacache.coefficients_2 (high-noise and low-noise stage polynomials). " + "There is no built-in coefficient table for Wan 2.2." + ) + cfg_high = tc.model_copy(deep=True) + cfg_low = tc.model_copy(deep=True) + cfg_low.coefficients = tc.coefficients_2 + logger.info("TeaCache: Initializing (Wan 2.2 high-noise transformer)...") + self.cache_backend = TeaCacheBackend(cfg_high) + self.cache_backend.enable(self.transformer) + self.transformer_cache_backend = self.cache_backend + logger.info("TeaCache: Initializing (Wan 2.2 low-noise transformer_2)...") + self.transformer_2_cache_backend = TeaCacheBackend(cfg_low) + self.transformer_2_cache_backend.enable(self.transformer_2) + self._teacache_backends = [ + self.cache_backend, + self.transformer_2_cache_backend, + ] + def _run_warmup(self, height: int, width: int, num_frames: int, steps: int) -> None: with torch.no_grad(): self.forward( diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py index bb27e8b017d0..eae6a22f0e2a 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py @@ -13,6 +13,7 @@ from tensorrt_llm._torch.visual_gen.cache.teacache import ( ExtractorConfig, + TeaCacheBackend, register_extractor_from_config, ) from tensorrt_llm._torch.visual_gen.models.wan.defaults import ( @@ -96,13 +97,6 @@ def __init__(self, pipeline_config): ) self.is_wan22_14b = self.boundary_ratio is not None - # Validate TeaCache compatibility before allocating GPU memory - if self.is_wan22_14b and pipeline_config.cache_backend == "teacache": - raise ValueError( - "TeaCache is not supported for Wan 2.2 models. " - "Use cache_backend='none' or 'cache_dit' (not 'teacache')." - ) - super().__init__(pipeline_config) def _compute_wan_timestep_embedding(self, module, timestep=None, **kwargs): @@ -348,6 +342,35 @@ def post_load_weights(self) -> None: if hasattr(self.transformer_2, "post_load_weights"): self.transformer_2.post_load_weights() + if ( + self.transformer is not None + and self.transformer_2 is not None + and self.is_wan22 + and self.pipeline_config.cache_backend == "teacache" + ): + self._apply_teacache_coefficients(WAN_I2V_TEACACHE_COEFFICIENTS) + tc = self.pipeline_config.teacache + if tc.coefficients is None or tc.coefficients_2 is None: + raise ValueError( + "Wan 2.2 TeaCache requires explicit teacache.coefficients and " + "teacache.coefficients_2 (high-noise and low-noise stage polynomials). " + "There is no built-in coefficient table for Wan 2.2." + ) + cfg_high = tc.model_copy(deep=True) + cfg_low = tc.model_copy(deep=True) + cfg_low.coefficients = tc.coefficients_2 + logger.info("TeaCache: Initializing (Wan 2.2 I2V high-noise transformer)...") + self.cache_backend = TeaCacheBackend(cfg_high) + self.cache_backend.enable(self.transformer) + self.transformer_cache_backend = self.cache_backend + logger.info("TeaCache: Initializing (Wan 2.2 I2V low-noise transformer_2)...") + self.transformer_2_cache_backend = TeaCacheBackend(cfg_low) + self.transformer_2_cache_backend.enable(self.transformer_2) + self._teacache_backends = [ + self.cache_backend, + self.transformer_2_cache_backend, + ] + def _run_warmup(self, height: int, width: int, num_frames: int, steps: int) -> None: dummy_image = PIL.Image.new("RGB", (width, height)) with torch.no_grad(): diff --git a/tensorrt_llm/_torch/visual_gen/pipeline.py b/tensorrt_llm/_torch/visual_gen/pipeline.py index ee4e10a18dbb..151087519ca0 100644 --- a/tensorrt_llm/_torch/visual_gen/pipeline.py +++ b/tensorrt_llm/_torch/visual_gen/pipeline.py @@ -138,6 +138,8 @@ def __init__(self, pipeline_config: "DiffusionPipelineConfig"): # Unified cache acceleration (TeaCache, Cache-DiT); see _setup_cache_acceleration self.cache_accelerator: Optional["CacheAccelerator"] = None + # Wan 2.2 manual TeaCacheBackend pair; see WanPipeline.post_load_weights + self._teacache_backends: List[Any] = [] # Components self.transformer: Optional[nn.Module] = None @@ -524,6 +526,12 @@ def _setup_cache_acceleration( if acc.is_enabled(): self.cache_accelerator = acc + def _refresh_teacache_backends(self, total_steps: int) -> None: + """Reset manual TeaCache backends (e.g. Wan 2.2 dual transformers).""" + for backend in self._teacache_backends: + if backend is not None and backend.is_enabled(): + backend.refresh(total_steps) + def setup_parallel_vae(self): """Enable parallel-VAE decode mode and wrap the VAE on participating ranks. @@ -1039,6 +1047,7 @@ def denoise( # Reset cache acceleration state for new generation (TeaCache / Cache-DiT) if getattr(self, "cache_accelerator", None) and self.cache_accelerator.is_enabled(): self.cache_accelerator.refresh(total_steps) + self._refresh_teacache_backends(total_steps) if self.rank == 0: if has_extra_streams: diff --git a/tensorrt_llm/visual_gen/args.py b/tensorrt_llm/visual_gen/args.py index 82d64408ddf0..a705b9b2e17a 100644 --- a/tensorrt_llm/visual_gen/args.py +++ b/tensorrt_llm/visual_gen/args.py @@ -315,11 +315,22 @@ class TeaCacheConfig(BaseCacheConfig): ), ) + coefficients_2: Optional[List[float]] = Field( + default=None, + status="prototype", + description=( + "Second polynomial (Wan 2.2 dual-transformer low-noise stage only). " + "Required together with coefficients when enabling TeaCache on Wan 2.2." + ), + ) + @model_validator(mode="after") def validate_teacache(self) -> "TeaCacheConfig": """Validate TeaCache configuration.""" if len(self.coefficients) == 0: raise ValueError("TeaCache coefficients list cannot be empty") + if self.coefficients_2 is not None and len(self.coefficients_2) == 0: + raise ValueError("TeaCache coefficients_2 list cannot be empty") return self diff --git a/tests/unittest/_torch/visual_gen/test_teacache.py b/tests/unittest/_torch/visual_gen/test_teacache.py index 9cd0408a86d2..e0fc749108b8 100644 --- a/tests/unittest/_torch/visual_gen/test_teacache.py +++ b/tests/unittest/_torch/visual_gen/test_teacache.py @@ -241,10 +241,11 @@ def test_apply_teacache_coefficients_only(self): def test_nested_table_missing_requested_mode_warns_then_raises_if_no_fallback(self): """Variant matches path but nested dict lacks 'standard' / 'ret_steps' entry for mode.""" pipeline = self._make_pipeline_mock("FLUX.1-dev") + # Only ret_steps; standard mode requested (use_ret_steps=False) coefficients = {"dev": {"ret_steps": [9.0, 8.0]}} - with patch( - "tensorrt_llm._torch.visual_gen.pipeline.logger.warning" - ) as mock_warning: + # tensorrt_llm.logger does not propagate to the root logger, so pytest caplog + # does not see these records; assert via the pipeline module's logger.warning. + with patch("tensorrt_llm._torch.visual_gen.pipeline.logger.warning") as mock_warning: with pytest.raises(ValueError, match="No coefficients found"): BasePipeline._apply_teacache_coefficients(pipeline, coefficients) mock_warning.assert_called_once() @@ -259,6 +260,41 @@ def test_empty_coefficients_rejected(self): with pytest.raises(ValueError, match="cannot be empty"): TeaCacheConfig(coefficients=[]) + def test_empty_coefficients_2_rejected(self): + with pytest.raises(ValueError, match="coefficients_2"): + TeaCacheConfig(coefficients_2=[]) + + +class TestRefreshTeacacheBackends: + """BasePipeline._refresh_teacache_backends walks _teacache_backends only.""" + + def test_single_backend_refresh(self): + shared = MagicMock() + shared.is_enabled.return_value = True + pipe = SimpleNamespace(_teacache_backends=[shared]) + BasePipeline._refresh_teacache_backends(pipe, 50) + shared.refresh.assert_called_once_with(50) + + def test_refreshes_two_distinct_backends(self): + b1 = MagicMock() + b1.is_enabled.return_value = True + b2 = MagicMock() + b2.is_enabled.return_value = True + pipe = SimpleNamespace(_teacache_backends=[b1, b2]) + BasePipeline._refresh_teacache_backends(pipe, 30) + b1.refresh.assert_called_once_with(30) + b2.refresh.assert_called_once_with(30) + + def test_skips_disabled_backends(self): + active = MagicMock() + active.is_enabled.return_value = True + disabled = MagicMock() + disabled.is_enabled.return_value = False + pipe = SimpleNamespace(_teacache_backends=[active, disabled]) + BasePipeline._refresh_teacache_backends(pipe, 10) + active.refresh.assert_called_once_with(10) + disabled.refresh.assert_not_called() + class TestFlux2TeacacheTable: """FLUX.2 built-in coefficient table (dev variant).""" @@ -283,3 +319,136 @@ def test_flux2_dev_variant_resolves_from_checkpoint_path(self): ) assert pipeline.model_config.teacache.coefficients is not None assert len(pipeline.model_config.teacache.coefficients) >= 2 + + +class TestExplicitTeaCacheCoefficientsRequired: + """LTX-2 and Wan 2.2 have no built-in TeaCache tables; teacache requires user coefficients.""" + + def test_ltx2_raises_when_teacache_enabled_without_coefficients(self): + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + + pipe = object.__new__(LTX2Pipeline) + pipe.transformer = MagicMock() + pipe.model_config = DiffusionModelConfig( + pretrained_config=SimpleNamespace(), + cache=TeaCacheConfig(coefficients=None), + skip_create_weights_in_init=True, + ) + with patch.object(BasePipeline, "post_load_weights", lambda self: None): + with pytest.raises(ValueError, match="LTX-2 requires explicit teacache.coefficients"): + LTX2Pipeline.post_load_weights(pipe) + + def test_ltx2_succeeds_with_explicit_coefficients(self): + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + + pipe = object.__new__(LTX2Pipeline) + pipe.transformer = MagicMock() + pipe.model_config = DiffusionModelConfig( + pretrained_config=SimpleNamespace(), + cache=TeaCacheConfig( + coefficients=[1.0, 2.0, 3.0], + ), + skip_create_weights_in_init=True, + ) + with patch.object(BasePipeline, "post_load_weights", lambda self: None): + with patch( + "tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2.register_extractor" + ): + with patch.object(TeaCacheBackend, "enable"): + LTX2Pipeline.post_load_weights(pipe) + assert pipe.cache_accelerator is not None + + def test_wan22_raises_when_teacache_enabled_without_both_coefficient_lists(self): + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan import WanPipeline + + pipe = object.__new__(WanPipeline) + pipe.transformer = MagicMock() + pipe.transformer_2 = MagicMock() + pipe.is_wan22 = True + pipe.model_config = DiffusionModelConfig( + pretrained_config=SimpleNamespace( + _name_or_path="/models/wan/snapshot", boundary_ratio=0.2 + ), + cache=TeaCacheConfig( + coefficients=None, + coefficients_2=None, + ), + skip_create_weights_in_init=True, + ) + with patch.object(BasePipeline, "post_load_weights", lambda self: None): + with patch( + "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan.register_extractor_from_config" + ): + with patch.object( + WanPipeline, + "_apply_teacache_coefficients", + lambda self, coefficients: None, + ): + with pytest.raises(ValueError, match="Wan 2.2 TeaCache requires explicit"): + WanPipeline.post_load_weights(pipe) + + def test_wan22_i2v_raises_when_teacache_enabled_without_both_coefficient_lists(self): + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan_i2v import ( + WanImageToVideoPipeline, + ) + + pipe = object.__new__(WanImageToVideoPipeline) + pipe.transformer = MagicMock() + pipe.transformer_2 = MagicMock() + pipe.is_wan22 = True + pipe.model_config = DiffusionModelConfig( + pretrained_config=SimpleNamespace( + _name_or_path="/models/wan/snapshot", boundary_ratio=0.2 + ), + cache=TeaCacheConfig( + coefficients=None, + coefficients_2=None, + ), + skip_create_weights_in_init=True, + ) + with patch.object(BasePipeline, "post_load_weights", lambda self: None): + with patch( + "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan_i2v.register_extractor_from_config" + ): + with patch.object( + WanImageToVideoPipeline, + "_apply_teacache_coefficients", + lambda self, coefficients: None, + ): + with pytest.raises(ValueError, match="Wan 2.2 TeaCache requires explicit"): + WanImageToVideoPipeline.post_load_weights(pipe) + + def test_wan22_t2v_installs_two_teacache_backends_when_coefficients_provided(self): + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan import WanPipeline + + pipe = object.__new__(WanPipeline) + pipe.transformer = MagicMock() + pipe.transformer_2 = MagicMock() + pipe.is_wan22 = True + pipe.model_config = DiffusionModelConfig( + pretrained_config=SimpleNamespace(_name_or_path="/wan/snapshot", boundary_ratio=0.2), + cache=TeaCacheConfig( + coefficients=[1.0, 2.0], + coefficients_2=[3.0, 4.0], + ), + skip_create_weights_in_init=True, + ) + mock_enable = MagicMock() + backend_a = MagicMock() + backend_a.enable = mock_enable + backend_b = MagicMock() + backend_b.enable = mock_enable + with patch.object(BasePipeline, "post_load_weights", lambda self: None): + with patch( + "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan.register_extractor_from_config" + ): + with patch( + "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan.TeaCacheBackend" + ) as TB: + TB.side_effect = [backend_a, backend_b] + WanPipeline.post_load_weights(pipe) + assert TB.call_count == 2 + assert mock_enable.call_count == 2 + assert pipe.transformer_cache_backend is pipe.cache_backend + assert pipe.transformer_2_cache_backend is not None + assert pipe.transformer_2_cache_backend is not pipe.cache_backend From 55bbcecd53ca7e0d4670785139199ff92138e550 Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Tue, 21 Apr 2026 13:56:46 -0700 Subject: [PATCH 03/11] fix cache accelerator bugs causing failing tests Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py | 2 +- tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py | 2 +- tensorrt_llm/_torch/visual_gen/pipeline.py | 3 +-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py index fbfff8b9d315..f05dc56c00e8 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py @@ -316,7 +316,7 @@ def post_load_weights(self) -> None: else: if self.pipeline_config.cache_backend == "cache_dit": self._setup_cache_acceleration(self.transformer, coefficients=None) - self.transformer_cache_backend = self.cache_accelerator + self.transformer_cache_backend = self.cache_accelerator if self.transformer_2 is not None: if hasattr(self.transformer_2, "post_load_weights"): diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py index eae6a22f0e2a..7b66ba2d4730 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py @@ -336,7 +336,7 @@ def post_load_weights(self) -> None: else: if self.pipeline_config.cache_backend == "cache_dit": self._setup_cache_acceleration(self.transformer, coefficients=None) - self.transformer_cache_backend = self.cache_accelerator + self.transformer_cache_backend = self.cache_accelerator if self.transformer_2 is not None: if hasattr(self.transformer_2, "post_load_weights"): diff --git a/tensorrt_llm/_torch/visual_gen/pipeline.py b/tensorrt_llm/_torch/visual_gen/pipeline.py index 151087519ca0..527f9fd2eaed 100644 --- a/tensorrt_llm/_torch/visual_gen/pipeline.py +++ b/tensorrt_llm/_torch/visual_gen/pipeline.py @@ -523,8 +523,7 @@ def _setup_cache_acceleration( acc = TeaCacheAccelerator(cfg.teacache) acc.wrap(model=model) - if acc.is_enabled(): - self.cache_accelerator = acc + self.cache_accelerator = acc def _refresh_teacache_backends(self, total_steps: int) -> None: """Reset manual TeaCache backends (e.g. Wan 2.2 dual transformers).""" From 609d1ecf9a4cdba79589edabeec283b21ef57232 Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Wed, 22 Apr 2026 17:12:00 -0700 Subject: [PATCH 04/11] small fixes Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py | 5 ++--- .../_torch/visual_gen/models/wan/pipeline_wan_i2v.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py index f05dc56c00e8..852336ca97a0 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py @@ -337,9 +337,8 @@ def post_load_weights(self) -> None: "teacache.coefficients_2 (high-noise and low-noise stage polynomials). " "There is no built-in coefficient table for Wan 2.2." ) - cfg_high = tc.model_copy(deep=True) - cfg_low = tc.model_copy(deep=True) - cfg_low.coefficients = tc.coefficients_2 + cfg_high = tc.model_copy() + cfg_low = tc.model_copy(update={"coefficients": tc.coefficients_2}) logger.info("TeaCache: Initializing (Wan 2.2 high-noise transformer)...") self.cache_backend = TeaCacheBackend(cfg_high) self.cache_backend.enable(self.transformer) diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py index 7b66ba2d4730..d84299acb4b7 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py @@ -356,9 +356,8 @@ def post_load_weights(self) -> None: "teacache.coefficients_2 (high-noise and low-noise stage polynomials). " "There is no built-in coefficient table for Wan 2.2." ) - cfg_high = tc.model_copy(deep=True) - cfg_low = tc.model_copy(deep=True) - cfg_low.coefficients = tc.coefficients_2 + cfg_high = tc.model_copy() + cfg_low = tc.model_copy(update={"coefficients": tc.coefficients_2}) logger.info("TeaCache: Initializing (Wan 2.2 I2V high-noise transformer)...") self.cache_backend = TeaCacheBackend(cfg_high) self.cache_backend.enable(self.transformer) From 7b1542f116a390736070c592dec9331e6ab86a22 Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Fri, 24 Apr 2026 17:00:18 -0700 Subject: [PATCH 05/11] small fixes/code cleanup Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- .../visual_gen/models/flux/pipeline_flux2.py | 45 +++++----- .../_torch/visual_gen/test_teacache.py | 82 +++++++++++++++---- 2 files changed, 91 insertions(+), 36 deletions(-) diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py index aeba393c735f..cc7027f183b2 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py @@ -302,30 +302,31 @@ def load_weights(self, weights: dict) -> None: self.transformer.eval() def post_load_weights(self) -> None: - """Post-load setup: TeaCache registration.""" + """Post-load setup: cache acceleration (TeaCache or Cache-DiT).""" super().post_load_weights() - if self.transformer is not None and self.model_config.cache_backend == "teacache": - # Register TeaCache extractor for FLUX.2 (must be after device placement) - # Only set guidance_param_name for variants with guidance_embeds - guidance_param = "guidance" if self.transformer.guidance_embeds else None - forward_params = [ - "hidden_states", - "encoder_hidden_states", - "timestep", - "img_ids", - "txt_ids", - "guidance", - "return_dict", - ] - register_extractor_from_config( - ExtractorConfig( - model_class_name="Flux2Transformer2DModel", - timestep_embed_fn=self._compute_flux2_timestep_embedding, - guidance_param_name=guidance_param, - forward_params=forward_params, - return_dict_default=False, + if self.transformer is not None: + if self.model_config.cache_backend == "teacache": + # Register TeaCache extractor for FLUX.2 (must be after device placement) + # Only set guidance_param_name for variants with guidance_embeds + guidance_param = "guidance" if self.transformer.guidance_embeds else None + forward_params = [ + "hidden_states", + "encoder_hidden_states", + "timestep", + "img_ids", + "txt_ids", + "guidance", + "return_dict", + ] + register_extractor_from_config( + ExtractorConfig( + model_class_name="Flux2Transformer2DModel", + timestep_embed_fn=self._compute_flux2_timestep_embedding, + guidance_param_name=guidance_param, + forward_params=forward_params, + return_dict_default=False, + ) ) - ) self._setup_cache_acceleration(self.transformer, FLUX2_TEACACHE_COEFFICIENTS) diff --git a/tests/unittest/_torch/visual_gen/test_teacache.py b/tests/unittest/_torch/visual_gen/test_teacache.py index e0fc749108b8..105eb3fdb825 100644 --- a/tests/unittest/_torch/visual_gen/test_teacache.py +++ b/tests/unittest/_torch/visual_gen/test_teacache.py @@ -379,13 +379,8 @@ def test_wan22_raises_when_teacache_enabled_without_both_coefficient_lists(self) with patch( "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan.register_extractor_from_config" ): - with patch.object( - WanPipeline, - "_apply_teacache_coefficients", - lambda self, coefficients: None, - ): - with pytest.raises(ValueError, match="Wan 2.2 TeaCache requires explicit"): - WanPipeline.post_load_weights(pipe) + with pytest.raises(ValueError, match="Wan 2.2 TeaCache requires explicit"): + WanPipeline.post_load_weights(pipe) def test_wan22_i2v_raises_when_teacache_enabled_without_both_coefficient_lists(self): from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan_i2v import ( @@ -410,13 +405,8 @@ def test_wan22_i2v_raises_when_teacache_enabled_without_both_coefficient_lists(s with patch( "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan_i2v.register_extractor_from_config" ): - with patch.object( - WanImageToVideoPipeline, - "_apply_teacache_coefficients", - lambda self, coefficients: None, - ): - with pytest.raises(ValueError, match="Wan 2.2 TeaCache requires explicit"): - WanImageToVideoPipeline.post_load_weights(pipe) + with pytest.raises(ValueError, match="Wan 2.2 TeaCache requires explicit"): + WanImageToVideoPipeline.post_load_weights(pipe) def test_wan22_t2v_installs_two_teacache_backends_when_coefficients_provided(self): from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan import WanPipeline @@ -452,3 +442,67 @@ def test_wan22_t2v_installs_two_teacache_backends_when_coefficients_provided(sel assert pipe.transformer_cache_backend is pipe.cache_backend assert pipe.transformer_2_cache_backend is not None assert pipe.transformer_2_cache_backend is not pipe.cache_backend + + def test_wan22_t2v_transformer_gets_coefficients_and_transformer_2_gets_coefficients_2(self): + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan import WanPipeline + + pipe = object.__new__(WanPipeline) + pipe.transformer = MagicMock() + pipe.transformer_2 = MagicMock() + pipe.is_wan22 = True + pipe.model_config = DiffusionModelConfig( + pretrained_config=SimpleNamespace(_name_or_path="/wan/snapshot", boundary_ratio=0.2), + cache=TeaCacheConfig( + coefficients=[1.0, 2.0], + coefficients_2=[3.0, 4.0], + ), + skip_create_weights_in_init=True, + ) + with patch.object(BasePipeline, "post_load_weights", lambda self: None): + with patch( + "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan.register_extractor_from_config" + ): + with patch( + "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan.TeaCacheBackend" + ) as TB: + TB.return_value = MagicMock() + WanPipeline.post_load_weights(pipe) + + assert TB.call_count == 2 + cfg_high = TB.call_args_list[0][0][0] + cfg_low = TB.call_args_list[1][0][0] + assert cfg_high.coefficients == [1.0, 2.0] + assert cfg_low.coefficients == [3.0, 4.0] + + def test_wan22_i2v_transformer_gets_coefficients_and_transformer_2_gets_coefficients_2(self): + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan_i2v import ( + WanImageToVideoPipeline, + ) + + pipe = object.__new__(WanImageToVideoPipeline) + pipe.transformer = MagicMock() + pipe.transformer_2 = MagicMock() + pipe.is_wan22 = True + pipe.model_config = DiffusionModelConfig( + pretrained_config=SimpleNamespace(_name_or_path="/wan/snapshot", boundary_ratio=0.2), + cache=TeaCacheConfig( + coefficients=[5.0, 6.0], + coefficients_2=[7.0, 8.0], + ), + skip_create_weights_in_init=True, + ) + with patch.object(BasePipeline, "post_load_weights", lambda self: None): + with patch( + "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan_i2v.register_extractor_from_config" + ): + with patch( + "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan_i2v.TeaCacheBackend" + ) as TB: + TB.return_value = MagicMock() + WanImageToVideoPipeline.post_load_weights(pipe) + + assert TB.call_count == 2 + cfg_high = TB.call_args_list[0][0][0] + cfg_low = TB.call_args_list[1][0][0] + assert cfg_high.coefficients == [5.0, 6.0] + assert cfg_low.coefficients == [7.0, 8.0] From 9b13f39fb5fab36dc1d1b349a4055773c55fed31 Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:23:59 -0700 Subject: [PATCH 06/11] add pipeline-level test, small rebase fixes Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- .../visual_gen/models/flux/pipeline_flux2.py | 2 +- tensorrt_llm/visual_gen/args.py | 10 +- .../test_lists/test-db/l0_b200.yml | 1 + .../_torch/visual_gen/test_teacache.py | 16 +- ...st_wan21_t2v_teacache_user_coefficients.py | 166 ++++++++++++++++++ 5 files changed, 187 insertions(+), 8 deletions(-) create mode 100644 tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache_user_coefficients.py diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py index cc7027f183b2..a71ee912cef1 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py @@ -305,7 +305,7 @@ def post_load_weights(self) -> None: """Post-load setup: cache acceleration (TeaCache or Cache-DiT).""" super().post_load_weights() if self.transformer is not None: - if self.model_config.cache_backend == "teacache": + if self.pipeline_config.cache_backend == "teacache": # Register TeaCache extractor for FLUX.2 (must be after device placement) # Only set guidance_param_name for variants with guidance_embeds guidance_param = "guidance" if self.transformer.guidance_embeds else None diff --git a/tensorrt_llm/visual_gen/args.py b/tensorrt_llm/visual_gen/args.py index a705b9b2e17a..6520866e8202 100644 --- a/tensorrt_llm/visual_gen/args.py +++ b/tensorrt_llm/visual_gen/args.py @@ -305,13 +305,13 @@ class TeaCacheConfig(BaseCacheConfig): teacache_thresh: float = Field(0.2, gt=0.0, status="prototype") use_ret_steps: bool = Field(False, status="prototype") - coefficients: List[float] = Field( - default_factory=lambda: [1.0, 0.0], + coefficients: Optional[List[float]] = Field( + default=None, status="prototype", description=( "Polynomial coefficients used by the TeaCache decision function. " - "Variable-length (FLUX uses 5, Wan uses 4); the pipeline overrides " - "this per-checkpoint at load time." + "None (default) uses the pipeline's built-in per-checkpoint table; " + "an explicit list overrides the table entirely." ), ) @@ -327,7 +327,7 @@ class TeaCacheConfig(BaseCacheConfig): @model_validator(mode="after") def validate_teacache(self) -> "TeaCacheConfig": """Validate TeaCache configuration.""" - if len(self.coefficients) == 0: + if self.coefficients is not None and len(self.coefficients) == 0: raise ValueError("TeaCache coefficients list cannot be empty") if self.coefficients_2 is not None and len(self.coefficients_2) == 0: raise ValueError("TeaCache coefficients_2 list cannot be empty") diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index e9571ad9724a..cbf124e5f6cf 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -214,6 +214,7 @@ l0_b200: - unittest/_torch/visual_gen/test_wan22_ti2v_5b_pipeline.py - unittest/_torch/visual_gen/test_wan21_i2v_teacache.py - unittest/_torch/visual_gen/test_wan21_t2v_teacache.py + - unittest/_torch/visual_gen/test_wan21_t2v_teacache_user_coefficients.py - unittest/_torch/visual_gen/test_wan_transformer.py - unittest/_torch/visual_gen/test_cosmos3_transformer.py - unittest/_torch/visual_gen/test_cosmos3_pipeline.py diff --git a/tests/unittest/_torch/visual_gen/test_teacache.py b/tests/unittest/_torch/visual_gen/test_teacache.py index 105eb3fdb825..ce3c720f419e 100644 --- a/tests/unittest/_torch/visual_gen/test_teacache.py +++ b/tests/unittest/_torch/visual_gen/test_teacache.py @@ -334,6 +334,7 @@ def test_ltx2_raises_when_teacache_enabled_without_coefficients(self): cache=TeaCacheConfig(coefficients=None), skip_create_weights_in_init=True, ) + pipe.pipeline_config = _PipelineConfigShim(pipe.model_config) with patch.object(BasePipeline, "post_load_weights", lambda self: None): with pytest.raises(ValueError, match="LTX-2 requires explicit teacache.coefficients"): LTX2Pipeline.post_load_weights(pipe) @@ -350,6 +351,7 @@ def test_ltx2_succeeds_with_explicit_coefficients(self): ), skip_create_weights_in_init=True, ) + pipe.pipeline_config = _PipelineConfigShim(pipe.model_config) with patch.object(BasePipeline, "post_load_weights", lambda self: None): with patch( "tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2.register_extractor" @@ -365,9 +367,10 @@ def test_wan22_raises_when_teacache_enabled_without_both_coefficient_lists(self) pipe.transformer = MagicMock() pipe.transformer_2 = MagicMock() pipe.is_wan22 = True + pipe.is_wan22_14b = True pipe.model_config = DiffusionModelConfig( pretrained_config=SimpleNamespace( - _name_or_path="/models/wan/snapshot", boundary_ratio=0.2 + _name_or_path="/models/wan14b/snapshot", boundary_ratio=0.2 ), cache=TeaCacheConfig( coefficients=None, @@ -375,6 +378,7 @@ def test_wan22_raises_when_teacache_enabled_without_both_coefficient_lists(self) ), skip_create_weights_in_init=True, ) + pipe.pipeline_config = _PipelineConfigShim(pipe.model_config) with patch.object(BasePipeline, "post_load_weights", lambda self: None): with patch( "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan.register_extractor_from_config" @@ -391,9 +395,10 @@ def test_wan22_i2v_raises_when_teacache_enabled_without_both_coefficient_lists(s pipe.transformer = MagicMock() pipe.transformer_2 = MagicMock() pipe.is_wan22 = True + pipe.is_wan22_14b = True pipe.model_config = DiffusionModelConfig( pretrained_config=SimpleNamespace( - _name_or_path="/models/wan/snapshot", boundary_ratio=0.2 + _name_or_path="/models/wan720p/snapshot", boundary_ratio=0.2 ), cache=TeaCacheConfig( coefficients=None, @@ -401,6 +406,7 @@ def test_wan22_i2v_raises_when_teacache_enabled_without_both_coefficient_lists(s ), skip_create_weights_in_init=True, ) + pipe.pipeline_config = _PipelineConfigShim(pipe.model_config) with patch.object(BasePipeline, "post_load_weights", lambda self: None): with patch( "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan_i2v.register_extractor_from_config" @@ -415,6 +421,7 @@ def test_wan22_t2v_installs_two_teacache_backends_when_coefficients_provided(sel pipe.transformer = MagicMock() pipe.transformer_2 = MagicMock() pipe.is_wan22 = True + pipe.is_wan22_14b = True pipe.model_config = DiffusionModelConfig( pretrained_config=SimpleNamespace(_name_or_path="/wan/snapshot", boundary_ratio=0.2), cache=TeaCacheConfig( @@ -423,6 +430,7 @@ def test_wan22_t2v_installs_two_teacache_backends_when_coefficients_provided(sel ), skip_create_weights_in_init=True, ) + pipe.pipeline_config = _PipelineConfigShim(pipe.model_config) mock_enable = MagicMock() backend_a = MagicMock() backend_a.enable = mock_enable @@ -450,6 +458,7 @@ def test_wan22_t2v_transformer_gets_coefficients_and_transformer_2_gets_coeffici pipe.transformer = MagicMock() pipe.transformer_2 = MagicMock() pipe.is_wan22 = True + pipe.is_wan22_14b = True pipe.model_config = DiffusionModelConfig( pretrained_config=SimpleNamespace(_name_or_path="/wan/snapshot", boundary_ratio=0.2), cache=TeaCacheConfig( @@ -458,6 +467,7 @@ def test_wan22_t2v_transformer_gets_coefficients_and_transformer_2_gets_coeffici ), skip_create_weights_in_init=True, ) + pipe.pipeline_config = _PipelineConfigShim(pipe.model_config) with patch.object(BasePipeline, "post_load_weights", lambda self: None): with patch( "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan.register_extractor_from_config" @@ -483,6 +493,7 @@ def test_wan22_i2v_transformer_gets_coefficients_and_transformer_2_gets_coeffici pipe.transformer = MagicMock() pipe.transformer_2 = MagicMock() pipe.is_wan22 = True + pipe.is_wan22_14b = True pipe.model_config = DiffusionModelConfig( pretrained_config=SimpleNamespace(_name_or_path="/wan/snapshot", boundary_ratio=0.2), cache=TeaCacheConfig( @@ -491,6 +502,7 @@ def test_wan22_i2v_transformer_gets_coefficients_and_transformer_2_gets_coeffici ), skip_create_weights_in_init=True, ) + pipe.pipeline_config = _PipelineConfigShim(pipe.model_config) with patch.object(BasePipeline, "post_load_weights", lambda self: None): with patch( "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan_i2v.register_extractor_from_config" diff --git a/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache_user_coefficients.py b/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache_user_coefficients.py new file mode 100644 index 000000000000..2dd1904dfac4 --- /dev/null +++ b/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache_user_coefficients.py @@ -0,0 +1,166 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Confirm that user-configured TeaCache coefficients affect the cache rate on Wan 2.1 1.3B. + +Loads model weights and runs two forward passes back-to-back, each with a +different user-supplied TeaCacheConfig.coefficients list, then prints and compares +the cached step counts to confirm the user override is respected. + +Run: + pytest tests/unittest/_torch/visual_gen/test_wan13b_teacache_coefficients.py -v -s + +Override checkpoint: + DIFFUSION_MODEL_PATH_WAN21_1_3B=/path/to/weights \\ + pytest tests/unittest/_torch/visual_gen/test_wan13b_teacache_coefficients.py -v -s +""" + +import gc +import os + +os.environ["TLLM_DISABLE_MPI"] = "1" + +from pathlib import Path + +import pytest +import torch + +from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader +from tensorrt_llm.visual_gen.args import TeaCacheConfig, VisualGenArgs + + +@pytest.fixture(autouse=True, scope="module") +def _cleanup_mpi_env(): + yield + os.environ.pop("TLLM_DISABLE_MPI", None) + + +@pytest.fixture(autouse=True) +def _cleanup_gpu(): + gc.collect() + torch.cuda.empty_cache() + yield + gc.collect() + torch.cuda.empty_cache() + + +def _llm_models_root() -> Path: + if "LLM_MODELS_ROOT" in os.environ: + root = Path(os.environ["LLM_MODELS_ROOT"]) + else: + root = Path("/home/scratch.trt_llm_data_ci/llm-models/") + if not root.exists(): + root = Path("/scratch.trt_llm_data/llm-models/") + assert root.exists(), ( + "Set LLM_MODELS_ROOT or ensure /home/scratch.trt_llm_data_ci/llm-models/ is accessible." + ) + return root + + +WAN21_1_3B_PATH = os.environ.get( + "DIFFUSION_MODEL_PATH_WAN21_1_3B", + str(_llm_models_root() / "Wan2.1-T2V-1.3B-Diffusers"), +) + +PROMPT = "a cat sitting on a windowsill" +HEIGHT, WIDTH = 480, 832 +NUM_FRAMES = 33 # (33-1)/4+1 = 9 latent frames; smallest realistic shape +NUM_STEPS = 50 # enough steps for coefficients to produce meaningful cache hits +SEED = 42 + +# Two user-supplied coefficient lists passed explicitly via TeaCacheConfig.coefficients, +# bypassing the built-in variant-lookup table entirely, testing the user-override code path. + +COEFFICIENTS_CALIBRATED = [ + 2.39676752e03, + -1.31110545e03, + 2.01331979e02, + -8.29855975e00, + 1.37887774e-01, +] + +COEFFICIENTS_IDENTITY_LINEAR = [ + 1.0, + 0.0, +] + + +def _run_forward(coefficients: list, thresh: float, label: str) -> dict: + """Load the pipeline with the given user-supplied coefficients, run one forward pass.""" + if not os.path.exists(WAN21_1_3B_PATH): + pytest.skip(f"Checkpoint not found: {WAN21_1_3B_PATH}") + + args = VisualGenArgs( + model=WAN21_1_3B_PATH, + cache_config=TeaCacheConfig( + teacache_thresh=thresh, + coefficients=coefficients, + ), + ) + pipeline = PipelineLoader(args).load(skip_warmup=True) + try: + with torch.no_grad(): + pipeline.forward( + prompt=PROMPT, + negative_prompt="", + height=HEIGHT, + width=WIDTH, + num_frames=NUM_FRAMES, + num_inference_steps=NUM_STEPS, + seed=SEED, + ) + stats = pipeline.transformer_cache_backend.get_stats() + finally: + del pipeline + gc.collect() + torch.cuda.empty_cache() + + print( + f" {label:<30s} cached={stats['cached_steps']:>3}/{stats['total_steps']} " + f"({stats['hit_rate']:.1%} cache rate)" + ) + return stats + + +@pytest.mark.integration +@pytest.mark.wan_t2v +@pytest.mark.teacache +class TestWan13BUserCoefficientsAffectCacheRate: + """User-supplied TeaCacheConfig.coefficients override the built-in table and change cache rate.""" + + def test_different_user_coefficients_produce_different_cache_rates(self): + print(f"\n {'coefficient set':<30s} {'cached / total':>15} hit rate") + print(f" {'-' * 65}") + + stats_a = _run_forward( + COEFFICIENTS_CALIBRATED, + thresh=0.2, + label="calibrated 1.3B standard", + ) + stats_b = _run_forward( + COEFFICIENTS_IDENTITY_LINEAR, + thresh=0.2, + label="identity linear [1, 0]", + ) + + print(f" {'-' * 65}") + print( + f" difference in cached steps: " + f"{abs(stats_a['cached_steps'] - stats_b['cached_steps'])}" + ) + + assert stats_a["cached_steps"] != stats_b["cached_steps"], ( + f"Both coefficient sets produced {stats_a['cached_steps']} cached steps — " + f"expected user-configured coefficients to produce different cache rates." + ) From 3d39a4092d96e05a0baf75af08c19c3baa42c01d Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:31:44 -0700 Subject: [PATCH 07/11] update README Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- examples/visual_gen/README.md | 247 ---------------------------- examples/visual_gen/serve/README.md | 1 - 2 files changed, 248 deletions(-) diff --git a/examples/visual_gen/README.md b/examples/visual_gen/README.md index 709c0839eb02..c38656472b35 100644 --- a/examples/visual_gen/README.md +++ b/examples/visual_gen/README.md @@ -33,250 +33,3 @@ python models/flux2.py --visual_gen_args configs/flux2-dev-fp4-1gpu.yaml Install deps from the repo root: `pip install -r requirements-dev.txt`. Output: `.png` for image models; `.mp4` for video models when FFmpeg is installed (otherwise `.avi`). - -## FLUX (Text-to-Image) - -### Basic Usage - -**FLUX.1:** - -```bash -python visual_gen_flux.py \ - --model_path black-forest-labs/FLUX.1-dev \ - --prompt "A cat sitting on a windowsill" \ - --height 1024 --width 1024 \ - --guidance_scale 3.5 \ - --output_path output.png -``` - -**With FP8 quantization:** - -```bash -python visual_gen_flux.py \ - --model_path black-forest-labs/FLUX.2-dev \ - --prompt "A cat sitting on a windowsill" \ - --linear_type trtllm-fp8-per-tensor \ - --output_path output_fp8.png -``` - -**Batch mode (multiple prompts from file):** - -```bash -python visual_gen_flux.py \ - --model_path black-forest-labs/FLUX.1-dev \ - --prompts_file prompts.txt \ - --output_dir results/ --seed 42 -``` - - -## WAN (Text-to-Video) - -### Basic Usage - -**Single GPU:** - -```bash -python visual_gen_wan_t2v.py \ - --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ - --prompt "A cute cat playing piano" \ - --height 480 --width 832 --num_frames 33 \ - --output_path output.mp4 -``` - -**With TeaCache:** -```bash -python visual_gen_wan_t2v.py \ - --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ - --prompt "A cute cat playing piano" \ - --height 480 --width 832 --num_frames 33 \ - --enable_teacache \ - --output_path output.mp4 -``` - -### Multi-GPU Parallelism - -WAN supports two parallelism modes that can be combined: -- **CFG Parallelism**: Split positive/negative prompts across GPUs -- **Ulysses Parallelism**: Split sequence across GPUs for longer sequences - - -**Ulysses Only (2 GPUs):** -```bash -python visual_gen_wan_t2v.py \ - --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ - --prompt "A cute cat playing piano" \ - --height 480 --width 832 --num_frames 33 \ - --attention_backend TRTLLM \ - --cfg_size 1 --ulysses_size 2 \ - --output_path output.mp4 -``` -GPU Layout: GPU 0-1 share sequence (6 heads each) - -**CFG Only (2 GPUs):** -```bash -python visual_gen_wan_t2v.py \ - --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ - --prompt "A cute cat playing piano" \ - --height 480 --width 832 --num_frames 33 \ - --attention_backend TRTLLM \ - --cfg_size 2 --ulysses_size 1 \ - --output_path output.mp4 -``` -GPU Layout: GPU 0 (positive) | GPU 1 (negative) - -**CFG + Ulysses (4 GPUs):** -```bash -python visual_gen_wan_t2v.py \ - --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ - --prompt "A cute cat playing piano" \ - --height 480 --width 832 --num_frames 33 \ - --attention_backend TRTLLM \ - --cfg_size 2 --ulysses_size 2 \ - --output_path output.mp4 -``` -GPU Layout: GPU 0-1 (positive, Ulysses) | GPU 2-3 (negative, Ulysses) - -**Large-Scale (8 GPUs):** -```bash -python visual_gen_wan_t2v.py \ - --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ - --prompt "A cute cat playing piano" \ - --height 480 --width 832 --num_frames 33 \ - --attention_backend TRTLLM \ - --cfg_size 2 --ulysses_size 4 \ - --output_path output.mp4 -``` - - -## WAN (Image-to-Video) - -```bash -python visual_gen_wan_i2v.py \ - --model_path Wan-AI/Wan2.1-I2V-14B-480P-Diffusers \ - --image_path input_image.jpg \ - --prompt "She turns around and smiles" \ - --height 480 --width 832 --num_frames 81 \ - --output_path output_i2v.mp4 -``` - - -## LTX2 (Text/Image-to-Video with Audio) - -LTX2 generates video **with audio** from text prompts or input images. -It uses a Gemma3 text encoder (provided separately via `--text_encoder_path`) -and supports BF16, FP8, and FP4 precision checkpoints. - -Please refer to tensorrt_llm/_torch/visual_gen/models/ltx2/LTX_2_CHECKPOINT_FORMAT.md for model checkpoint info. - -### Basic Usage - -**Text-to-Video (single GPU):** -```bash -python visual_gen_ltx2.py \ - --model_path ${MODEL_ROOT}/LTX-2-checkpoint/ \ - --text_encoder_path ${MODEL_ROOT}/gemma-3-12b-it \ - --prompt "A cute cat playing piano" \ - --height 720 --width 1280 --num_frames 121 \ - --steps 40 --guidance_scale 4.0 --seed 42 \ - --output_path output_t2v.mp4 -``` - -**Image-to-Video:** -```bash -python visual_gen_ltx2.py \ - --model_path ${MODEL_ROOT}/LTX-2-checkpoint/ \ - --text_encoder_path ${MODEL_ROOT}/gemma-3-12b-it \ - --prompt "A cute cat playing piano" \ - --image ${PROJECT_ROOT}/examples/visual_gen/cat_piano.png \ - --image_cond_strength 1.0 \ - --height 720 --width 1280 --num_frames 121 \ - --steps 40 --seed 42 \ - --output_path output_i2v.mp4 -``` - -### Precision Variants - -LTX2 ships checkpoints at three precision levels. Simply point `--model_path` at the -appropriate directory: - -```bash -# FP8 -python visual_gen_ltx2.py \ - --model_path ${MODEL_ROOT}/LTX-2-checkpoint/fp8/ \ - --text_encoder_path ${MODEL_ROOT}/gemma-3-12b-it \ - --prompt "A cute cat playing piano" \ - --height 720 --width 1280 --num_frames 121 \ - --output_path output_fp8.mp4 - -# FP4 -python visual_gen_ltx2.py \ - --model_path ${MODEL_ROOT}/LTX-2-checkpoint/fp4/ \ - --text_encoder_path ${MODEL_ROOT}/gemma-3-12b-it \ - --prompt "A cute cat playing piano" \ - --height 512 --width 768 --num_frames 121 \ - --output_path output_fp4.mp4 -``` - ---- - -## Common Arguments - -| Argument | FLUX | WAN | LTX2 | Default | Description | -|----------|------|-----|------|---------|-------------| -| `--model_path` | ✓ | ✓ | — | Path to model checkpoint directory | -| `--text_encoder_path` | — | ✓ | — | Path to Gemma3 text encoder | -| `--prompt` | ✓ | ✓ | — | Text prompt for generation | -| `--negative_prompt` | — | ✓ | *(built-in)* | Negative prompt | -| `--height` | ✓ | ✓ | ✓ | 1024 / 720 | Output height | -| `--width` | ✓ | ✓ | ✓ | 1024 / 1280 | Output width | -| `--num_frames` | — | ✓ | ✓ | 81 / 121 | Number of frames | -| `--frame_rate` | — | ✓ | 24.0 | Output frame rate (fps) | -| `--steps` | ✓ | ✓ | ✓ | 50 / 40 | Denoising steps | -| `--guidance_scale` | ✓ | ✓ | ✓ | 3.5 / 5.0 / 4.0 | Guidance strength | -| `--seed` | ✓ | ✓ | ✓ | 42 | Random seed | -| `--image` | — | ✓ | None | Input image for image-to-video | -| `--image_cond_strength` | — | ✓ | 1.0 | Image conditioning strength | -| `--enable_teacache` | ✓ | ✓ | — | False | Cache optimization | -| `--teacache_thresh` | ✓ | ✓ | — | 0.2 | TeaCache similarity threshold | -| `--teacache_coefficients` | ✓ | ✓ | — | *(omit)* | Optional polynomial coeffs; overrides built-in table | -| `--use_ret_steps` | ✓ | ✓ | — | False | TeaCache retention-steps mode (WAN/FLUX tables) | -| `--attention_backend` | ✓ | ✓ | — | VANILLA | `VANILLA`, `TRTLLM`, or `FA4` | -| `--cfg_size` | — | ✓ | — | 1 | CFG parallelism | -| `--ulysses_size` | ✓ | ✓ | — | 1 | Sequence parallelism | -| `--linear_type` | ✓ | ✓ | — | default | Quantization type | -| `--enhance_prompt` | — | ✓ | False | Gemma3 prompt enhancement | -| `--stg_scale` | — | ✓ | 0.0 | Spatiotemporal guidance scale | -| `--modality_scale` | — | ✓ | 1.0 | Cross-modal guidance scale | -| `--rescale_scale` | — | ✓ | 0.0 | Variance-preserving rescale factor | - -## Troubleshooting - -**Out of Memory:** -- Use quantization: `--linear_type trtllm-fp8-blockwise` (WAN) or `--linear_type trtllm-fp8-per-tensor` (FLUX) -- Reduce resolution or frames -- Enable TeaCache: `--enable_teacache` -- Use Ulysses parallelism with more GPUs - -**Slow Inference:** -- Enable TeaCache: `--enable_teacache` -- Use TRTLLM backend: `--attention_backend TRTLLM` -- Use multi-GPU: `--cfg_size 2` or `--ulysses_size 2` - -**Import Errors:** -- Run from repository root -- Install necessary dependencies, e.g., `pip install -r requirements-dev.txt` - -**Ulysses Errors:** -- `ulysses_size` must divide the model's head count (12 for WAN) -- Total GPUs = `cfg_size × ulysses_size` -- Sequence length must be divisible by `ulysses_size` - -## Output Formats - -- **FLUX**: `.png` (image) -- **WAN**: `.mp4` if FFmpeg is installed, otherwise `.avi` (video) -- **LTX2**: `.mp4` (video with audio) if FFmpeg is installed, otherwise `.avi` (video) - -## Serving - -See [`serve/README.md`](serve/README.md) for `trtllm-serve` examples including image generation (FLUX), video generation (WAN T2V/I2V), and API endpoint reference. diff --git a/examples/visual_gen/serve/README.md b/examples/visual_gen/serve/README.md index 1aaf04e5bb92..1b83cb9c431a 100644 --- a/examples/visual_gen/serve/README.md +++ b/examples/visual_gen/serve/README.md @@ -51,7 +51,6 @@ Before running these examples, ensure you have: ``` For LTX-2, you need to provide a proper text_encoder_path in `./configs/ltx2.yml`. - **TeaCache:** Example YAML files set `enable_teacache` and `teacache_thresh` only. Omit `coefficients` to use each pipeline’s **built-in** coefficient table (checkpoint path matching). Add `coefficients: [ ... ]` under `teacache` only when you need to override those defaults. ## Examples From e5a7e5a649793e9b1447f688bbc6c77abfd31422 Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:36:21 -0700 Subject: [PATCH 08/11] remove newline Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- examples/visual_gen/serve/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/visual_gen/serve/README.md b/examples/visual_gen/serve/README.md index 1b83cb9c431a..a979767b2a6a 100644 --- a/examples/visual_gen/serve/README.md +++ b/examples/visual_gen/serve/README.md @@ -51,7 +51,6 @@ Before running these examples, ensure you have: ``` For LTX-2, you need to provide a proper text_encoder_path in `./configs/ltx2.yml`. - ## Examples Current supported & tested models: From d7a061cafc5a9fa0a4dc0fb4a4b93d4ba59eeda4 Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Thu, 11 Jun 2026 10:26:24 -0700 Subject: [PATCH 09/11] remove deprecated is_wan22 references Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py | 1 - tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py | 1 - tensorrt_llm/_torch/visual_gen/pipeline.py | 4 +++- tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py | 2 +- tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py index 852336ca97a0..e8b0bec18a29 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py @@ -326,7 +326,6 @@ def post_load_weights(self) -> None: if ( self.transformer is not None and self.transformer_2 is not None - and self.is_wan22 and self.pipeline_config.cache_backend == "teacache" ): self._apply_teacache_coefficients(WAN_TEACACHE_COEFFICIENTS) diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py index d84299acb4b7..c3549baa6f60 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py @@ -345,7 +345,6 @@ def post_load_weights(self) -> None: if ( self.transformer is not None and self.transformer_2 is not None - and self.is_wan22 and self.pipeline_config.cache_backend == "teacache" ): self._apply_teacache_coefficients(WAN_I2V_TEACACHE_COEFFICIENTS) diff --git a/tensorrt_llm/_torch/visual_gen/pipeline.py b/tensorrt_llm/_torch/visual_gen/pipeline.py index 527f9fd2eaed..febed93ba647 100644 --- a/tensorrt_llm/_torch/visual_gen/pipeline.py +++ b/tensorrt_llm/_torch/visual_gen/pipeline.py @@ -443,7 +443,9 @@ def _apply_teacache_coefficients(self, coefficients: Optional[Dict] = None) -> N if not coefficients: return - checkpoint_path = getattr(self.pipeline_config.primary_pretrained_config, "_name_or_path", "") or "" + checkpoint_path = ( + getattr(self.pipeline_config.primary_pretrained_config, "_name_or_path", "") or "" + ) for model_size, coeff_data in coefficients.items(): # Match model size in path (case-insensitive, e.g., "1.3B", "14B", "dev") diff --git a/tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py b/tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py index 75cc07ade90b..700c896dec42 100644 --- a/tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py +++ b/tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py @@ -252,5 +252,5 @@ def test_wan22_raises_if_teacache_enabled(self): model=WAN22_I2V_A14B_PATH, cache_config=TeaCacheConfig(), ) - with pytest.raises(ValueError, match="TeaCache is not supported for Wan 2\\.2"): + with pytest.raises(ValueError, match=r"Wan 2\.2 TeaCache requires explicit"): PipelineLoader(args).load(skip_warmup=True) diff --git a/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py b/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py index 8d3b5f14f178..4854928ee86c 100644 --- a/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py +++ b/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py @@ -234,5 +234,5 @@ def test_wan22_raises_if_teacache_enabled(self): model=WAN22_A14B_PATH, cache_config=TeaCacheConfig(), ) - with pytest.raises(ValueError, match=r"TeaCache is not supported for Wan 2\.2"): + with pytest.raises(ValueError, match=r"Wan 2\.2 TeaCache requires explicit"): PipelineLoader(args).load(skip_warmup=True) From 52815c0915bb995d1698fb97eef774a2422d11c3 Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Mon, 15 Jun 2026 09:26:33 -0700 Subject: [PATCH 10/11] address code comments + small refactor Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- .../visual_gen/cache/teacache_accelerator.py | 67 +++--- .../visual_gen/models/flux/pipeline_flux.py | 3 +- .../visual_gen/models/flux/pipeline_flux2.py | 3 +- .../visual_gen/models/ltx2/pipeline_ltx2.py | 4 +- .../visual_gen/models/wan/pipeline_wan.py | 25 +-- .../visual_gen/models/wan/pipeline_wan_i2v.py | 25 +-- tensorrt_llm/_torch/visual_gen/pipeline.py | 69 +++--- tensorrt_llm/visual_gen/args.py | 4 + .../_torch/visual_gen/test_teacache.py | 89 ++++---- .../visual_gen/test_wan21_i2v_teacache.py | 2 +- .../visual_gen/test_wan21_t2v_teacache.py | 2 +- ...st_wan21_t2v_teacache_user_coefficients.py | 2 +- .../visual_gen/test_wan22_i2v_teacache.py | 200 ++++++++++++++++++ .../visual_gen/test_wan22_t2v_teacache.py | 184 ++++++++++++++++ 14 files changed, 510 insertions(+), 169 deletions(-) create mode 100644 tests/unittest/_torch/visual_gen/test_wan22_i2v_teacache.py create mode 100644 tests/unittest/_torch/visual_gen/test_wan22_t2v_teacache.py diff --git a/tensorrt_llm/_torch/visual_gen/cache/teacache_accelerator.py b/tensorrt_llm/_torch/visual_gen/cache/teacache_accelerator.py index 273247e3209d..5119d275c52a 100644 --- a/tensorrt_llm/_torch/visual_gen/cache/teacache_accelerator.py +++ b/tensorrt_llm/_torch/visual_gen/cache/teacache_accelerator.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Tuple import torch.nn as nn @@ -17,52 +17,55 @@ class TeaCacheAccelerator(CacheAccelerator): - """Whole-transformer skip using existing :class:`TeaCacheBackend`.""" + """Whole-transformer step skipping via :class:`TeaCacheBackend`.""" - def __init__(self, teacache_cfg: TeaCacheConfig): + def __init__(self, pipeline: Any, teacache_cfg: TeaCacheConfig): + self._pipeline = pipeline self._cfg = teacache_cfg - self._backend: Optional[TeaCacheBackend] = None - self._module: Optional[nn.Module] = None + self._backends: List[Tuple[nn.Module, TeaCacheBackend]] = [] def wrap(self, *args: Any, **kwargs: Any) -> None: - """Enable TeaCache on ``model``. - - Coefficient matching is done on the pipeline via - :meth:`BasePipeline._apply_teacache_coefficients` before this runs. - - Keyword args: ``model`` (required). - """ - model: Optional[nn.Module] = kwargs.get("model") - if model is None: + if self._backends: return - if self._backend is not None: + transformer = getattr(self._pipeline, "transformer", None) + if transformer is None: return - logger.info("TeaCache: Initializing...") - self._backend = TeaCacheBackend(self._cfg) - self._backend.enable(model) - self._module = model + transformer_2 = getattr(self._pipeline, "transformer_2", None) + if transformer_2 is not None and self._cfg.coefficients_2 is not None: + cfg_low = self._cfg.model_copy(update={"coefficients": self._cfg.coefficients_2}) + logger.info("TeaCache: Initializing (high-noise transformer)...") + backend_high = TeaCacheBackend(self._cfg) + backend_high.enable(transformer) + logger.info("TeaCache: Initializing (low-noise transformer_2)...") + backend_low = TeaCacheBackend(cfg_low) + backend_low.enable(transformer_2) + self._backends = [(transformer, backend_high), (transformer_2, backend_low)] + else: + logger.info("TeaCache: Initializing...") + backend = TeaCacheBackend(self._cfg) + backend.enable(transformer) + self._backends = [(transformer, backend)] def unwrap(self) -> None: - if self._backend is None or self._module is None: - return - self._backend.disable(self._module) - self._backend = None - self._module = None + for module, backend in self._backends: + backend.disable(module) + self._backends = [] def refresh(self, num_inference_steps: int) -> None: - if self._backend: - self._backend.refresh(num_inference_steps) + for _, backend in self._backends: + backend.refresh(num_inference_steps) def get_stats(self) -> Dict[str, Any]: - if not self._backend: + if not self._backends: return {} - return self._backend.get_stats() or {} + if len(self._backends) == 1: + return self._backends[0][1].get_stats() or {} + return {f"transformer_{i}": b.get_stats() for i, (_, b) in enumerate(self._backends)} def is_enabled(self) -> bool: - return bool(self._backend and self._backend.is_enabled()) + return bool(self._backends and any(b.is_enabled() for _, b in self._backends)) @property - def tea_cache_backend(self) -> Optional[TeaCacheBackend]: - """The wrapped TeaCache hook owner (for introspection or advanced use).""" - return self._backend + def backends(self) -> List[Tuple[nn.Module, TeaCacheBackend]]: + return self._backends diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py index cf2dd468a6df..0ad1566f0fd5 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py @@ -235,7 +235,8 @@ def post_load_weights(self) -> None: ) # TeaCache or Cache-DiT - self._setup_cache_acceleration(self.transformer, FLUX_TEACACHE_COEFFICIENTS) + self._apply_teacache_coefficients(FLUX_TEACACHE_COEFFICIENTS) + self._setup_cache_acceleration() @property def default_generation_params(self): diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py index a71ee912cef1..21105dc344fb 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py @@ -328,7 +328,8 @@ def post_load_weights(self) -> None: ) ) - self._setup_cache_acceleration(self.transformer, FLUX2_TEACACHE_COEFFICIENTS) + self._apply_teacache_coefficients(FLUX2_TEACACHE_COEFFICIENTS) + self._setup_cache_acceleration() @property def default_generation_params(self): diff --git a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py index c62e8b5141c3..d0330c79930e 100644 --- a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py +++ b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py @@ -1023,11 +1023,11 @@ def post_load_weights(self) -> None: "LTXModel", LTX2TeaCacheExtractor(self._compute_ltx2_timestep_embedding), ) - self._setup_cache_acceleration(self.transformer, coefficients=None) + self._setup_cache_acceleration() # Cache-DiT if self.transformer is not None and self.pipeline_config.cache_backend == "cache_dit": - self._setup_cache_acceleration(self.transformer, coefficients=None) + self._setup_cache_acceleration() # Compression ratios from native scale factors self.vae_spatial_compression_ratio = VIDEO_SCALE_FACTORS.width diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py index e8b0bec18a29..0fb5658c308a 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py @@ -11,7 +11,6 @@ from tensorrt_llm._torch.visual_gen.cache.teacache import ( ExtractorConfig, - TeaCacheBackend, register_extractor_from_config, ) from tensorrt_llm._torch.visual_gen.models.wan.defaults import ( @@ -309,14 +308,11 @@ def post_load_weights(self) -> None: ) if not self.is_wan22_14b: - self._setup_cache_acceleration( - self.transformer, coefficients=WAN_TEACACHE_COEFFICIENTS - ) - self.transformer_cache_backend = self.cache_accelerator + self._apply_teacache_coefficients(WAN_TEACACHE_COEFFICIENTS) + self._setup_cache_acceleration() else: if self.pipeline_config.cache_backend == "cache_dit": - self._setup_cache_acceleration(self.transformer, coefficients=None) - self.transformer_cache_backend = self.cache_accelerator + self._setup_cache_acceleration() if self.transformer_2 is not None: if hasattr(self.transformer_2, "post_load_weights"): @@ -328,7 +324,6 @@ def post_load_weights(self) -> None: and self.transformer_2 is not None and self.pipeline_config.cache_backend == "teacache" ): - self._apply_teacache_coefficients(WAN_TEACACHE_COEFFICIENTS) tc = self.pipeline_config.teacache if tc.coefficients is None or tc.coefficients_2 is None: raise ValueError( @@ -336,19 +331,7 @@ def post_load_weights(self) -> None: "teacache.coefficients_2 (high-noise and low-noise stage polynomials). " "There is no built-in coefficient table for Wan 2.2." ) - cfg_high = tc.model_copy() - cfg_low = tc.model_copy(update={"coefficients": tc.coefficients_2}) - logger.info("TeaCache: Initializing (Wan 2.2 high-noise transformer)...") - self.cache_backend = TeaCacheBackend(cfg_high) - self.cache_backend.enable(self.transformer) - self.transformer_cache_backend = self.cache_backend - logger.info("TeaCache: Initializing (Wan 2.2 low-noise transformer_2)...") - self.transformer_2_cache_backend = TeaCacheBackend(cfg_low) - self.transformer_2_cache_backend.enable(self.transformer_2) - self._teacache_backends = [ - self.cache_backend, - self.transformer_2_cache_backend, - ] + self._setup_cache_acceleration() def _run_warmup(self, height: int, width: int, num_frames: int, steps: int) -> None: with torch.no_grad(): diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py index c3549baa6f60..d00da652e5a9 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py @@ -13,7 +13,6 @@ from tensorrt_llm._torch.visual_gen.cache.teacache import ( ExtractorConfig, - TeaCacheBackend, register_extractor_from_config, ) from tensorrt_llm._torch.visual_gen.models.wan.defaults import ( @@ -329,14 +328,11 @@ def post_load_weights(self) -> None: ) if not self.is_wan22_14b: - self._setup_cache_acceleration( - self.transformer, coefficients=WAN_I2V_TEACACHE_COEFFICIENTS - ) - self.transformer_cache_backend = self.cache_accelerator + self._apply_teacache_coefficients(WAN_I2V_TEACACHE_COEFFICIENTS) + self._setup_cache_acceleration() else: if self.pipeline_config.cache_backend == "cache_dit": - self._setup_cache_acceleration(self.transformer, coefficients=None) - self.transformer_cache_backend = self.cache_accelerator + self._setup_cache_acceleration() if self.transformer_2 is not None: if hasattr(self.transformer_2, "post_load_weights"): @@ -347,7 +343,6 @@ def post_load_weights(self) -> None: and self.transformer_2 is not None and self.pipeline_config.cache_backend == "teacache" ): - self._apply_teacache_coefficients(WAN_I2V_TEACACHE_COEFFICIENTS) tc = self.pipeline_config.teacache if tc.coefficients is None or tc.coefficients_2 is None: raise ValueError( @@ -355,19 +350,7 @@ def post_load_weights(self) -> None: "teacache.coefficients_2 (high-noise and low-noise stage polynomials). " "There is no built-in coefficient table for Wan 2.2." ) - cfg_high = tc.model_copy() - cfg_low = tc.model_copy(update={"coefficients": tc.coefficients_2}) - logger.info("TeaCache: Initializing (Wan 2.2 I2V high-noise transformer)...") - self.cache_backend = TeaCacheBackend(cfg_high) - self.cache_backend.enable(self.transformer) - self.transformer_cache_backend = self.cache_backend - logger.info("TeaCache: Initializing (Wan 2.2 I2V low-noise transformer_2)...") - self.transformer_2_cache_backend = TeaCacheBackend(cfg_low) - self.transformer_2_cache_backend.enable(self.transformer_2) - self._teacache_backends = [ - self.cache_backend, - self.transformer_2_cache_backend, - ] + self._setup_cache_acceleration() def _run_warmup(self, height: int, width: int, num_frames: int, steps: int) -> None: dummy_image = PIL.Image.new("RGB", (width, height)) diff --git a/tensorrt_llm/_torch/visual_gen/pipeline.py b/tensorrt_llm/_torch/visual_gen/pipeline.py index febed93ba647..576944fb03d1 100644 --- a/tensorrt_llm/_torch/visual_gen/pipeline.py +++ b/tensorrt_llm/_torch/visual_gen/pipeline.py @@ -101,16 +101,6 @@ def _parse_profile_range(): from .config import DiffusionPipelineConfig -def _teacache_coefficients_are_explicit_user_override(teacache_cfg: Any) -> bool: - """Return True if teacache.coefficients should skip built-in variant table matching. - - Only None means auto (use checkpoint path table). Any non-empty list, including - the identity polynomial [1.0, 0.0], is treated as an explicit user override. - """ - - return teacache_cfg.coefficients is not None - - class BasePipeline(nn.Module): """ Base class for diffusion pipelines. @@ -138,8 +128,6 @@ def __init__(self, pipeline_config: "DiffusionPipelineConfig"): # Unified cache acceleration (TeaCache, Cache-DiT); see _setup_cache_acceleration self.cache_accelerator: Optional["CacheAccelerator"] = None - # Wan 2.2 manual TeaCacheBackend pair; see WanPipeline.post_load_weights - self._teacache_backends: List[Any] = [] # Components self.transformer: Optional[nn.Module] = None @@ -431,7 +419,7 @@ def _apply_teacache_coefficients(self, coefficients: Optional[Dict] = None) -> N teacache_cfg = self.pipeline_config.teacache if teacache_cfg is None: return - if _teacache_coefficients_are_explicit_user_override(teacache_cfg): + if teacache_cfg.is_explicit_user_override(): logger.info( "TeaCache: Using user-configured coefficients " "(skipping built-in checkpoint variant matching)" @@ -441,6 +429,12 @@ def _apply_teacache_coefficients(self, coefficients: Optional[Dict] = None) -> N teacache_explicit = teacache_cfg.model_dump(exclude_unset=True) if not coefficients: + if teacache_cfg.coefficients is None: + raise ValueError( + "TeaCache is enabled but no polynomial coefficients were resolved. " + "Set teacache.coefficients in VisualGenArgs, or use a pipeline and " + "checkpoint whose path matches a built-in coefficient table." + ) return checkpoint_path = ( @@ -487,11 +481,7 @@ def _apply_teacache_coefficients(self, coefficients: Optional[Dict] = None) -> N f"or use a checkpoint path that contains one of the variant keys." ) - def _setup_cache_acceleration( - self, - model: Optional[nn.Module] = None, - coefficients: Optional[Dict] = None, - ) -> None: + def _setup_cache_acceleration(self) -> None: """Enable TeaCache or Cache-DiT from model_config.cache_backend.""" if getattr(self, "cache_accelerator", None) is not None: @@ -510,29 +500,10 @@ def _setup_cache_acceleration( if not use_teacache: return - self._apply_teacache_coefficients(coefficients) - - teacache_cfg = cfg.teacache - if teacache_cfg is None or teacache_cfg.coefficients is None: - raise ValueError( - "TeaCache is enabled but no polynomial coefficients were resolved. " - "Set teacache.coefficients in VisualGenArgs, or use a pipeline and " - "checkpoint whose path matches a built-in coefficient table." - ) - - if model is None: - return - - acc = TeaCacheAccelerator(cfg.teacache) - acc.wrap(model=model) + acc = TeaCacheAccelerator(self, cfg.teacache) + acc.wrap() self.cache_accelerator = acc - def _refresh_teacache_backends(self, total_steps: int) -> None: - """Reset manual TeaCache backends (e.g. Wan 2.2 dual transformers).""" - for backend in self._teacache_backends: - if backend is not None and backend.is_enabled(): - backend.refresh(total_steps) - def setup_parallel_vae(self): """Enable parallel-VAE decode mode and wrap the VAE on participating ranks. @@ -1048,7 +1019,6 @@ def denoise( # Reset cache acceleration state for new generation (TeaCache / Cache-DiT) if getattr(self, "cache_accelerator", None) and self.cache_accelerator.is_enabled(): self.cache_accelerator.refresh(total_steps) - self._refresh_teacache_backends(total_steps) if self.rank == 0: if has_extra_streams: @@ -1175,11 +1145,20 @@ def denoise( if stats: if self.pipeline_config.cache_backend == "cache_dit": logger.info("Cache-DiT stats: %s", stats) - elif "hit_rate" in stats: - logger.info( - f"TeaCache: {stats['hit_rate']:.1%} hit rate " - f"({stats['cached']}/{stats['total']} steps)" - ) + elif self.pipeline_config.cache_backend == "teacache": + first_val = next(iter(stats.values()), None) + if isinstance(first_val, dict): + for key, s in stats.items(): + if "hit_rate" in s: + logger.info( + f"TeaCache {key}: {s['hit_rate']:.1%} hit rate " + f"({s['cached']}/{s['total']} steps)" + ) + elif "hit_rate" in stats: + logger.info( + f"TeaCache: {stats['hit_rate']:.1%} hit rate " + f"({stats['cached']}/{stats['total']} steps)" + ) else: logger.info("Cache acceleration stats: %s", stats) diff --git a/tensorrt_llm/visual_gen/args.py b/tensorrt_llm/visual_gen/args.py index 6520866e8202..d037df5bd4cc 100644 --- a/tensorrt_llm/visual_gen/args.py +++ b/tensorrt_llm/visual_gen/args.py @@ -333,6 +333,10 @@ def validate_teacache(self) -> "TeaCacheConfig": raise ValueError("TeaCache coefficients_2 list cannot be empty") return self + def is_explicit_user_override(self) -> bool: + """Return True if coefficients were set by the user and should skip built-in table matching.""" + return self.coefficients is not None + class CacheDiTConfig(BaseCacheConfig): """Configuration for Cache-DiT (DBCache, TaylorSeer, SCM). diff --git a/tests/unittest/_torch/visual_gen/test_teacache.py b/tests/unittest/_torch/visual_gen/test_teacache.py index ce3c720f419e..c53f77aaa135 100644 --- a/tests/unittest/_torch/visual_gen/test_teacache.py +++ b/tests/unittest/_torch/visual_gen/test_teacache.py @@ -109,13 +109,15 @@ def test_setup_cache_acceleration_raises_when_no_table_and_no_user_coefficients( """Fails if TeaCache is on but the pipeline passes no coefficient table.""" pipeline = self._make_pipeline_mock("FLUX.1-dev") with pytest.raises(ValueError, match="no polynomial coefficients were resolved"): - BasePipeline._setup_cache_acceleration(pipeline, MagicMock(), None) + BasePipeline._apply_teacache_coefficients(pipeline, None) + BasePipeline._setup_cache_acceleration(pipeline) def test_setup_cache_acceleration_raises_when_empty_table(self): """Same as no table: nothing to resolve from.""" pipeline = self._make_pipeline_mock("FLUX.1-dev") with pytest.raises(ValueError, match="no polynomial coefficients were resolved"): - BasePipeline._setup_cache_acceleration(pipeline, MagicMock(), {}) + BasePipeline._apply_teacache_coefficients(pipeline, {}) + BasePipeline._setup_cache_acceleration(pipeline) def test_matching_variant_selects_ret_steps_mode(self): """Nested table: use_ret_steps=True selects ret_steps coefficients.""" @@ -124,7 +126,8 @@ def test_matching_variant_selects_ret_steps_mode(self): "dev": {"standard": [1.0, 2.0, 3.0], "ret_steps": [4.0, 5.0]}, } with patch.object(TeaCacheBackend, "enable"): - BasePipeline._setup_cache_acceleration(pipeline, MagicMock(), coefficients) + BasePipeline._apply_teacache_coefficients(pipeline, coefficients) + BasePipeline._setup_cache_acceleration(pipeline) assert pipeline.model_config.teacache.coefficients == [4.0, 5.0] @@ -133,7 +136,8 @@ def test_flat_list_table_entry(self): pipeline = self._make_pipeline_mock("FLUX.1-dev") coefficients = {"dev": [11.0, 22.0, 33.0]} with patch.object(TeaCacheBackend, "enable"): - BasePipeline._setup_cache_acceleration(pipeline, MagicMock(), coefficients) + BasePipeline._apply_teacache_coefficients(pipeline, coefficients) + BasePipeline._setup_cache_acceleration(pipeline) assert pipeline.model_config.teacache.coefficients == [11.0, 22.0, 33.0] @@ -148,7 +152,8 @@ def test_default_thresh_from_table_when_user_did_not_set_teacache_thresh(self): }, } with patch.object(TeaCacheBackend, "enable"): - BasePipeline._setup_cache_acceleration(pipeline, MagicMock(), builtin) + BasePipeline._apply_teacache_coefficients(pipeline, builtin) + BasePipeline._setup_cache_acceleration(pipeline) assert pipeline.model_config.teacache.coefficients == [1.0, 2.0] assert pipeline.model_config.teacache.teacache_thresh == 0.42 @@ -161,11 +166,11 @@ def test_explicit_identity_coefficients_still_skip_table(self): coefficients=[1.0, 0.0], ) with patch.object(TeaCacheBackend, "enable"): - BasePipeline._setup_cache_acceleration( + BasePipeline._apply_teacache_coefficients( pipeline, - MagicMock(), {"dev": {"standard": [99.0, 99.0]}}, ) + BasePipeline._setup_cache_acceleration(pipeline) assert pipeline.model_config.teacache.coefficients == [1.0, 0.0] @@ -177,7 +182,8 @@ def test_matching_variant_selects_coefficients(self): "schnell": {"standard": [10.0, 20.0]}, } with patch.object(TeaCacheBackend, "enable"): - BasePipeline._setup_cache_acceleration(pipeline, MagicMock(), coefficients) + BasePipeline._apply_teacache_coefficients(pipeline, coefficients) + BasePipeline._setup_cache_acceleration(pipeline) assert pipeline.pipeline_config.teacache.coefficients == [1.0, 2.0, 3.0] @@ -189,14 +195,14 @@ def test_no_match_raises_valueerror(self): "schnell": {"standard": [10.0, 20.0]}, } with pytest.raises(ValueError, match="No coefficients found"): - BasePipeline._setup_cache_acceleration(pipeline, MagicMock(), coefficients) + BasePipeline._apply_teacache_coefficients(pipeline, coefficients) def test_disabled_teacache_is_noop(self): """No-op when cache is None (TeaCache not selected).""" pipeline = self._make_pipeline_mock("FLUX.1-dev") pipeline.pipeline_config = pipeline.pipeline_config.model_copy(update={"cache": None}) - BasePipeline._setup_cache_acceleration(pipeline, MagicMock(), {"dev": [1.0]}) + BasePipeline._setup_cache_acceleration(pipeline) assert pipeline.cache_accelerator is None def test_user_configured_coefficients_skip_variant_matching(self): @@ -210,7 +216,8 @@ def test_user_configured_coefficients_skip_variant_matching(self): "dev": {"standard": [1.0, 2.0, 3.0], "ret_steps": [4.0, 5.0]}, } with patch.object(TeaCacheBackend, "enable"): - BasePipeline._setup_cache_acceleration(pipeline, MagicMock(), builtin) + BasePipeline._apply_teacache_coefficients(pipeline, builtin) + BasePipeline._setup_cache_acceleration(pipeline) assert pipeline.model_config.teacache.coefficients == [0.25, 0.5, 0.75] @@ -225,7 +232,8 @@ def test_user_configured_coefficients_take_precedence_over_builtin_table(self): "dev": {"standard": [1.0, 2.0, 3.0], "ret_steps": [4.0, 5.0]}, } with patch.object(TeaCacheBackend, "enable"): - BasePipeline._setup_cache_acceleration(pipeline, MagicMock(), builtin) + BasePipeline._apply_teacache_coefficients(pipeline, builtin) + BasePipeline._setup_cache_acceleration(pipeline) assert pipeline.model_config.teacache.coefficients == [9.0, 8.0, 7.0] @@ -265,35 +273,32 @@ def test_empty_coefficients_2_rejected(self): TeaCacheConfig(coefficients_2=[]) -class TestRefreshTeacacheBackends: - """BasePipeline._refresh_teacache_backends walks _teacache_backends only.""" +class TestTeaCacheAcceleratorRefresh: + """TeaCacheAccelerator.refresh delegates to all registered backends.""" + + def _make_accelerator(self, backends): + from tensorrt_llm._torch.visual_gen.cache.teacache_accelerator import TeaCacheAccelerator + + acc = TeaCacheAccelerator.__new__(TeaCacheAccelerator) + acc._backends = backends + return acc def test_single_backend_refresh(self): - shared = MagicMock() - shared.is_enabled.return_value = True - pipe = SimpleNamespace(_teacache_backends=[shared]) - BasePipeline._refresh_teacache_backends(pipe, 50) - shared.refresh.assert_called_once_with(50) + backend = MagicMock() + acc = self._make_accelerator([(MagicMock(), backend)]) + acc.refresh(50) + backend.refresh.assert_called_once_with(50) def test_refreshes_two_distinct_backends(self): - b1 = MagicMock() - b1.is_enabled.return_value = True - b2 = MagicMock() - b2.is_enabled.return_value = True - pipe = SimpleNamespace(_teacache_backends=[b1, b2]) - BasePipeline._refresh_teacache_backends(pipe, 30) + b1, b2 = MagicMock(), MagicMock() + acc = self._make_accelerator([(MagicMock(), b1), (MagicMock(), b2)]) + acc.refresh(30) b1.refresh.assert_called_once_with(30) b2.refresh.assert_called_once_with(30) - def test_skips_disabled_backends(self): - active = MagicMock() - active.is_enabled.return_value = True - disabled = MagicMock() - disabled.is_enabled.return_value = False - pipe = SimpleNamespace(_teacache_backends=[active, disabled]) - BasePipeline._refresh_teacache_backends(pipe, 10) - active.refresh.assert_called_once_with(10) - disabled.refresh.assert_not_called() + def test_no_backends_is_noop(self): + acc = self._make_accelerator([]) + acc.refresh(10) # should not raise class TestFlux2TeacacheTable: @@ -314,9 +319,8 @@ def test_flux2_dev_variant_resolves_from_checkpoint_path(self): ) ) with patch.object(TeaCacheBackend, "enable"): - BasePipeline._setup_cache_acceleration( - pipeline, MagicMock(), FLUX2_TEACACHE_COEFFICIENTS - ) + BasePipeline._apply_teacache_coefficients(pipeline, FLUX2_TEACACHE_COEFFICIENTS) + BasePipeline._setup_cache_acceleration(pipeline) assert pipeline.model_config.teacache.coefficients is not None assert len(pipeline.model_config.teacache.coefficients) >= 2 @@ -441,15 +445,14 @@ def test_wan22_t2v_installs_two_teacache_backends_when_coefficients_provided(sel "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan.register_extractor_from_config" ): with patch( - "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan.TeaCacheBackend" + "tensorrt_llm._torch.visual_gen.cache.teacache_accelerator.TeaCacheBackend" ) as TB: TB.side_effect = [backend_a, backend_b] WanPipeline.post_load_weights(pipe) assert TB.call_count == 2 assert mock_enable.call_count == 2 - assert pipe.transformer_cache_backend is pipe.cache_backend - assert pipe.transformer_2_cache_backend is not None - assert pipe.transformer_2_cache_backend is not pipe.cache_backend + assert pipe.cache_accelerator is not None + assert len(pipe.cache_accelerator.backends) == 2 def test_wan22_t2v_transformer_gets_coefficients_and_transformer_2_gets_coefficients_2(self): from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan import WanPipeline @@ -473,7 +476,7 @@ def test_wan22_t2v_transformer_gets_coefficients_and_transformer_2_gets_coeffici "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan.register_extractor_from_config" ): with patch( - "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan.TeaCacheBackend" + "tensorrt_llm._torch.visual_gen.cache.teacache_accelerator.TeaCacheBackend" ) as TB: TB.return_value = MagicMock() WanPipeline.post_load_weights(pipe) @@ -508,7 +511,7 @@ def test_wan22_i2v_transformer_gets_coefficients_and_transformer_2_gets_coeffici "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan_i2v.register_extractor_from_config" ): with patch( - "tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan_i2v.TeaCacheBackend" + "tensorrt_llm._torch.visual_gen.cache.teacache_accelerator.TeaCacheBackend" ) as TB: TB.return_value = MagicMock() WanImageToVideoPipeline.post_load_weights(pipe) diff --git a/tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py b/tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py index 700c896dec42..798755df8c86 100644 --- a/tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py +++ b/tests/unittest/_torch/visual_gen/test_wan21_i2v_teacache.py @@ -171,7 +171,7 @@ def _assert_i2v_teacache( seed=INFER_SEED, ) - stats = pipeline.transformer_cache_backend.get_stats() + stats = pipeline.cache_accelerator.get_stats() print(f"\n ===== TeaCache — Wan 2.1 {model} single-stage {height}x{width} =====") print( diff --git a/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py b/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py index 4854928ee86c..76e19f8232e5 100644 --- a/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py +++ b/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache.py @@ -155,7 +155,7 @@ def _assert_single_stage_teacache( seed=INFER_SEED, ) - stats = pipeline.transformer_cache_backend.get_stats() + stats = pipeline.cache_accelerator.get_stats() print(f"\n ===== TeaCache — Wan 2.1 {model} single-stage {height}x{width} =====") print( diff --git a/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache_user_coefficients.py b/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache_user_coefficients.py index 2dd1904dfac4..aa16c708b6a7 100644 --- a/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache_user_coefficients.py +++ b/tests/unittest/_torch/visual_gen/test_wan21_t2v_teacache_user_coefficients.py @@ -120,7 +120,7 @@ def _run_forward(coefficients: list, thresh: float, label: str) -> dict: num_inference_steps=NUM_STEPS, seed=SEED, ) - stats = pipeline.transformer_cache_backend.get_stats() + stats = pipeline.cache_accelerator.get_stats() finally: del pipeline gc.collect() diff --git a/tests/unittest/_torch/visual_gen/test_wan22_i2v_teacache.py b/tests/unittest/_torch/visual_gen/test_wan22_i2v_teacache.py new file mode 100644 index 000000000000..7cc439162d4a --- /dev/null +++ b/tests/unittest/_torch/visual_gen/test_wan22_i2v_teacache.py @@ -0,0 +1,200 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Wan 2.2 I2V TeaCache happy path. + +Wan 2.2 uses a dual-transformer architecture (high-noise + low-noise stages). +TeaCache requires explicit coefficients for both transformers via +TeaCacheConfig.coefficients (high-noise) and TeaCacheConfig.coefficients_2 (low-noise). + +Verifies: + - Both transformer backends are initialized + - Both backends produce cache hits after a forward pass + - Stats are returned for each transformer separately + +Run: + pytest tests/unittest/_torch/visual_gen/test_wan22_i2v_teacache.py -v -s + +Override checkpoint path: + DIFFUSION_MODEL_PATH_WAN22_I2V=/path/to/wan22 \\ + pytest tests/unittest/_torch/visual_gen/test_wan22_i2v_teacache.py -v -s +""" + +import os + +os.environ["TLLM_DISABLE_MPI"] = "1" + +import gc +from pathlib import Path + +import numpy as np +import pytest +import torch +from PIL import Image + +from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader +from tensorrt_llm.visual_gen.args import TeaCacheConfig, VisualGenArgs + + +@pytest.fixture(autouse=True, scope="module") +def _cleanup_mpi_env(): + yield + os.environ.pop("TLLM_DISABLE_MPI", None) + + +@pytest.fixture(autouse=True) +def _cleanup_gpu(): + gc.collect() + torch.cuda.empty_cache() + yield + gc.collect() + torch.cuda.empty_cache() + + +# ============================================================================ +# Path helpers +# ============================================================================ + + +def _llm_models_root() -> Path: + if "LLM_MODELS_ROOT" in os.environ: + root = Path(os.environ["LLM_MODELS_ROOT"]) + else: + root = Path("/home/scratch.trt_llm_data_ci/llm-models/") + if not root.exists(): + root = Path("/scratch.trt_llm_data/llm-models/") + assert root.exists(), ( + "Set LLM_MODELS_ROOT or ensure /home/scratch.trt_llm_data_ci/llm-models/ is accessible." + ) + return root + + +def _checkpoint(env_var: str, default_name: str) -> str: + return os.environ.get(env_var) or str(_llm_models_root() / default_name) + + +WAN22_I2V_A14B_PATH = _checkpoint("DIFFUSION_MODEL_PATH_WAN22_I2V", "Wan2.2-I2V-A14B-Diffusers") + +INFER_NUM_FRAMES = 33 # (33-1)/4+1 = 9 latent frames; smallest realistic shape +INFER_NUM_STEPS = 20 # Wan 2.2 has no reference hit rate; just enough to exercise both backends +INFER_SEED = 42 + +# Placeholder coefficients for Wan 2.2 dual-transformer TeaCache. +# Wan 2.2 has no built-in coefficient table; users must supply both sets. +# These are used to exercise the dual-backend path and verify cache hits fire. +WAN22_I2V_HIGH_NOISE_COEFFICIENTS = [ + -5784.54975374, + 5449.50911966, + -1811.16591783, + 256.27178429, + -13.02301147, +] +WAN22_I2V_LOW_NOISE_COEFFICIENTS = [ + 2.39676752e03, + -1.31110545e03, + 2.01331979e02, + -8.29855975e00, + 1.37887774e-01, +] + + +# ============================================================================ +# Helpers +# ============================================================================ + + +def _make_test_image(height: int, width: int) -> Image.Image: + img = np.zeros((height, width, 3), dtype=np.uint8) + img[:, :, 0] = np.linspace(0, 255, height, dtype=np.uint8)[:, None] + return Image.fromarray(img, mode="RGB") + + +# ============================================================================ +# Fixture +# ============================================================================ + + +@pytest.fixture +def wan22_i2v_pipeline(): + if not os.path.exists(WAN22_I2V_A14B_PATH): + pytest.skip(f"Checkpoint not found: {WAN22_I2V_A14B_PATH}") + args = VisualGenArgs( + model=WAN22_I2V_A14B_PATH, + cache_config=TeaCacheConfig( + teacache_thresh=0.15, + coefficients=WAN22_I2V_HIGH_NOISE_COEFFICIENTS, + coefficients_2=WAN22_I2V_LOW_NOISE_COEFFICIENTS, + ), + ) + pipeline = PipelineLoader(args).load(skip_warmup=True) + yield pipeline + del pipeline + torch.cuda.empty_cache() + + +# ============================================================================ +# Assertion helper +# ============================================================================ + + +def _assert_dual_stage_teacache(pipeline, height: int, width: int) -> None: + test_image = _make_test_image(height, width) + + with torch.no_grad(): + pipeline.forward( + image=test_image, + prompt="a cat sitting on a windowsill", + negative_prompt="", + height=height, + width=width, + num_frames=INFER_NUM_FRAMES, + num_inference_steps=INFER_NUM_STEPS, + seed=INFER_SEED, + ) + + stats = pipeline.cache_accelerator.get_stats() + + print(f"\n ===== TeaCache — Wan 2.2 I2V-A14B dual-stage {height}x{width} =====") + for key, s in stats.items(): + print( + f" {key}: {s['cached_steps']}/{s['total_steps']} cached ({s['hit_rate']:.1%} hit rate)" + ) + print(" ================================================================") + + assert len(stats) == 2, f"Expected stats for 2 transformers, got: {list(stats.keys())}" + total_steps_sum = sum(s["total_steps"] for s in stats.values()) + assert total_steps_sum == INFER_NUM_STEPS, ( + f"Sum of steps across both transformers {total_steps_sum} != {INFER_NUM_STEPS}" + ) + for key, s in stats.items(): + assert s["total_steps"] > 0, f"{key}: transformer ran 0 steps" + assert s["compute_steps"] + s["cached_steps"] == s["total_steps"] + + +# ============================================================================ +# Tests +# ============================================================================ + + +@pytest.mark.integration +@pytest.mark.wan_i2v +@pytest.mark.teacache +class TestWan22I2V_TeaCache: + """Wan2.2-I2V-A14B 480x832 dual-stage TeaCache.""" + + def test_wan22_i2v_teacache_forward_runs(self, wan22_i2v_pipeline): + _assert_dual_stage_teacache(wan22_i2v_pipeline, height=480, width=832) + + def test_wan22_i2v_teacache_two_backends_initialized(self, wan22_i2v_pipeline): + assert len(wan22_i2v_pipeline.cache_accelerator.backends) == 2 diff --git a/tests/unittest/_torch/visual_gen/test_wan22_t2v_teacache.py b/tests/unittest/_torch/visual_gen/test_wan22_t2v_teacache.py new file mode 100644 index 000000000000..f295fab62e88 --- /dev/null +++ b/tests/unittest/_torch/visual_gen/test_wan22_t2v_teacache.py @@ -0,0 +1,184 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Wan 2.2 T2V TeaCache happy path. + +Wan 2.2 uses a dual-transformer architecture (high-noise + low-noise stages). +TeaCache requires explicit coefficients for both transformers via +TeaCacheConfig.coefficients (high-noise) and TeaCacheConfig.coefficients_2 (low-noise). + +Verifies: + - Both transformer backends are initialized + - Both backends produce cache hits after a forward pass + - Stats are returned for each transformer separately + +Run: + pytest tests/unittest/_torch/visual_gen/test_wan22_t2v_teacache.py -v -s + +Override checkpoint path: + DIFFUSION_MODEL_PATH_WAN22_T2V=/path/to/wan22 \\ + pytest tests/unittest/_torch/visual_gen/test_wan22_t2v_teacache.py -v -s +""" + +import os + +os.environ["TLLM_DISABLE_MPI"] = "1" + +import gc +from pathlib import Path + +import pytest +import torch + +from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader +from tensorrt_llm.visual_gen.args import TeaCacheConfig, VisualGenArgs + + +@pytest.fixture(autouse=True, scope="module") +def _cleanup_mpi_env(): + yield + os.environ.pop("TLLM_DISABLE_MPI", None) + + +@pytest.fixture(autouse=True) +def _cleanup_gpu(): + gc.collect() + torch.cuda.empty_cache() + yield + gc.collect() + torch.cuda.empty_cache() + + +# ============================================================================ +# Path helpers +# ============================================================================ + + +def _llm_models_root() -> Path: + if "LLM_MODELS_ROOT" in os.environ: + root = Path(os.environ["LLM_MODELS_ROOT"]) + else: + root = Path("/home/scratch.trt_llm_data_ci/llm-models/") + if not root.exists(): + root = Path("/scratch.trt_llm_data/llm-models/") + assert root.exists(), ( + "Set LLM_MODELS_ROOT or ensure /home/scratch.trt_llm_data_ci/llm-models/ is accessible." + ) + return root + + +def _checkpoint(env_var: str, default_name: str) -> str: + return os.environ.get(env_var) or str(_llm_models_root() / default_name) + + +WAN22_A14B_PATH = _checkpoint("DIFFUSION_MODEL_PATH_WAN22_T2V", "Wan2.2-T2V-A14B-Diffusers") + +INFER_NUM_FRAMES = 33 # (33-1)/4+1 = 9 latent frames; smallest realistic shape +INFER_NUM_STEPS = 20 # Wan 2.2 has no reference hit rate; just enough to exercise both backends +INFER_SEED = 42 + +# Placeholder coefficients for Wan 2.2 dual-transformer TeaCache. +# Wan 2.2 has no built-in coefficient table; users must supply both sets. +# These are used to exercise the dual-backend path and verify cache hits fire. +WAN22_T2V_HIGH_NOISE_COEFFICIENTS = [ + -5784.54975374, + 5449.50911966, + -1811.16591783, + 256.27178429, + -13.02301147, +] +WAN22_T2V_LOW_NOISE_COEFFICIENTS = [ + 2.39676752e03, + -1.31110545e03, + 2.01331979e02, + -8.29855975e00, + 1.37887774e-01, +] + + +# ============================================================================ +# Fixture +# ============================================================================ + + +@pytest.fixture +def wan22_t2v_pipeline(): + if not os.path.exists(WAN22_A14B_PATH): + pytest.skip(f"Checkpoint not found: {WAN22_A14B_PATH}") + args = VisualGenArgs( + model=WAN22_A14B_PATH, + cache_config=TeaCacheConfig( + teacache_thresh=0.15, + coefficients=WAN22_T2V_HIGH_NOISE_COEFFICIENTS, + coefficients_2=WAN22_T2V_LOW_NOISE_COEFFICIENTS, + ), + ) + pipeline = PipelineLoader(args).load(skip_warmup=True) + yield pipeline + del pipeline + torch.cuda.empty_cache() + + +# ============================================================================ +# Assertion helper +# ============================================================================ + + +def _assert_dual_stage_teacache(pipeline, height: int, width: int) -> None: + with torch.no_grad(): + pipeline.forward( + prompt="a cat sitting on a windowsill", + negative_prompt="", + height=height, + width=width, + num_frames=INFER_NUM_FRAMES, + num_inference_steps=INFER_NUM_STEPS, + seed=INFER_SEED, + ) + + stats = pipeline.cache_accelerator.get_stats() + + print(f"\n ===== TeaCache — Wan 2.2 T2V-A14B dual-stage {height}x{width} =====") + for key, s in stats.items(): + print( + f" {key}: {s['cached_steps']}/{s['total_steps']} cached ({s['hit_rate']:.1%} hit rate)" + ) + print(" ================================================================") + + assert len(stats) == 2, f"Expected stats for 2 transformers, got: {list(stats.keys())}" + total_steps_sum = sum(s["total_steps"] for s in stats.values()) + assert total_steps_sum == INFER_NUM_STEPS, ( + f"Sum of steps across both transformers {total_steps_sum} != {INFER_NUM_STEPS}" + ) + for key, s in stats.items(): + assert s["total_steps"] > 0, f"{key}: transformer ran 0 steps" + assert s["compute_steps"] + s["cached_steps"] == s["total_steps"] + + +# ============================================================================ +# Tests +# ============================================================================ + + +@pytest.mark.integration +@pytest.mark.wan_t2v +@pytest.mark.teacache +class TestWan22T2V_TeaCache: + """Wan2.2-T2V-A14B 480x832 dual-stage TeaCache.""" + + def test_wan22_t2v_teacache_forward_runs(self, wan22_t2v_pipeline): + _assert_dual_stage_teacache(wan22_t2v_pipeline, height=480, width=832) + + def test_wan22_t2v_teacache_two_backends_initialized(self, wan22_t2v_pipeline): + assert len(wan22_t2v_pipeline.cache_accelerator.backends) == 2 From 287a52ea87e1c2faaa800aee767cc2a0a054f7dd Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Mon, 15 Jun 2026 09:51:10 -0700 Subject: [PATCH 11/11] two-transformer teacache tests Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- tests/integration/test_lists/test-db/l0_b200.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index cbf124e5f6cf..307194a1de14 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -215,6 +215,8 @@ l0_b200: - unittest/_torch/visual_gen/test_wan21_i2v_teacache.py - unittest/_torch/visual_gen/test_wan21_t2v_teacache.py - unittest/_torch/visual_gen/test_wan21_t2v_teacache_user_coefficients.py + - unittest/_torch/visual_gen/test_wan22_i2v_teacache.py + - unittest/_torch/visual_gen/test_wan22_t2v_teacache.py - unittest/_torch/visual_gen/test_wan_transformer.py - unittest/_torch/visual_gen/test_cosmos3_transformer.py - unittest/_torch/visual_gen/test_cosmos3_pipeline.py