Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
8dbb673
add magcache
AlanPonnachan Nov 29, 2025
a8a57c6
formatting
AlanPonnachan Nov 29, 2025
cbf4b5e
Merge branch 'main' into feat/mag-cache
AlanPonnachan Dec 2, 2025
a6a9fb4
add magcache support with calibration mode
AlanPonnachan Dec 4, 2025
bb0a069
Merge branch 'main' into feat/mag-cache
AlanPonnachan Dec 4, 2025
37f8826
add imports
AlanPonnachan Dec 6, 2025
0a05bec
improvements
AlanPonnachan Dec 7, 2025
535a14e
Merge branch 'main' into feat/mag-cache
AlanPonnachan Dec 7, 2025
ebbebbe
Apply style fixes
github-actions[bot] Dec 8, 2025
00e9b96
fix kandinsky errors
AlanPonnachan Dec 10, 2025
3603e6c
Merge branch 'main' into feat/mag-cache
AlanPonnachan Dec 10, 2025
a282057
add tests and documentation
AlanPonnachan Dec 13, 2025
f672d37
Merge branch 'main' into feat/mag-cache
AlanPonnachan Dec 13, 2025
a3a9d15
Apply style fixes
github-actions[bot] Dec 17, 2025
163ac73
improvements
AlanPonnachan Dec 18, 2025
233e99f
Merge branch 'main' into feat/mag-cache
AlanPonnachan Dec 19, 2025
2cfefb5
Merge branch 'main' into feat/mag-cache
AlanPonnachan Dec 19, 2025
abec8c0
Merge branch 'main' into feat/mag-cache
sayakpaul Dec 22, 2025
cc9685f
Apply style fixes
github-actions[bot] Dec 22, 2025
acc5371
Merge branch 'main' into feat/mag-cache
sayakpaul Dec 31, 2025
ce8c8e6
make fix-copies.
sayakpaul Dec 31, 2025
d60f3cb
Merge pull request #1 from huggingface/AlanPonnachan-feat/mag-cache
AlanPonnachan Dec 31, 2025
6a69fcf
Merge branch 'main' into feat/mag-cache
AlanPonnachan Jan 11, 2026
a9d736e
Merge branch 'main' into feat/mag-cache
sayakpaul Jan 15, 2026
af04ddd
Merge branch 'main' into feat/mag-cache
sayakpaul Feb 3, 2026
6bd5726
minor fixes
AlanPonnachan Feb 3, 2026
825ecac
Merge branch 'main' into feat/mag-cache
AlanPonnachan Feb 3, 2026
864fa2a
Merge branch 'main' into feat/mag-cache
sayakpaul Feb 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions docs/source/en/optimization/cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,57 @@ config = TaylorSeerCacheConfig(
)
pipe.transformer.enable_cache(config)
```

## MagCache

[MagCache](https://github.com/Zehong-Ma/MagCache) accelerates inference by skipping transformer blocks based on the magnitude of the residual update. It observes that the magnitude of updates (Output - Input) decays predictably over the diffusion process. By accumulating an "error budget" based on pre-computed magnitude ratios, it dynamically decides when to skip computation and reuse the previous residual.

MagCache relies on **Magnitude Ratios** (`mag_ratios`), which describe this decay curve. These ratios are specific to the model checkpoint and scheduler.

### Usage

To use MagCache, you typically follow a two-step process: **Calibration** and **Inference**.

1. **Calibration**: Run inference once with `calibrate=True`. The hook will measure the residual magnitudes and print the calculated ratios to the console.
2. **Inference**: Pass these ratios to `MagCacheConfig` to enable acceleration.

```python
import torch
from diffusers import FluxPipeline, MagCacheConfig

pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16
).to("cuda")

# 1. Calibration Step
# Run full inference to measure model behavior.
calib_config = MagCacheConfig(calibrate=True, num_inference_steps=4)
pipe.transformer.enable_cache(calib_config)

# Run a prompt to trigger calibration
pipe("A cat playing chess", num_inference_steps=4)
# Logs will print something like: "MagCache Calibration Results: [1.0, 1.37, 0.97, 0.87]"

# 2. Inference Step
# Apply the specific ratios obtained from calibration for optimized speed.
# Note: For Flux models, you can also import defaults:
# from diffusers.hooks.mag_cache import FLUX_MAG_RATIOS
mag_config = MagCacheConfig(
mag_ratios=[1.0, 1.37, 0.97, 0.87],
num_inference_steps=4
)

pipe.transformer.enable_cache(mag_config)

image = pipe("A cat playing chess", num_inference_steps=4).images[0]
```

> [!NOTE]
> `mag_ratios` represent the model's intrinsic magnitude decay curve. Ratios calibrated for a high number of steps (e.g., 50) can be reused for lower step counts (e.g., 20). The implementation uses interpolation to map the curve to the current number of inference steps.

> [!TIP]
> For pipelines that run Classifier-Free Guidance sequentially (like Kandinsky 5.0), the calibration log might print two arrays: one for the Conditional pass and one for the Unconditional pass. In most cases, you should use the first array (Conditional).

> [!TIP]
> For pipelines that run Classifier-Free Guidance in a **batched** manner (like SDXL or Flux), the `hidden_states` processed by the model contain both conditional and unconditional branches concatenated together. The calibration process automatically accounts for this, producing a single array of ratios that represents the joint behavior. You can use this resulting array directly without modification.
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,14 @@
"FirstBlockCacheConfig",
"HookRegistry",
"LayerSkipConfig",
"MagCacheConfig",
"PyramidAttentionBroadcastConfig",
"SmoothedEnergyGuidanceConfig",
"TaylorSeerCacheConfig",
"apply_faster_cache",
"apply_first_block_cache",
"apply_layer_skip",
"apply_mag_cache",
"apply_pyramid_attention_broadcast",
"apply_taylorseer_cache",
]
Expand Down Expand Up @@ -932,12 +934,14 @@
FirstBlockCacheConfig,
HookRegistry,
LayerSkipConfig,
MagCacheConfig,
PyramidAttentionBroadcastConfig,
SmoothedEnergyGuidanceConfig,
TaylorSeerCacheConfig,
apply_faster_cache,
apply_first_block_cache,
apply_layer_skip,
apply_mag_cache,
apply_pyramid_attention_broadcast,
apply_taylorseer_cache,
)
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .hooks import HookRegistry, ModelHook
from .layer_skip import LayerSkipConfig, apply_layer_skip
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .mag_cache import MagCacheConfig, apply_mag_cache
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
8 changes: 7 additions & 1 deletion src/diffusers/hooks/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@
_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)

_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = (
"blocks",
"transformer_blocks",
"single_transformer_blocks",
"layers",
Comment thread
sayakpaul marked this conversation as resolved.
"visual_transformer_blocks",
)
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")

Expand Down
22 changes: 21 additions & 1 deletion src/diffusers/hooks/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class AttentionProcessorMetadata:
class TransformerBlockMetadata:
return_hidden_states_index: int = None
return_encoder_hidden_states_index: int = None
hidden_states_argument_name: str = "hidden_states"

_cls: Type = None
_cached_parameter_indices: Dict[str, int] = None
Expand Down Expand Up @@ -169,7 +170,7 @@ def _register_attention_processors_metadata():


def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
from ..models.attention import BasicTransformerBlock, JointTransformerBlock
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
from ..models.transformers.transformer_bria import BriaTransformerBlock
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
Expand All @@ -184,6 +185,7 @@ def _register_transformer_blocks_metadata():
HunyuanImageSingleTransformerBlock,
HunyuanImageTransformerBlock,
)
from ..models.transformers.transformer_kandinsky import Kandinsky5TransformerDecoderBlock
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
Expand Down Expand Up @@ -331,6 +333,24 @@ def _register_transformer_blocks_metadata():
),
)

TransformerBlockRegistry.register(
model_class=JointTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
Comment thread
sayakpaul marked this conversation as resolved.

# Kandinsky 5.0 (Kandinsky5TransformerDecoderBlock)
TransformerBlockRegistry.register(
model_class=Kandinsky5TransformerDecoderBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
hidden_states_argument_name="visual_embed",
),
)


# fmt: off
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
Expand Down
Loading
Loading