Skip to content

Commit 430c557

Browse files
AlanPonnachangithub-actions[bot]sayakpaul
authored
Add support for Magcache (#12744)
* add magcache * formatting * add magcache support with calibration mode * add imports * improvements * Apply style fixes * fix kandinsky errors * add tests and documentation * Apply style fixes * improvements * Apply style fixes * make fix-copies. * minor fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 1b8fc6c commit 430c557

File tree

11 files changed

+883
-2
lines changed

11 files changed

+883
-2
lines changed

docs/source/en/optimization/cache.md

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,57 @@ config = TaylorSeerCacheConfig(
111111
)
112112
pipe.transformer.enable_cache(config)
113113
```
114+
115+
## MagCache
116+
117+
[MagCache](https://github.com/Zehong-Ma/MagCache) accelerates inference by skipping transformer blocks based on the magnitude of the residual update. It observes that the magnitude of updates (Output - Input) decays predictably over the diffusion process. By accumulating an "error budget" based on pre-computed magnitude ratios, it dynamically decides when to skip computation and reuse the previous residual.
118+
119+
MagCache relies on **Magnitude Ratios** (`mag_ratios`), which describe this decay curve. These ratios are specific to the model checkpoint and scheduler.
120+
121+
### Usage
122+
123+
To use MagCache, you typically follow a two-step process: **Calibration** and **Inference**.
124+
125+
1. **Calibration**: Run inference once with `calibrate=True`. The hook will measure the residual magnitudes and print the calculated ratios to the console.
126+
2. **Inference**: Pass these ratios to `MagCacheConfig` to enable acceleration.
127+
128+
```python
129+
import torch
130+
from diffusers import FluxPipeline, MagCacheConfig
131+
132+
pipe = FluxPipeline.from_pretrained(
133+
"black-forest-labs/FLUX.1-schnell",
134+
torch_dtype=torch.bfloat16
135+
).to("cuda")
136+
137+
# 1. Calibration Step
138+
# Run full inference to measure model behavior.
139+
calib_config = MagCacheConfig(calibrate=True, num_inference_steps=4)
140+
pipe.transformer.enable_cache(calib_config)
141+
142+
# Run a prompt to trigger calibration
143+
pipe("A cat playing chess", num_inference_steps=4)
144+
# Logs will print something like: "MagCache Calibration Results: [1.0, 1.37, 0.97, 0.87]"
145+
146+
# 2. Inference Step
147+
# Apply the specific ratios obtained from calibration for optimized speed.
148+
# Note: For Flux models, you can also import defaults:
149+
# from diffusers.hooks.mag_cache import FLUX_MAG_RATIOS
150+
mag_config = MagCacheConfig(
151+
mag_ratios=[1.0, 1.37, 0.97, 0.87],
152+
num_inference_steps=4
153+
)
154+
155+
pipe.transformer.enable_cache(mag_config)
156+
157+
image = pipe("A cat playing chess", num_inference_steps=4).images[0]
158+
```
159+
160+
> [!NOTE]
161+
> `mag_ratios` represent the model's intrinsic magnitude decay curve. Ratios calibrated for a high number of steps (e.g., 50) can be reused for lower step counts (e.g., 20). The implementation uses interpolation to map the curve to the current number of inference steps.
162+
163+
> [!TIP]
164+
> For pipelines that run Classifier-Free Guidance sequentially (like Kandinsky 5.0), the calibration log might print two arrays: one for the Conditional pass and one for the Unconditional pass. In most cases, you should use the first array (Conditional).
165+
166+
> [!TIP]
167+
> For pipelines that run Classifier-Free Guidance in a **batched** manner (like SDXL or Flux), the `hidden_states` processed by the model contain both conditional and unconditional branches concatenated together. The calibration process automatically accounts for this, producing a single array of ratios that represents the joint behavior. You can use this resulting array directly without modification.

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,14 @@
168168
"FirstBlockCacheConfig",
169169
"HookRegistry",
170170
"LayerSkipConfig",
171+
"MagCacheConfig",
171172
"PyramidAttentionBroadcastConfig",
172173
"SmoothedEnergyGuidanceConfig",
173174
"TaylorSeerCacheConfig",
174175
"apply_faster_cache",
175176
"apply_first_block_cache",
176177
"apply_layer_skip",
178+
"apply_mag_cache",
177179
"apply_pyramid_attention_broadcast",
178180
"apply_taylorseer_cache",
179181
]
@@ -932,12 +934,14 @@
932934
FirstBlockCacheConfig,
933935
HookRegistry,
934936
LayerSkipConfig,
937+
MagCacheConfig,
935938
PyramidAttentionBroadcastConfig,
936939
SmoothedEnergyGuidanceConfig,
937940
TaylorSeerCacheConfig,
938941
apply_faster_cache,
939942
apply_first_block_cache,
940943
apply_layer_skip,
944+
apply_mag_cache,
941945
apply_pyramid_attention_broadcast,
942946
apply_taylorseer_cache,
943947
)

src/diffusers/hooks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .hooks import HookRegistry, ModelHook
2424
from .layer_skip import LayerSkipConfig, apply_layer_skip
2525
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
26+
from .mag_cache import MagCacheConfig, apply_mag_cache
2627
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
2728
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
2829
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache

src/diffusers/hooks/_common.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,13 @@
2323
_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
2424
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
2525

26-
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
26+
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = (
27+
"blocks",
28+
"transformer_blocks",
29+
"single_transformer_blocks",
30+
"layers",
31+
"visual_transformer_blocks",
32+
)
2733
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
2834
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
2935

src/diffusers/hooks/_helpers.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class AttentionProcessorMetadata:
2626
class TransformerBlockMetadata:
2727
return_hidden_states_index: int = None
2828
return_encoder_hidden_states_index: int = None
29+
hidden_states_argument_name: str = "hidden_states"
2930

3031
_cls: Type = None
3132
_cached_parameter_indices: Dict[str, int] = None
@@ -169,7 +170,7 @@ def _register_attention_processors_metadata():
169170

170171

171172
def _register_transformer_blocks_metadata():
172-
from ..models.attention import BasicTransformerBlock
173+
from ..models.attention import BasicTransformerBlock, JointTransformerBlock
173174
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
174175
from ..models.transformers.transformer_bria import BriaTransformerBlock
175176
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
@@ -184,6 +185,7 @@ def _register_transformer_blocks_metadata():
184185
HunyuanImageSingleTransformerBlock,
185186
HunyuanImageTransformerBlock,
186187
)
188+
from ..models.transformers.transformer_kandinsky import Kandinsky5TransformerDecoderBlock
187189
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
188190
from ..models.transformers.transformer_mochi import MochiTransformerBlock
189191
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
@@ -331,6 +333,24 @@ def _register_transformer_blocks_metadata():
331333
),
332334
)
333335

336+
TransformerBlockRegistry.register(
337+
model_class=JointTransformerBlock,
338+
metadata=TransformerBlockMetadata(
339+
return_hidden_states_index=1,
340+
return_encoder_hidden_states_index=0,
341+
),
342+
)
343+
344+
# Kandinsky 5.0 (Kandinsky5TransformerDecoderBlock)
345+
TransformerBlockRegistry.register(
346+
model_class=Kandinsky5TransformerDecoderBlock,
347+
metadata=TransformerBlockMetadata(
348+
return_hidden_states_index=0,
349+
return_encoder_hidden_states_index=None,
350+
hidden_states_argument_name="visual_embed",
351+
),
352+
)
353+
334354

335355
# fmt: off
336356
def _skip_attention___ret___hidden_states(self, *args, **kwargs):

0 commit comments

Comments
 (0)