Skip to content

Commit ed07118

Browse files
sayakpauldg845
andauthored
[tests] refactor caching tests. (#13235)
* refactor magcache tests. * include taylorseer in the caching mixin. * up * add back magcache and migrate to pytest --------- Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
1 parent 2fa9b93 commit ed07118

4 files changed

Lines changed: 392 additions & 175 deletions

File tree

tests/hooks/test_mag_cache.py

Lines changed: 176 additions & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import unittest
16-
1715
import numpy as np
16+
import pytest
1817
import torch
1918

2019
from diffusers import MagCacheConfig, apply_mag_cache
@@ -70,175 +69,178 @@ def forward(self, hidden_states, encoder_hidden_states=None):
7069
return hidden_states, encoder_hidden_states
7170

7271

73-
class MagCacheTests(unittest.TestCase):
74-
def setUp(self):
75-
# Register standard dummy block
76-
TransformerBlockRegistry.register(
77-
DummyBlock,
78-
TransformerBlockMetadata(return_hidden_states_index=None, return_encoder_hidden_states_index=None),
79-
)
80-
# Register tuple block (Flux style)
81-
TransformerBlockRegistry.register(
82-
TupleOutputBlock,
83-
TransformerBlockMetadata(return_hidden_states_index=0, return_encoder_hidden_states_index=1),
84-
)
85-
86-
def _set_context(self, model, context_name):
87-
"""Helper to set context on all hooks in the model."""
88-
for module in model.modules():
89-
if hasattr(module, "_diffusers_hook"):
90-
module._diffusers_hook._set_context(context_name)
91-
92-
def _get_calibration_data(self, model):
93-
for module in model.modules():
94-
if hasattr(module, "_diffusers_hook"):
95-
hook = module._diffusers_hook.get_hook("mag_cache_block_hook")
96-
if hook:
97-
return hook.state_manager.get_state().calibration_ratios
98-
return []
99-
100-
def test_mag_cache_validation(self):
101-
"""Test that missing mag_ratios raises ValueError."""
102-
with self.assertRaises(ValueError):
103-
MagCacheConfig(num_inference_steps=10, calibrate=False)
104-
105-
def test_mag_cache_skipping_logic(self):
106-
"""
107-
Tests that MagCache correctly calculates residuals and skips blocks when conditions are met.
108-
"""
109-
model = DummyTransformer()
110-
111-
# Dummy ratios: [1.0, 1.0] implies 0 accumulated error if we skip
112-
ratios = np.array([1.0, 1.0])
113-
114-
config = MagCacheConfig(
115-
threshold=100.0,
116-
num_inference_steps=2,
117-
retention_ratio=0.0, # Enable immediate skipping
118-
max_skip_steps=5,
119-
mag_ratios=ratios,
120-
)
121-
122-
apply_mag_cache(model, config)
123-
self._set_context(model, "test_context")
124-
125-
# Step 0: Input 10.0 -> Output 40.0 (2 blocks * 2x each)
126-
# HeadInput=10. Output=40. Residual=30.
127-
input_t0 = torch.tensor([[[10.0]]])
128-
output_t0 = model(input_t0)
129-
self.assertTrue(torch.allclose(output_t0, torch.tensor([[[40.0]]])), "Step 0 failed")
130-
131-
# Step 1: Input 11.0.
132-
# If Skipped: Output = Input(11) + Residual(30) = 41.0
133-
# If Computed: Output = 11 * 4 = 44.0
134-
input_t1 = torch.tensor([[[11.0]]])
135-
output_t1 = model(input_t1)
136-
137-
self.assertTrue(
138-
torch.allclose(output_t1, torch.tensor([[[41.0]]])), f"Expected Skip (41.0), got {output_t1.item()}"
139-
)
140-
141-
def test_mag_cache_retention(self):
142-
"""Test that retention_ratio prevents skipping even if error is low."""
143-
model = DummyTransformer()
144-
# Ratios that imply 0 error, so it *would* skip if retention allowed it
145-
ratios = np.array([1.0, 1.0])
146-
147-
config = MagCacheConfig(
148-
threshold=100.0,
149-
num_inference_steps=2,
150-
retention_ratio=1.0, # Force retention for ALL steps
151-
mag_ratios=ratios,
152-
)
153-
154-
apply_mag_cache(model, config)
155-
self._set_context(model, "test_context")
156-
157-
# Step 0
158-
model(torch.tensor([[[10.0]]]))
159-
160-
# Step 1: Should COMPUTE (44.0) not SKIP (41.0) because of retention
161-
input_t1 = torch.tensor([[[11.0]]])
162-
output_t1 = model(input_t1)
163-
164-
self.assertTrue(
165-
torch.allclose(output_t1, torch.tensor([[[44.0]]])),
166-
f"Expected Compute (44.0) due to retention, got {output_t1.item()}",
167-
)
168-
169-
def test_mag_cache_tuple_outputs(self):
170-
"""Test compatibility with models returning (hidden, encoder_hidden) like Flux."""
171-
model = TupleTransformer()
172-
ratios = np.array([1.0, 1.0])
173-
174-
config = MagCacheConfig(threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=ratios)
175-
176-
apply_mag_cache(model, config)
177-
self._set_context(model, "test_context")
178-
179-
# Step 0: Compute. Input 10.0 -> Output 20.0 (1 block * 2x)
180-
# Residual = 10.0
181-
input_t0 = torch.tensor([[[10.0]]])
182-
enc_t0 = torch.tensor([[[1.0]]])
183-
out_0, _ = model(input_t0, encoder_hidden_states=enc_t0)
184-
self.assertTrue(torch.allclose(out_0, torch.tensor([[[20.0]]])))
185-
186-
# Step 1: Skip. Input 11.0.
187-
# Skipped Output = 11 + 10 = 21.0
188-
input_t1 = torch.tensor([[[11.0]]])
189-
out_1, _ = model(input_t1, encoder_hidden_states=enc_t0)
190-
191-
self.assertTrue(
192-
torch.allclose(out_1, torch.tensor([[[21.0]]])), f"Tuple skip failed. Expected 21.0, got {out_1.item()}"
193-
)
194-
195-
def test_mag_cache_reset(self):
196-
"""Test that state resets correctly after num_inference_steps."""
197-
model = DummyTransformer()
198-
config = MagCacheConfig(
199-
threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=np.array([1.0, 1.0])
200-
)
201-
apply_mag_cache(model, config)
202-
self._set_context(model, "test_context")
203-
204-
input_t = torch.ones(1, 1, 1)
205-
206-
model(input_t) # Step 0
207-
model(input_t) # Step 1 (Skipped)
208-
209-
# Step 2 (Reset -> Step 0) -> Should Compute
210-
# Input 2.0 -> Output 8.0
211-
input_t2 = torch.tensor([[[2.0]]])
212-
output_t2 = model(input_t2)
213-
214-
self.assertTrue(torch.allclose(output_t2, torch.tensor([[[8.0]]])), "State did not reset correctly")
215-
216-
def test_mag_cache_calibration(self):
217-
"""Test that calibration mode records ratios."""
218-
model = DummyTransformer()
219-
config = MagCacheConfig(num_inference_steps=2, calibrate=True)
220-
apply_mag_cache(model, config)
221-
self._set_context(model, "test_context")
222-
223-
# Step 0
224-
# HeadInput = 10. Output = 40. Residual = 30.
225-
# Ratio 0 is placeholder 1.0
226-
model(torch.tensor([[[10.0]]]))
227-
228-
# Check intermediate state
229-
ratios = self._get_calibration_data(model)
230-
self.assertEqual(len(ratios), 1)
231-
self.assertEqual(ratios[0], 1.0)
232-
233-
# Step 1
234-
# HeadInput = 10. Output = 40. Residual = 30.
235-
# PrevResidual = 30. CurrResidual = 30.
236-
# Ratio = 30/30 = 1.0
237-
model(torch.tensor([[[10.0]]]))
238-
239-
# Verify it computes fully (no skip)
240-
# If it skipped, output would be 41.0. It should be 40.0
241-
# Actually in test setup, input is same (10.0) so output 40.0.
242-
# Let's ensure list is empty after reset (end of step 1)
243-
ratios_after = self._get_calibration_data(model)
244-
self.assertEqual(ratios_after, [])
72+
@pytest.fixture(autouse=True)
73+
def register_dummy_blocks():
74+
# Register standard dummy block
75+
TransformerBlockRegistry.register(
76+
DummyBlock,
77+
TransformerBlockMetadata(return_hidden_states_index=None, return_encoder_hidden_states_index=None),
78+
)
79+
# Register tuple block (Flux style)
80+
TransformerBlockRegistry.register(
81+
TupleOutputBlock,
82+
TransformerBlockMetadata(return_hidden_states_index=0, return_encoder_hidden_states_index=1),
83+
)
84+
85+
86+
def _set_context(model, context_name):
87+
"""Helper to set context on all hooks in the model."""
88+
for module in model.modules():
89+
if hasattr(module, "_diffusers_hook"):
90+
module._diffusers_hook._set_context(context_name)
91+
92+
93+
def _get_calibration_data(model):
94+
for module in model.modules():
95+
if hasattr(module, "_diffusers_hook"):
96+
hook = module._diffusers_hook.get_hook("mag_cache_block_hook")
97+
if hook:
98+
return hook.state_manager.get_state().calibration_ratios
99+
return []
100+
101+
102+
def test_mag_cache_validation():
103+
"""Test that missing mag_ratios raises ValueError."""
104+
with pytest.raises(ValueError):
105+
MagCacheConfig(num_inference_steps=10, calibrate=False)
106+
107+
108+
def test_mag_cache_skipping_logic():
109+
"""
110+
Tests that MagCache correctly calculates residuals and skips blocks when conditions are met.
111+
"""
112+
model = DummyTransformer()
113+
114+
# Dummy ratios: [1.0, 1.0] implies 0 accumulated error if we skip
115+
ratios = np.array([1.0, 1.0])
116+
117+
config = MagCacheConfig(
118+
threshold=100.0,
119+
num_inference_steps=2,
120+
retention_ratio=0.0, # Enable immediate skipping
121+
max_skip_steps=5,
122+
mag_ratios=ratios,
123+
)
124+
125+
apply_mag_cache(model, config)
126+
_set_context(model, "test_context")
127+
128+
# Step 0: Input 10.0 -> Output 40.0 (2 blocks * 2x each)
129+
# HeadInput=10. Output=40. Residual=30.
130+
input_t0 = torch.tensor([[[10.0]]])
131+
output_t0 = model(input_t0)
132+
assert torch.allclose(output_t0, torch.tensor([[[40.0]]])), "Step 0 failed"
133+
134+
# Step 1: Input 11.0.
135+
# If Skipped: Output = Input(11) + Residual(30) = 41.0
136+
# If Computed: Output = 11 * 4 = 44.0
137+
input_t1 = torch.tensor([[[11.0]]])
138+
output_t1 = model(input_t1)
139+
140+
assert torch.allclose(output_t1, torch.tensor([[[41.0]]])), f"Expected Skip (41.0), got {output_t1.item()}"
141+
142+
143+
def test_mag_cache_retention():
144+
"""Test that retention_ratio prevents skipping even if error is low."""
145+
model = DummyTransformer()
146+
# Ratios that imply 0 error, so it *would* skip if retention allowed it
147+
ratios = np.array([1.0, 1.0])
148+
149+
config = MagCacheConfig(
150+
threshold=100.0,
151+
num_inference_steps=2,
152+
retention_ratio=1.0, # Force retention for ALL steps
153+
mag_ratios=ratios,
154+
)
155+
156+
apply_mag_cache(model, config)
157+
_set_context(model, "test_context")
158+
159+
# Step 0
160+
model(torch.tensor([[[10.0]]]))
161+
162+
# Step 1: Should COMPUTE (44.0) not SKIP (41.0) because of retention
163+
input_t1 = torch.tensor([[[11.0]]])
164+
output_t1 = model(input_t1)
165+
166+
assert torch.allclose(output_t1, torch.tensor([[[44.0]]])), (
167+
f"Expected Compute (44.0) due to retention, got {output_t1.item()}"
168+
)
169+
170+
171+
def test_mag_cache_tuple_outputs():
172+
"""Test compatibility with models returning (hidden, encoder_hidden) like Flux."""
173+
model = TupleTransformer()
174+
ratios = np.array([1.0, 1.0])
175+
176+
config = MagCacheConfig(threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=ratios)
177+
178+
apply_mag_cache(model, config)
179+
_set_context(model, "test_context")
180+
181+
# Step 0: Compute. Input 10.0 -> Output 20.0 (1 block * 2x)
182+
# Residual = 10.0
183+
input_t0 = torch.tensor([[[10.0]]])
184+
enc_t0 = torch.tensor([[[1.0]]])
185+
out_0, _ = model(input_t0, encoder_hidden_states=enc_t0)
186+
assert torch.allclose(out_0, torch.tensor([[[20.0]]]))
187+
188+
# Step 1: Skip. Input 11.0.
189+
# Skipped Output = 11 + 10 = 21.0
190+
input_t1 = torch.tensor([[[11.0]]])
191+
out_1, _ = model(input_t1, encoder_hidden_states=enc_t0)
192+
193+
assert torch.allclose(out_1, torch.tensor([[[21.0]]])), f"Tuple skip failed. Expected 21.0, got {out_1.item()}"
194+
195+
196+
def test_mag_cache_reset():
197+
"""Test that state resets correctly after num_inference_steps."""
198+
model = DummyTransformer()
199+
config = MagCacheConfig(
200+
threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=np.array([1.0, 1.0])
201+
)
202+
apply_mag_cache(model, config)
203+
_set_context(model, "test_context")
204+
205+
input_t = torch.ones(1, 1, 1)
206+
207+
model(input_t) # Step 0
208+
model(input_t) # Step 1 (Skipped)
209+
210+
# Step 2 (Reset -> Step 0) -> Should Compute
211+
# Input 2.0 -> Output 8.0
212+
input_t2 = torch.tensor([[[2.0]]])
213+
output_t2 = model(input_t2)
214+
215+
assert torch.allclose(output_t2, torch.tensor([[[8.0]]])), "State did not reset correctly"
216+
217+
218+
def test_mag_cache_calibration():
219+
"""Test that calibration mode records ratios."""
220+
model = DummyTransformer()
221+
config = MagCacheConfig(num_inference_steps=2, calibrate=True)
222+
apply_mag_cache(model, config)
223+
_set_context(model, "test_context")
224+
225+
# Step 0
226+
# HeadInput = 10. Output = 40. Residual = 30.
227+
# Ratio 0 is placeholder 1.0
228+
model(torch.tensor([[[10.0]]]))
229+
230+
# Check intermediate state
231+
ratios = _get_calibration_data(model)
232+
assert len(ratios) == 1
233+
assert ratios[0] == 1.0
234+
235+
# Step 1
236+
# HeadInput = 10. Output = 40. Residual = 30.
237+
# PrevResidual = 30. CurrResidual = 30.
238+
# Ratio = 30/30 = 1.0
239+
model(torch.tensor([[[10.0]]]))
240+
241+
# Verify it computes fully (no skip)
242+
# If it skipped, output would be 41.0. It should be 40.0
243+
# Actually in test setup, input is same (10.0) so output 40.0.
244+
# Let's ensure list is empty after reset (end of step 1)
245+
ratios_after = _get_calibration_data(model)
246+
assert ratios_after == []

tests/models/testing_utils/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,12 @@
55
FasterCacheTesterMixin,
66
FirstBlockCacheConfigMixin,
77
FirstBlockCacheTesterMixin,
8+
MagCacheConfigMixin,
9+
MagCacheTesterMixin,
810
PyramidAttentionBroadcastConfigMixin,
911
PyramidAttentionBroadcastTesterMixin,
12+
TaylorSeerCacheConfigMixin,
13+
TaylorSeerCacheTesterMixin,
1014
)
1115
from .common import BaseModelTesterConfig, ModelTesterMixin
1216
from .compile import TorchCompileTesterMixin
@@ -52,6 +56,8 @@
5256
"FasterCacheTesterMixin",
5357
"FirstBlockCacheConfigMixin",
5458
"FirstBlockCacheTesterMixin",
59+
"MagCacheConfigMixin",
60+
"MagCacheTesterMixin",
5561
"GGUFCompileTesterMixin",
5662
"GGUFConfigMixin",
5763
"GGUFTesterMixin",
@@ -67,6 +73,8 @@
6773
"ModelTesterMixin",
6874
"PyramidAttentionBroadcastConfigMixin",
6975
"PyramidAttentionBroadcastTesterMixin",
76+
"TaylorSeerCacheConfigMixin",
77+
"TaylorSeerCacheTesterMixin",
7078
"QuantizationCompileTesterMixin",
7179
"QuantizationTesterMixin",
7280
"QuantoCompileTesterMixin",

0 commit comments

Comments
 (0)