|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -import unittest |
16 | | - |
17 | 15 | import numpy as np |
| 16 | +import pytest |
18 | 17 | import torch |
19 | 18 |
|
20 | 19 | from diffusers import MagCacheConfig, apply_mag_cache |
@@ -70,175 +69,178 @@ def forward(self, hidden_states, encoder_hidden_states=None): |
70 | 69 | return hidden_states, encoder_hidden_states |
71 | 70 |
|
72 | 71 |
|
73 | | -class MagCacheTests(unittest.TestCase): |
74 | | - def setUp(self): |
75 | | - # Register standard dummy block |
76 | | - TransformerBlockRegistry.register( |
77 | | - DummyBlock, |
78 | | - TransformerBlockMetadata(return_hidden_states_index=None, return_encoder_hidden_states_index=None), |
79 | | - ) |
80 | | - # Register tuple block (Flux style) |
81 | | - TransformerBlockRegistry.register( |
82 | | - TupleOutputBlock, |
83 | | - TransformerBlockMetadata(return_hidden_states_index=0, return_encoder_hidden_states_index=1), |
84 | | - ) |
85 | | - |
86 | | - def _set_context(self, model, context_name): |
87 | | - """Helper to set context on all hooks in the model.""" |
88 | | - for module in model.modules(): |
89 | | - if hasattr(module, "_diffusers_hook"): |
90 | | - module._diffusers_hook._set_context(context_name) |
91 | | - |
92 | | - def _get_calibration_data(self, model): |
93 | | - for module in model.modules(): |
94 | | - if hasattr(module, "_diffusers_hook"): |
95 | | - hook = module._diffusers_hook.get_hook("mag_cache_block_hook") |
96 | | - if hook: |
97 | | - return hook.state_manager.get_state().calibration_ratios |
98 | | - return [] |
99 | | - |
100 | | - def test_mag_cache_validation(self): |
101 | | - """Test that missing mag_ratios raises ValueError.""" |
102 | | - with self.assertRaises(ValueError): |
103 | | - MagCacheConfig(num_inference_steps=10, calibrate=False) |
104 | | - |
105 | | - def test_mag_cache_skipping_logic(self): |
106 | | - """ |
107 | | - Tests that MagCache correctly calculates residuals and skips blocks when conditions are met. |
108 | | - """ |
109 | | - model = DummyTransformer() |
110 | | - |
111 | | - # Dummy ratios: [1.0, 1.0] implies 0 accumulated error if we skip |
112 | | - ratios = np.array([1.0, 1.0]) |
113 | | - |
114 | | - config = MagCacheConfig( |
115 | | - threshold=100.0, |
116 | | - num_inference_steps=2, |
117 | | - retention_ratio=0.0, # Enable immediate skipping |
118 | | - max_skip_steps=5, |
119 | | - mag_ratios=ratios, |
120 | | - ) |
121 | | - |
122 | | - apply_mag_cache(model, config) |
123 | | - self._set_context(model, "test_context") |
124 | | - |
125 | | - # Step 0: Input 10.0 -> Output 40.0 (2 blocks * 2x each) |
126 | | - # HeadInput=10. Output=40. Residual=30. |
127 | | - input_t0 = torch.tensor([[[10.0]]]) |
128 | | - output_t0 = model(input_t0) |
129 | | - self.assertTrue(torch.allclose(output_t0, torch.tensor([[[40.0]]])), "Step 0 failed") |
130 | | - |
131 | | - # Step 1: Input 11.0. |
132 | | - # If Skipped: Output = Input(11) + Residual(30) = 41.0 |
133 | | - # If Computed: Output = 11 * 4 = 44.0 |
134 | | - input_t1 = torch.tensor([[[11.0]]]) |
135 | | - output_t1 = model(input_t1) |
136 | | - |
137 | | - self.assertTrue( |
138 | | - torch.allclose(output_t1, torch.tensor([[[41.0]]])), f"Expected Skip (41.0), got {output_t1.item()}" |
139 | | - ) |
140 | | - |
141 | | - def test_mag_cache_retention(self): |
142 | | - """Test that retention_ratio prevents skipping even if error is low.""" |
143 | | - model = DummyTransformer() |
144 | | - # Ratios that imply 0 error, so it *would* skip if retention allowed it |
145 | | - ratios = np.array([1.0, 1.0]) |
146 | | - |
147 | | - config = MagCacheConfig( |
148 | | - threshold=100.0, |
149 | | - num_inference_steps=2, |
150 | | - retention_ratio=1.0, # Force retention for ALL steps |
151 | | - mag_ratios=ratios, |
152 | | - ) |
153 | | - |
154 | | - apply_mag_cache(model, config) |
155 | | - self._set_context(model, "test_context") |
156 | | - |
157 | | - # Step 0 |
158 | | - model(torch.tensor([[[10.0]]])) |
159 | | - |
160 | | - # Step 1: Should COMPUTE (44.0) not SKIP (41.0) because of retention |
161 | | - input_t1 = torch.tensor([[[11.0]]]) |
162 | | - output_t1 = model(input_t1) |
163 | | - |
164 | | - self.assertTrue( |
165 | | - torch.allclose(output_t1, torch.tensor([[[44.0]]])), |
166 | | - f"Expected Compute (44.0) due to retention, got {output_t1.item()}", |
167 | | - ) |
168 | | - |
169 | | - def test_mag_cache_tuple_outputs(self): |
170 | | - """Test compatibility with models returning (hidden, encoder_hidden) like Flux.""" |
171 | | - model = TupleTransformer() |
172 | | - ratios = np.array([1.0, 1.0]) |
173 | | - |
174 | | - config = MagCacheConfig(threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=ratios) |
175 | | - |
176 | | - apply_mag_cache(model, config) |
177 | | - self._set_context(model, "test_context") |
178 | | - |
179 | | - # Step 0: Compute. Input 10.0 -> Output 20.0 (1 block * 2x) |
180 | | - # Residual = 10.0 |
181 | | - input_t0 = torch.tensor([[[10.0]]]) |
182 | | - enc_t0 = torch.tensor([[[1.0]]]) |
183 | | - out_0, _ = model(input_t0, encoder_hidden_states=enc_t0) |
184 | | - self.assertTrue(torch.allclose(out_0, torch.tensor([[[20.0]]]))) |
185 | | - |
186 | | - # Step 1: Skip. Input 11.0. |
187 | | - # Skipped Output = 11 + 10 = 21.0 |
188 | | - input_t1 = torch.tensor([[[11.0]]]) |
189 | | - out_1, _ = model(input_t1, encoder_hidden_states=enc_t0) |
190 | | - |
191 | | - self.assertTrue( |
192 | | - torch.allclose(out_1, torch.tensor([[[21.0]]])), f"Tuple skip failed. Expected 21.0, got {out_1.item()}" |
193 | | - ) |
194 | | - |
195 | | - def test_mag_cache_reset(self): |
196 | | - """Test that state resets correctly after num_inference_steps.""" |
197 | | - model = DummyTransformer() |
198 | | - config = MagCacheConfig( |
199 | | - threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=np.array([1.0, 1.0]) |
200 | | - ) |
201 | | - apply_mag_cache(model, config) |
202 | | - self._set_context(model, "test_context") |
203 | | - |
204 | | - input_t = torch.ones(1, 1, 1) |
205 | | - |
206 | | - model(input_t) # Step 0 |
207 | | - model(input_t) # Step 1 (Skipped) |
208 | | - |
209 | | - # Step 2 (Reset -> Step 0) -> Should Compute |
210 | | - # Input 2.0 -> Output 8.0 |
211 | | - input_t2 = torch.tensor([[[2.0]]]) |
212 | | - output_t2 = model(input_t2) |
213 | | - |
214 | | - self.assertTrue(torch.allclose(output_t2, torch.tensor([[[8.0]]])), "State did not reset correctly") |
215 | | - |
216 | | - def test_mag_cache_calibration(self): |
217 | | - """Test that calibration mode records ratios.""" |
218 | | - model = DummyTransformer() |
219 | | - config = MagCacheConfig(num_inference_steps=2, calibrate=True) |
220 | | - apply_mag_cache(model, config) |
221 | | - self._set_context(model, "test_context") |
222 | | - |
223 | | - # Step 0 |
224 | | - # HeadInput = 10. Output = 40. Residual = 30. |
225 | | - # Ratio 0 is placeholder 1.0 |
226 | | - model(torch.tensor([[[10.0]]])) |
227 | | - |
228 | | - # Check intermediate state |
229 | | - ratios = self._get_calibration_data(model) |
230 | | - self.assertEqual(len(ratios), 1) |
231 | | - self.assertEqual(ratios[0], 1.0) |
232 | | - |
233 | | - # Step 1 |
234 | | - # HeadInput = 10. Output = 40. Residual = 30. |
235 | | - # PrevResidual = 30. CurrResidual = 30. |
236 | | - # Ratio = 30/30 = 1.0 |
237 | | - model(torch.tensor([[[10.0]]])) |
238 | | - |
239 | | - # Verify it computes fully (no skip) |
240 | | - # If it skipped, output would be 41.0. It should be 40.0 |
241 | | - # Actually in test setup, input is same (10.0) so output 40.0. |
242 | | - # Let's ensure list is empty after reset (end of step 1) |
243 | | - ratios_after = self._get_calibration_data(model) |
244 | | - self.assertEqual(ratios_after, []) |
| 72 | +@pytest.fixture(autouse=True) |
| 73 | +def register_dummy_blocks(): |
| 74 | + # Register standard dummy block |
| 75 | + TransformerBlockRegistry.register( |
| 76 | + DummyBlock, |
| 77 | + TransformerBlockMetadata(return_hidden_states_index=None, return_encoder_hidden_states_index=None), |
| 78 | + ) |
| 79 | + # Register tuple block (Flux style) |
| 80 | + TransformerBlockRegistry.register( |
| 81 | + TupleOutputBlock, |
| 82 | + TransformerBlockMetadata(return_hidden_states_index=0, return_encoder_hidden_states_index=1), |
| 83 | + ) |
| 84 | + |
| 85 | + |
| 86 | +def _set_context(model, context_name): |
| 87 | + """Helper to set context on all hooks in the model.""" |
| 88 | + for module in model.modules(): |
| 89 | + if hasattr(module, "_diffusers_hook"): |
| 90 | + module._diffusers_hook._set_context(context_name) |
| 91 | + |
| 92 | + |
| 93 | +def _get_calibration_data(model): |
| 94 | + for module in model.modules(): |
| 95 | + if hasattr(module, "_diffusers_hook"): |
| 96 | + hook = module._diffusers_hook.get_hook("mag_cache_block_hook") |
| 97 | + if hook: |
| 98 | + return hook.state_manager.get_state().calibration_ratios |
| 99 | + return [] |
| 100 | + |
| 101 | + |
| 102 | +def test_mag_cache_validation(): |
| 103 | + """Test that missing mag_ratios raises ValueError.""" |
| 104 | + with pytest.raises(ValueError): |
| 105 | + MagCacheConfig(num_inference_steps=10, calibrate=False) |
| 106 | + |
| 107 | + |
| 108 | +def test_mag_cache_skipping_logic(): |
| 109 | + """ |
| 110 | + Tests that MagCache correctly calculates residuals and skips blocks when conditions are met. |
| 111 | + """ |
| 112 | + model = DummyTransformer() |
| 113 | + |
| 114 | + # Dummy ratios: [1.0, 1.0] implies 0 accumulated error if we skip |
| 115 | + ratios = np.array([1.0, 1.0]) |
| 116 | + |
| 117 | + config = MagCacheConfig( |
| 118 | + threshold=100.0, |
| 119 | + num_inference_steps=2, |
| 120 | + retention_ratio=0.0, # Enable immediate skipping |
| 121 | + max_skip_steps=5, |
| 122 | + mag_ratios=ratios, |
| 123 | + ) |
| 124 | + |
| 125 | + apply_mag_cache(model, config) |
| 126 | + _set_context(model, "test_context") |
| 127 | + |
| 128 | + # Step 0: Input 10.0 -> Output 40.0 (2 blocks * 2x each) |
| 129 | + # HeadInput=10. Output=40. Residual=30. |
| 130 | + input_t0 = torch.tensor([[[10.0]]]) |
| 131 | + output_t0 = model(input_t0) |
| 132 | + assert torch.allclose(output_t0, torch.tensor([[[40.0]]])), "Step 0 failed" |
| 133 | + |
| 134 | + # Step 1: Input 11.0. |
| 135 | + # If Skipped: Output = Input(11) + Residual(30) = 41.0 |
| 136 | + # If Computed: Output = 11 * 4 = 44.0 |
| 137 | + input_t1 = torch.tensor([[[11.0]]]) |
| 138 | + output_t1 = model(input_t1) |
| 139 | + |
| 140 | + assert torch.allclose(output_t1, torch.tensor([[[41.0]]])), f"Expected Skip (41.0), got {output_t1.item()}" |
| 141 | + |
| 142 | + |
| 143 | +def test_mag_cache_retention(): |
| 144 | + """Test that retention_ratio prevents skipping even if error is low.""" |
| 145 | + model = DummyTransformer() |
| 146 | + # Ratios that imply 0 error, so it *would* skip if retention allowed it |
| 147 | + ratios = np.array([1.0, 1.0]) |
| 148 | + |
| 149 | + config = MagCacheConfig( |
| 150 | + threshold=100.0, |
| 151 | + num_inference_steps=2, |
| 152 | + retention_ratio=1.0, # Force retention for ALL steps |
| 153 | + mag_ratios=ratios, |
| 154 | + ) |
| 155 | + |
| 156 | + apply_mag_cache(model, config) |
| 157 | + _set_context(model, "test_context") |
| 158 | + |
| 159 | + # Step 0 |
| 160 | + model(torch.tensor([[[10.0]]])) |
| 161 | + |
| 162 | + # Step 1: Should COMPUTE (44.0) not SKIP (41.0) because of retention |
| 163 | + input_t1 = torch.tensor([[[11.0]]]) |
| 164 | + output_t1 = model(input_t1) |
| 165 | + |
| 166 | + assert torch.allclose(output_t1, torch.tensor([[[44.0]]])), ( |
| 167 | + f"Expected Compute (44.0) due to retention, got {output_t1.item()}" |
| 168 | + ) |
| 169 | + |
| 170 | + |
| 171 | +def test_mag_cache_tuple_outputs(): |
| 172 | + """Test compatibility with models returning (hidden, encoder_hidden) like Flux.""" |
| 173 | + model = TupleTransformer() |
| 174 | + ratios = np.array([1.0, 1.0]) |
| 175 | + |
| 176 | + config = MagCacheConfig(threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=ratios) |
| 177 | + |
| 178 | + apply_mag_cache(model, config) |
| 179 | + _set_context(model, "test_context") |
| 180 | + |
| 181 | + # Step 0: Compute. Input 10.0 -> Output 20.0 (1 block * 2x) |
| 182 | + # Residual = 10.0 |
| 183 | + input_t0 = torch.tensor([[[10.0]]]) |
| 184 | + enc_t0 = torch.tensor([[[1.0]]]) |
| 185 | + out_0, _ = model(input_t0, encoder_hidden_states=enc_t0) |
| 186 | + assert torch.allclose(out_0, torch.tensor([[[20.0]]])) |
| 187 | + |
| 188 | + # Step 1: Skip. Input 11.0. |
| 189 | + # Skipped Output = 11 + 10 = 21.0 |
| 190 | + input_t1 = torch.tensor([[[11.0]]]) |
| 191 | + out_1, _ = model(input_t1, encoder_hidden_states=enc_t0) |
| 192 | + |
| 193 | + assert torch.allclose(out_1, torch.tensor([[[21.0]]])), f"Tuple skip failed. Expected 21.0, got {out_1.item()}" |
| 194 | + |
| 195 | + |
| 196 | +def test_mag_cache_reset(): |
| 197 | + """Test that state resets correctly after num_inference_steps.""" |
| 198 | + model = DummyTransformer() |
| 199 | + config = MagCacheConfig( |
| 200 | + threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=np.array([1.0, 1.0]) |
| 201 | + ) |
| 202 | + apply_mag_cache(model, config) |
| 203 | + _set_context(model, "test_context") |
| 204 | + |
| 205 | + input_t = torch.ones(1, 1, 1) |
| 206 | + |
| 207 | + model(input_t) # Step 0 |
| 208 | + model(input_t) # Step 1 (Skipped) |
| 209 | + |
| 210 | + # Step 2 (Reset -> Step 0) -> Should Compute |
| 211 | + # Input 2.0 -> Output 8.0 |
| 212 | + input_t2 = torch.tensor([[[2.0]]]) |
| 213 | + output_t2 = model(input_t2) |
| 214 | + |
| 215 | + assert torch.allclose(output_t2, torch.tensor([[[8.0]]])), "State did not reset correctly" |
| 216 | + |
| 217 | + |
| 218 | +def test_mag_cache_calibration(): |
| 219 | + """Test that calibration mode records ratios.""" |
| 220 | + model = DummyTransformer() |
| 221 | + config = MagCacheConfig(num_inference_steps=2, calibrate=True) |
| 222 | + apply_mag_cache(model, config) |
| 223 | + _set_context(model, "test_context") |
| 224 | + |
| 225 | + # Step 0 |
| 226 | + # HeadInput = 10. Output = 40. Residual = 30. |
| 227 | + # Ratio 0 is placeholder 1.0 |
| 228 | + model(torch.tensor([[[10.0]]])) |
| 229 | + |
| 230 | + # Check intermediate state |
| 231 | + ratios = _get_calibration_data(model) |
| 232 | + assert len(ratios) == 1 |
| 233 | + assert ratios[0] == 1.0 |
| 234 | + |
| 235 | + # Step 1 |
| 236 | + # HeadInput = 10. Output = 40. Residual = 30. |
| 237 | + # PrevResidual = 30. CurrResidual = 30. |
| 238 | + # Ratio = 30/30 = 1.0 |
| 239 | + model(torch.tensor([[[10.0]]])) |
| 240 | + |
| 241 | + # Verify it computes fully (no skip) |
| 242 | + # If it skipped, output would be 41.0. It should be 40.0 |
| 243 | + # Actually in test setup, input is same (10.0) so output 40.0. |
| 244 | + # Let's ensure list is empty after reset (end of step 1) |
| 245 | + ratios_after = _get_calibration_data(model) |
| 246 | + assert ratios_after == [] |
0 commit comments