|
15 | 15 | """Unit tests for TeaCache (CPU-only, no model weights needed).""" |
16 | 16 |
|
17 | 17 | from types import SimpleNamespace |
| 18 | +from typing import Optional |
18 | 19 | from unittest.mock import MagicMock, patch |
19 | 20 |
|
20 | 21 | import pytest |
|
24 | 25 | from tensorrt_llm._torch.visual_gen.teacache import TeaCacheBackend |
25 | 26 |
|
26 | 27 |
|
| 28 | +class _PipelineTeacacheTestDouble: |
| 29 | + """Minimal object for BasePipeline._setup_teacache / _apply_teacache_coefficients tests.""" |
| 30 | + |
| 31 | + def __init__(self, model_config: DiffusionModelConfig): |
| 32 | + self.model_config = model_config |
| 33 | + self.cache_backend = None |
| 34 | + |
| 35 | + def _apply_teacache_coefficients(self, coefficients: Optional[dict] = None): |
| 36 | + return BasePipeline._apply_teacache_coefficients(self, coefficients) |
| 37 | + |
| 38 | + |
27 | 39 | class TestSetupTeacache: |
28 | 40 | """Tests for _setup_teacache coefficient matching and fail-early behavior.""" |
29 | 41 |
|
30 | 42 | def _make_pipeline_mock(self, checkpoint_name, use_ret_steps=False): |
31 | | - pipeline = MagicMock() |
32 | | - pipeline.model_config = DiffusionModelConfig( |
33 | | - pretrained_config=SimpleNamespace(_name_or_path=f"/path/to/{checkpoint_name}/snapshot"), |
34 | | - teacache=TeaCacheConfig( |
35 | | - enable_teacache=True, |
36 | | - teacache_thresh=0.3, |
37 | | - use_ret_steps=use_ret_steps, |
38 | | - ), |
39 | | - skip_create_weights_in_init=True, |
| 43 | + return _PipelineTeacacheTestDouble( |
| 44 | + DiffusionModelConfig( |
| 45 | + pretrained_config=SimpleNamespace( |
| 46 | + _name_or_path=f"/path/to/{checkpoint_name}/snapshot" |
| 47 | + ), |
| 48 | + teacache=TeaCacheConfig( |
| 49 | + enable_teacache=True, |
| 50 | + teacache_thresh=0.3, |
| 51 | + use_ret_steps=use_ret_steps, |
| 52 | + coefficients=None, # auto: resolve from pipeline variant table |
| 53 | + ), |
| 54 | + skip_create_weights_in_init=True, |
| 55 | + ) |
40 | 56 | ) |
41 | | - return pipeline |
| 57 | + |
| 58 | + def _make_pipeline_teacache_enable_only(self, checkpoint_name, use_ret_steps=False): |
| 59 | + """enable_teacache only so teacache_thresh stays unset (exclude_unset omits it).""" |
| 60 | + return _PipelineTeacacheTestDouble( |
| 61 | + DiffusionModelConfig( |
| 62 | + pretrained_config=SimpleNamespace( |
| 63 | + _name_or_path=f"/path/to/{checkpoint_name}/snapshot" |
| 64 | + ), |
| 65 | + teacache=TeaCacheConfig( |
| 66 | + enable_teacache=True, |
| 67 | + use_ret_steps=use_ret_steps, |
| 68 | + coefficients=None, |
| 69 | + ), |
| 70 | + skip_create_weights_in_init=True, |
| 71 | + ) |
| 72 | + ) |
| 73 | + |
| 74 | + def test_setup_teacache_raises_when_no_table_and_no_user_coefficients(self): |
| 75 | + """Fails if TeaCache is on but the pipeline passes no coefficient table.""" |
| 76 | + pipeline = self._make_pipeline_mock("FLUX.1-dev") |
| 77 | + with pytest.raises(ValueError, match="no polynomial coefficients were resolved"): |
| 78 | + BasePipeline._setup_teacache(pipeline, MagicMock(), None) |
| 79 | + |
| 80 | + def test_setup_teacache_raises_when_empty_table(self): |
| 81 | + """Same as no table: nothing to resolve from.""" |
| 82 | + pipeline = self._make_pipeline_mock("FLUX.1-dev") |
| 83 | + with pytest.raises(ValueError, match="no polynomial coefficients were resolved"): |
| 84 | + BasePipeline._setup_teacache(pipeline, MagicMock(), {}) |
| 85 | + |
| 86 | + def test_matching_variant_selects_ret_steps_mode(self): |
| 87 | + """Nested table: use_ret_steps=True selects ret_steps coefficients.""" |
| 88 | + pipeline = self._make_pipeline_mock("FLUX.1-dev", use_ret_steps=True) |
| 89 | + coefficients = { |
| 90 | + "dev": {"standard": [1.0, 2.0, 3.0], "ret_steps": [4.0, 5.0]}, |
| 91 | + } |
| 92 | + with patch.object(TeaCacheBackend, "enable"): |
| 93 | + BasePipeline._setup_teacache(pipeline, MagicMock(), coefficients) |
| 94 | + |
| 95 | + assert pipeline.model_config.teacache.coefficients == [4.0, 5.0] |
| 96 | + |
| 97 | + def test_flat_list_table_entry(self): |
| 98 | + """Table value may be a plain list (no standard/ret_steps nesting).""" |
| 99 | + pipeline = self._make_pipeline_mock("FLUX.1-dev") |
| 100 | + coefficients = {"dev": [11.0, 22.0, 33.0]} |
| 101 | + with patch.object(TeaCacheBackend, "enable"): |
| 102 | + BasePipeline._setup_teacache(pipeline, MagicMock(), coefficients) |
| 103 | + |
| 104 | + assert pipeline.model_config.teacache.coefficients == [11.0, 22.0, 33.0] |
| 105 | + |
| 106 | + def test_default_thresh_from_table_when_user_did_not_set_teacache_thresh(self): |
| 107 | + """default_thresh applies when teacache_thresh was not explicitly set (exclude_unset).""" |
| 108 | + pipeline = self._make_pipeline_teacache_enable_only("FLUX.1-dev") |
| 109 | + builtin = { |
| 110 | + "dev": { |
| 111 | + "standard": [1.0, 2.0], |
| 112 | + "ret_steps": [3.0, 4.0], |
| 113 | + "default_thresh": 0.42, |
| 114 | + }, |
| 115 | + } |
| 116 | + with patch.object(TeaCacheBackend, "enable"): |
| 117 | + BasePipeline._setup_teacache(pipeline, MagicMock(), builtin) |
| 118 | + |
| 119 | + assert pipeline.model_config.teacache.coefficients == [1.0, 2.0] |
| 120 | + assert pipeline.model_config.teacache.teacache_thresh == 0.42 |
| 121 | + |
| 122 | + def test_explicit_identity_coefficients_still_skip_table(self): |
| 123 | + """[1.0, 0.0] is a user override: no variant lookup, no ValueError on unknown path.""" |
| 124 | + pipeline = self._make_pipeline_mock("FLUX.1-unknown-variant") |
| 125 | + pipeline.model_config.teacache = TeaCacheConfig( |
| 126 | + enable_teacache=True, |
| 127 | + teacache_thresh=0.3, |
| 128 | + coefficients=[1.0, 0.0], |
| 129 | + ) |
| 130 | + with patch.object(TeaCacheBackend, "enable"): |
| 131 | + BasePipeline._setup_teacache( |
| 132 | + pipeline, |
| 133 | + MagicMock(), |
| 134 | + {"dev": {"standard": [99.0, 99.0]}}, |
| 135 | + ) |
| 136 | + |
| 137 | + assert pipeline.model_config.teacache.coefficients == [1.0, 0.0] |
42 | 138 |
|
43 | 139 | def test_matching_variant_selects_coefficients(self): |
44 | 140 | """Picks coefficients whose key appears in checkpoint path.""" |
@@ -69,3 +165,54 @@ def test_disabled_teacache_is_noop(self): |
69 | 165 |
|
70 | 166 | BasePipeline._setup_teacache(pipeline, MagicMock(), {"dev": [1.0]}) |
71 | 167 | assert pipeline.cache_backend is None |
| 168 | + |
| 169 | + def test_user_configured_coefficients_skip_variant_matching(self): |
| 170 | + """Explicit TeaCacheConfig.coefficients skips dict lookup (no ValueError).""" |
| 171 | + pipeline = self._make_pipeline_mock("FLUX.1-unknown-variant") |
| 172 | + pipeline.model_config.teacache = TeaCacheConfig( |
| 173 | + enable_teacache=True, |
| 174 | + teacache_thresh=0.3, |
| 175 | + coefficients=[0.25, 0.5, 0.75], |
| 176 | + ) |
| 177 | + builtin = { |
| 178 | + "dev": {"standard": [1.0, 2.0, 3.0], "ret_steps": [4.0, 5.0]}, |
| 179 | + } |
| 180 | + with patch.object(TeaCacheBackend, "enable"): |
| 181 | + BasePipeline._setup_teacache(pipeline, MagicMock(), builtin) |
| 182 | + |
| 183 | + assert pipeline.model_config.teacache.coefficients == [0.25, 0.5, 0.75] |
| 184 | + |
| 185 | + def test_user_configured_coefficients_take_precedence_over_builtin_table(self): |
| 186 | + """User coefficients are not overwritten when a built-in variant would also match.""" |
| 187 | + pipeline = self._make_pipeline_mock("FLUX.1-dev") |
| 188 | + pipeline.model_config.teacache = TeaCacheConfig( |
| 189 | + enable_teacache=True, |
| 190 | + teacache_thresh=0.3, |
| 191 | + coefficients=[9.0, 8.0, 7.0], |
| 192 | + ) |
| 193 | + builtin = { |
| 194 | + "dev": {"standard": [1.0, 2.0, 3.0], "ret_steps": [4.0, 5.0]}, |
| 195 | + } |
| 196 | + with patch.object(TeaCacheBackend, "enable"): |
| 197 | + BasePipeline._setup_teacache(pipeline, MagicMock(), builtin) |
| 198 | + |
| 199 | + assert pipeline.model_config.teacache.coefficients == [9.0, 8.0, 7.0] |
| 200 | + |
| 201 | + def test_apply_teacache_coefficients_only(self): |
| 202 | + """_apply_teacache_coefficients updates config without enabling backend.""" |
| 203 | + pipeline = self._make_pipeline_mock("FLUX.1-unknown-variant") |
| 204 | + pipeline.model_config.teacache = TeaCacheConfig( |
| 205 | + enable_teacache=True, |
| 206 | + coefficients=[0.1, 0.2], |
| 207 | + ) |
| 208 | + # Would raise if variant matching ran |
| 209 | + BasePipeline._apply_teacache_coefficients(pipeline, {"dev": {"standard": [99.0]}}) |
| 210 | + assert pipeline.model_config.teacache.coefficients == [0.1, 0.2] |
| 211 | + |
| 212 | + |
| 213 | +class TestTeaCacheConfigValidation: |
| 214 | + """TeaCacheConfig validation (no pipeline).""" |
| 215 | + |
| 216 | + def test_empty_coefficients_rejected(self): |
| 217 | + with pytest.raises(ValueError, match="cannot be empty"): |
| 218 | + TeaCacheConfig(enable_teacache=True, coefficients=[]) |
0 commit comments