Skip to content

Commit 5554be2

Browse files
committed
user-configured TeaCache coefficients
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
1 parent 90c1cb7 commit 5554be2

3 files changed

Lines changed: 256 additions & 50 deletions

File tree

tensorrt_llm/_torch/visual_gen/config.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,9 @@ class TeaCacheConfig(StrictBaseModel):
160160
teacache_thresh: Distance threshold for cache decisions (lower = more caching)
161161
use_ret_steps: Use aggressive warmup mode (5 steps) vs minimal (1 step)
162162
coefficients: Polynomial coefficients for rescaling embedding distances
163-
Applied as: rescaled_distance = poly(raw_distance)
163+
Applied as: rescaled_distance = poly(raw_distance).
164+
None means use the pipeline built-in coefficient table (checkpoint path
165+
matching). A non-None list overrides that table.
164166
ret_steps: Number of warmup steps (always compute, initialized at runtime)
165167
cutoff_steps: Step to stop caching (always compute after, initialized at runtime)
166168
num_steps: Total inference steps (set at runtime)
@@ -171,7 +173,7 @@ class TeaCacheConfig(StrictBaseModel):
171173
teacache_thresh: float = PydanticField(0.2, gt=0.0)
172174
use_ret_steps: bool = False
173175

174-
coefficients: List[float] = PydanticField(default_factory=lambda: [1.0, 0.0])
176+
coefficients: Optional[List[float]] = None
175177

176178
# Runtime state fields (initialized by TeaCacheBackend.refresh)
177179
ret_steps: Optional[int] = None
@@ -186,8 +188,8 @@ class TeaCacheConfig(StrictBaseModel):
186188
@model_validator(mode="after")
187189
def validate_teacache(self) -> "TeaCacheConfig":
188190
"""Validate TeaCache configuration."""
189-
# Validate coefficients
190-
if len(self.coefficients) == 0:
191+
# Validate coefficients (when provided)
192+
if self.coefficients is not None and len(self.coefficients) == 0:
191193
raise ValueError("TeaCache coefficients list cannot be empty")
192194

193195
# Validate ret_steps if set

tensorrt_llm/_torch/visual_gen/pipeline.py

Lines changed: 93 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,16 @@
2020
from .config import DiffusionModelConfig
2121

2222

23+
def _teacache_coefficients_are_explicit_user_override(teacache_cfg: Any) -> bool:
24+
"""Return True if teacache.coefficients should skip built-in variant table matching.
25+
26+
Only None means auto (use checkpoint path table). Any non-empty list, including
27+
the identity polynomial [1.0, 0.0], is treated as an explicit user override.
28+
"""
29+
30+
return teacache_cfg.coefficients is not None
31+
32+
2333
class BasePipeline(nn.Module):
2434
"""
2535
Base class for diffusion pipelines.
@@ -250,10 +260,85 @@ def post_load_weights(self) -> None:
250260
if self.transformer is not None and hasattr(self.transformer, "post_load_weights"):
251261
self.transformer.post_load_weights()
252262

263+
def _apply_teacache_coefficients(self, coefficients: Optional[Dict] = None) -> None:
264+
"""Resolve TeaCache polynomial coefficients into model_config.teacache.
265+
266+
Precedence:
267+
268+
1. User-specified TeaCacheConfig.coefficients — any non-None list skips built-in
269+
variant matching.
270+
271+
2. Pipeline table — if step 1 does not apply and the coefficients argument is a
272+
non-empty dict (model-specific tables from the pipeline subclass), match
273+
pretrained_config._name_or_path against keys and set coefficients (and optional
274+
default_thresh).
275+
276+
3. If coefficients is still None after step 2, _setup_teacache raises: TeaCache
277+
must not run without resolved coefficients.
278+
279+
Args:
280+
coefficients: Optional mapping from variant key to coefficient list or nested
281+
dict (ret_steps / standard), from the pipeline subclass.
282+
"""
283+
teacache_cfg = self.model_config.teacache
284+
if _teacache_coefficients_are_explicit_user_override(teacache_cfg):
285+
logger.info(
286+
"TeaCache: Using user-configured coefficients "
287+
"(skipping built-in checkpoint variant matching)"
288+
)
289+
return
290+
291+
teacache_explicit = teacache_cfg.model_dump(exclude_unset=True)
292+
293+
if not coefficients:
294+
return
295+
296+
checkpoint_path = getattr(self.model_config.pretrained_config, "_name_or_path", "") or ""
297+
298+
for model_size, coeff_data in coefficients.items():
299+
# Match model size in path (case-insensitive, e.g., "1.3B", "14B", "dev")
300+
path_l = checkpoint_path.lower()
301+
key_l = model_size.lower()
302+
if key_l not in path_l:
303+
continue
304+
305+
if isinstance(coeff_data, dict):
306+
# Select coefficient set based on warmup mode
307+
mode = "ret_steps" if teacache_cfg.use_ret_steps else "standard"
308+
if mode not in coeff_data:
309+
logger.warning(
310+
"TeaCache: matched variant %r but table has no %r entry "
311+
"(available keys: %s). Trying other variants.",
312+
model_size,
313+
mode,
314+
list(coeff_data.keys()),
315+
)
316+
continue
317+
teacache_cfg.coefficients = coeff_data[mode]
318+
logger.info(f"TeaCache: Using {model_size} coefficients ({mode} mode)")
319+
# Apply model-specific default threshold if user didn't explicitly set one
320+
default_thresh = coeff_data.get("default_thresh")
321+
if default_thresh is not None and "teacache_thresh" not in teacache_explicit:
322+
teacache_cfg.teacache_thresh = default_thresh
323+
logger.info(f"TeaCache: Using {model_size} default threshold {default_thresh}")
324+
break
325+
else:
326+
# Single coefficient list (no mode distinction)
327+
teacache_cfg.coefficients = coeff_data
328+
logger.info(f"TeaCache: Using {model_size} coefficients")
329+
break
330+
else:
331+
raise ValueError(
332+
f"TeaCache: No coefficients found for checkpoint '{checkpoint_path}'. "
333+
f"Available variants: {list(coefficients.keys())}. "
334+
f"Set teacache.coefficients explicitly in VisualGenArgs to use TeaCache anyway, "
335+
f"or use a checkpoint path that contains one of the variant keys."
336+
)
337+
253338
def _setup_teacache(self, model, coefficients: Optional[Dict] = None):
254339
"""Setup TeaCache optimization for the transformer model.
255340
256-
TeaCache caches transformer block outputs when timestep embeddings change slowly,
341+
TeaCache caches transformer block outputs,
257342
reducing computation during the denoising loop.
258343
259344
Args:
@@ -268,42 +353,14 @@ def _setup_teacache(self, model, coefficients: Optional[Dict] = None):
268353
if not teacache_cfg.enable_teacache:
269354
return
270355

271-
# Apply model-specific polynomial coefficients
272-
# Coefficients are used to rescale embedding distances for cache decisions
273-
if coefficients:
274-
checkpoint_path = (
275-
getattr(self.model_config.pretrained_config, "_name_or_path", "") or ""
356+
self._apply_teacache_coefficients(coefficients)
357+
358+
if teacache_cfg.coefficients is None:
359+
raise ValueError(
360+
"TeaCache is enabled but no polynomial coefficients were resolved. "
361+
"Set teacache.coefficients in VisualGenArgs, or use a pipeline and "
362+
"checkpoint whose path matches a built-in coefficient table."
276363
)
277-
for model_size, coeff_data in coefficients.items():
278-
# Match model size in path (case-insensitive, e.g., "1.3B", "14B", "dev")
279-
if model_size.lower() in checkpoint_path.lower():
280-
if isinstance(coeff_data, dict):
281-
# Select coefficient set based on warmup mode
282-
mode = "ret_steps" if teacache_cfg.use_ret_steps else "standard"
283-
if mode in coeff_data:
284-
teacache_cfg.coefficients = coeff_data[mode]
285-
logger.info(f"TeaCache: Using {model_size} coefficients ({mode} mode)")
286-
# Apply model-specific default threshold if user didn't explicitly set one
287-
default_thresh = coeff_data.get("default_thresh")
288-
if (
289-
default_thresh is not None
290-
and "teacache_thresh" not in teacache_cfg.model_fields_set
291-
):
292-
teacache_cfg.teacache_thresh = default_thresh
293-
logger.info(
294-
f"TeaCache: Using {model_size} default threshold {default_thresh}"
295-
)
296-
else:
297-
# Single coefficient list (no mode distinction)
298-
teacache_cfg.coefficients = coeff_data
299-
logger.info(f"TeaCache: Using {model_size} coefficients")
300-
break
301-
else:
302-
raise ValueError(
303-
f"TeaCache: No coefficients found for checkpoint '{checkpoint_path}'. "
304-
f"Available variants: {list(coefficients.keys())}. "
305-
f"TeaCache is not supported for this model variant."
306-
)
307364

308365
# Initialize and enable TeaCache backend
309366
logger.info("TeaCache: Initializing...")

tests/unittest/_torch/visual_gen/test_teacache.py

Lines changed: 157 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Unit tests for TeaCache (CPU-only, no model weights needed)."""
1616

1717
from types import SimpleNamespace
18+
from typing import Optional
1819
from unittest.mock import MagicMock, patch
1920

2021
import pytest
@@ -24,21 +25,116 @@
2425
from tensorrt_llm._torch.visual_gen.teacache import TeaCacheBackend
2526

2627

28+
class _PipelineTeacacheTestDouble:
29+
"""Minimal object for BasePipeline._setup_teacache / _apply_teacache_coefficients tests."""
30+
31+
def __init__(self, model_config: DiffusionModelConfig):
32+
self.model_config = model_config
33+
self.cache_backend = None
34+
35+
def _apply_teacache_coefficients(self, coefficients: Optional[dict] = None):
36+
return BasePipeline._apply_teacache_coefficients(self, coefficients)
37+
38+
2739
class TestSetupTeacache:
2840
"""Tests for _setup_teacache coefficient matching and fail-early behavior."""
2941

3042
def _make_pipeline_mock(self, checkpoint_name, use_ret_steps=False):
31-
pipeline = MagicMock()
32-
pipeline.model_config = DiffusionModelConfig(
33-
pretrained_config=SimpleNamespace(_name_or_path=f"/path/to/{checkpoint_name}/snapshot"),
34-
teacache=TeaCacheConfig(
35-
enable_teacache=True,
36-
teacache_thresh=0.3,
37-
use_ret_steps=use_ret_steps,
38-
),
39-
skip_create_weights_in_init=True,
43+
return _PipelineTeacacheTestDouble(
44+
DiffusionModelConfig(
45+
pretrained_config=SimpleNamespace(
46+
_name_or_path=f"/path/to/{checkpoint_name}/snapshot"
47+
),
48+
teacache=TeaCacheConfig(
49+
enable_teacache=True,
50+
teacache_thresh=0.3,
51+
use_ret_steps=use_ret_steps,
52+
coefficients=None, # auto: resolve from pipeline variant table
53+
),
54+
skip_create_weights_in_init=True,
55+
)
4056
)
41-
return pipeline
57+
58+
def _make_pipeline_teacache_enable_only(self, checkpoint_name, use_ret_steps=False):
59+
"""enable_teacache only so teacache_thresh stays unset (exclude_unset omits it)."""
60+
return _PipelineTeacacheTestDouble(
61+
DiffusionModelConfig(
62+
pretrained_config=SimpleNamespace(
63+
_name_or_path=f"/path/to/{checkpoint_name}/snapshot"
64+
),
65+
teacache=TeaCacheConfig(
66+
enable_teacache=True,
67+
use_ret_steps=use_ret_steps,
68+
coefficients=None,
69+
),
70+
skip_create_weights_in_init=True,
71+
)
72+
)
73+
74+
def test_setup_teacache_raises_when_no_table_and_no_user_coefficients(self):
75+
"""Fails if TeaCache is on but the pipeline passes no coefficient table."""
76+
pipeline = self._make_pipeline_mock("FLUX.1-dev")
77+
with pytest.raises(ValueError, match="no polynomial coefficients were resolved"):
78+
BasePipeline._setup_teacache(pipeline, MagicMock(), None)
79+
80+
def test_setup_teacache_raises_when_empty_table(self):
81+
"""Same as no table: nothing to resolve from."""
82+
pipeline = self._make_pipeline_mock("FLUX.1-dev")
83+
with pytest.raises(ValueError, match="no polynomial coefficients were resolved"):
84+
BasePipeline._setup_teacache(pipeline, MagicMock(), {})
85+
86+
def test_matching_variant_selects_ret_steps_mode(self):
87+
"""Nested table: use_ret_steps=True selects ret_steps coefficients."""
88+
pipeline = self._make_pipeline_mock("FLUX.1-dev", use_ret_steps=True)
89+
coefficients = {
90+
"dev": {"standard": [1.0, 2.0, 3.0], "ret_steps": [4.0, 5.0]},
91+
}
92+
with patch.object(TeaCacheBackend, "enable"):
93+
BasePipeline._setup_teacache(pipeline, MagicMock(), coefficients)
94+
95+
assert pipeline.model_config.teacache.coefficients == [4.0, 5.0]
96+
97+
def test_flat_list_table_entry(self):
98+
"""Table value may be a plain list (no standard/ret_steps nesting)."""
99+
pipeline = self._make_pipeline_mock("FLUX.1-dev")
100+
coefficients = {"dev": [11.0, 22.0, 33.0]}
101+
with patch.object(TeaCacheBackend, "enable"):
102+
BasePipeline._setup_teacache(pipeline, MagicMock(), coefficients)
103+
104+
assert pipeline.model_config.teacache.coefficients == [11.0, 22.0, 33.0]
105+
106+
def test_default_thresh_from_table_when_user_did_not_set_teacache_thresh(self):
107+
"""default_thresh applies when teacache_thresh was not explicitly set (exclude_unset)."""
108+
pipeline = self._make_pipeline_teacache_enable_only("FLUX.1-dev")
109+
builtin = {
110+
"dev": {
111+
"standard": [1.0, 2.0],
112+
"ret_steps": [3.0, 4.0],
113+
"default_thresh": 0.42,
114+
},
115+
}
116+
with patch.object(TeaCacheBackend, "enable"):
117+
BasePipeline._setup_teacache(pipeline, MagicMock(), builtin)
118+
119+
assert pipeline.model_config.teacache.coefficients == [1.0, 2.0]
120+
assert pipeline.model_config.teacache.teacache_thresh == 0.42
121+
122+
def test_explicit_identity_coefficients_still_skip_table(self):
123+
"""[1.0, 0.0] is a user override: no variant lookup, no ValueError on unknown path."""
124+
pipeline = self._make_pipeline_mock("FLUX.1-unknown-variant")
125+
pipeline.model_config.teacache = TeaCacheConfig(
126+
enable_teacache=True,
127+
teacache_thresh=0.3,
128+
coefficients=[1.0, 0.0],
129+
)
130+
with patch.object(TeaCacheBackend, "enable"):
131+
BasePipeline._setup_teacache(
132+
pipeline,
133+
MagicMock(),
134+
{"dev": {"standard": [99.0, 99.0]}},
135+
)
136+
137+
assert pipeline.model_config.teacache.coefficients == [1.0, 0.0]
42138

43139
def test_matching_variant_selects_coefficients(self):
44140
"""Picks coefficients whose key appears in checkpoint path."""
@@ -69,3 +165,54 @@ def test_disabled_teacache_is_noop(self):
69165

70166
BasePipeline._setup_teacache(pipeline, MagicMock(), {"dev": [1.0]})
71167
assert pipeline.cache_backend is None
168+
169+
def test_user_configured_coefficients_skip_variant_matching(self):
170+
"""Explicit TeaCacheConfig.coefficients skips dict lookup (no ValueError)."""
171+
pipeline = self._make_pipeline_mock("FLUX.1-unknown-variant")
172+
pipeline.model_config.teacache = TeaCacheConfig(
173+
enable_teacache=True,
174+
teacache_thresh=0.3,
175+
coefficients=[0.25, 0.5, 0.75],
176+
)
177+
builtin = {
178+
"dev": {"standard": [1.0, 2.0, 3.0], "ret_steps": [4.0, 5.0]},
179+
}
180+
with patch.object(TeaCacheBackend, "enable"):
181+
BasePipeline._setup_teacache(pipeline, MagicMock(), builtin)
182+
183+
assert pipeline.model_config.teacache.coefficients == [0.25, 0.5, 0.75]
184+
185+
def test_user_configured_coefficients_take_precedence_over_builtin_table(self):
186+
"""User coefficients are not overwritten when a built-in variant would also match."""
187+
pipeline = self._make_pipeline_mock("FLUX.1-dev")
188+
pipeline.model_config.teacache = TeaCacheConfig(
189+
enable_teacache=True,
190+
teacache_thresh=0.3,
191+
coefficients=[9.0, 8.0, 7.0],
192+
)
193+
builtin = {
194+
"dev": {"standard": [1.0, 2.0, 3.0], "ret_steps": [4.0, 5.0]},
195+
}
196+
with patch.object(TeaCacheBackend, "enable"):
197+
BasePipeline._setup_teacache(pipeline, MagicMock(), builtin)
198+
199+
assert pipeline.model_config.teacache.coefficients == [9.0, 8.0, 7.0]
200+
201+
def test_apply_teacache_coefficients_only(self):
202+
"""_apply_teacache_coefficients updates config without enabling backend."""
203+
pipeline = self._make_pipeline_mock("FLUX.1-unknown-variant")
204+
pipeline.model_config.teacache = TeaCacheConfig(
205+
enable_teacache=True,
206+
coefficients=[0.1, 0.2],
207+
)
208+
# Would raise if variant matching ran
209+
BasePipeline._apply_teacache_coefficients(pipeline, {"dev": {"standard": [99.0]}})
210+
assert pipeline.model_config.teacache.coefficients == [0.1, 0.2]
211+
212+
213+
class TestTeaCacheConfigValidation:
214+
"""TeaCacheConfig validation (no pipeline)."""
215+
216+
def test_empty_coefficients_rejected(self):
217+
with pytest.raises(ValueError, match="cannot be empty"):
218+
TeaCacheConfig(enable_teacache=True, coefficients=[])

0 commit comments

Comments
 (0)