Skip to content

Commit 447e571

Browse files
sippycodernmnWithNucleusdg845github-actions[bot]
authored
NucleusMoE-Image (#13317)
* adding NucleusMoE-Image model * update system prompt * Add text kv caching * Class/function name changes * add missing imports * add RoPE credits * Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * update defaults * Update src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * review updates * fix the tests * clean up * update apply_text_kv_cache * SwiGLUExperts addition * fuse SwiGLUExperts up and gate proj * Update src/diffusers/hooks/text_kv_cache.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/hooks/text_kv_cache.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/hooks/text_kv_cache.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/hooks/text_kv_cache.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * _SharedCacheKey -> TextKVCacheState * Apply style fixes * Run python utils/check_copies.py --fix_and_overwrite python utils/check_dummies.py --fix_and_overwrite * Apply style fixes * run `make fix-copies` * fix import * refactor text KV cache to be managed by StateManager --------- Co-authored-by: Murali Nandan Nagarapu <nmn@withnucleus.ai> Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 5adc544 commit 447e571

File tree

17 files changed

+2445
-3
lines changed

17 files changed

+2445
-3
lines changed

src/diffusers/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,22 +169,23 @@
169169
"PyramidAttentionBroadcastConfig",
170170
"SmoothedEnergyGuidanceConfig",
171171
"TaylorSeerCacheConfig",
172+
"TextKVCacheConfig",
172173
"apply_faster_cache",
173174
"apply_first_block_cache",
174175
"apply_layer_skip",
175176
"apply_mag_cache",
176177
"apply_pyramid_attention_broadcast",
177178
"apply_taylorseer_cache",
179+
"apply_text_kv_cache",
178180
]
179181
)
180182
_import_structure["image_processor"] = [
181-
"IPAdapterMaskProcessor",
182183
"InpaintProcessor",
184+
"IPAdapterMaskProcessor",
183185
"PixArtImageProcessor",
184186
"VaeImageProcessor",
185187
"VaeImageProcessorLDM3D",
186188
]
187-
_import_structure["video_processor"] = ["VideoProcessor"]
188189
_import_structure["models"].extend(
189190
[
190191
"AllegroTransformer3DModel",
@@ -262,6 +263,7 @@
262263
"MotionAdapter",
263264
"MultiAdapter",
264265
"MultiControlNetModel",
266+
"NucleusMoEImageTransformer2DModel",
265267
"OmniGenTransformer2DModel",
266268
"OvisImageTransformer2DModel",
267269
"ParallelConfig",
@@ -396,6 +398,7 @@
396398
]
397399
)
398400
_import_structure["training_utils"] = ["EMAModel"]
401+
_import_structure["video_processor"] = ["VideoProcessor"]
399402

400403
try:
401404
if not (is_torch_available() and is_scipy_available()):
@@ -613,6 +616,7 @@
613616
"MarigoldNormalsPipeline",
614617
"MochiPipeline",
615618
"MusicLDMPipeline",
619+
"NucleusMoEImagePipeline",
616620
"OmniGenPipeline",
617621
"OvisImagePipeline",
618622
"PaintByExamplePipeline",
@@ -967,12 +971,14 @@
967971
PyramidAttentionBroadcastConfig,
968972
SmoothedEnergyGuidanceConfig,
969973
TaylorSeerCacheConfig,
974+
TextKVCacheConfig,
970975
apply_faster_cache,
971976
apply_first_block_cache,
972977
apply_layer_skip,
973978
apply_mag_cache,
974979
apply_pyramid_attention_broadcast,
975980
apply_taylorseer_cache,
981+
apply_text_kv_cache,
976982
)
977983
from .image_processor import (
978984
InpaintProcessor,
@@ -1057,6 +1063,7 @@
10571063
MotionAdapter,
10581064
MultiAdapter,
10591065
MultiControlNetModel,
1066+
NucleusMoEImageTransformer2DModel,
10601067
OmniGenTransformer2DModel,
10611068
OvisImageTransformer2DModel,
10621069
ParallelConfig,
@@ -1384,6 +1391,7 @@
13841391
MarigoldNormalsPipeline,
13851392
MochiPipeline,
13861393
MusicLDMPipeline,
1394+
NucleusMoEImagePipeline,
13871395
OmniGenPipeline,
13881396
OvisImagePipeline,
13891397
PaintByExamplePipeline,

src/diffusers/hooks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@
2727
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
2828
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
2929
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
30+
from .text_kv_cache import TextKVCacheConfig, apply_text_kv_cache
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass
16+
17+
import torch
18+
19+
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
20+
21+
22+
_TEXT_KV_CACHE_TRANSFORMER_HOOK = "text_kv_cache_transformer"
23+
_TEXT_KV_CACHE_BLOCK_HOOK = "text_kv_cache_block"
24+
25+
26+
@dataclass
27+
class TextKVCacheConfig:
28+
"""Enable exact (lossless) text K/V caching for transformer models.
29+
30+
Pre-computes per-block text key and value projections once before the denoising loop and reuses them across all
31+
steps. Positive and negative prompts are distinguished via a stable cache key captured by a transformer-level hook
32+
before any intermediate tensor allocations.
33+
"""
34+
35+
pass
36+
37+
38+
class TextKVCacheState(BaseState):
39+
"""Shared state between the transformer-level and block-level hooks.
40+
41+
The transformer hook writes the stable ``encoder_hidden_states`` ``data_ptr()`` (captured *before* ``txt_norm``) so
42+
that block hooks can use it as a reliable cache key across denoising steps.
43+
"""
44+
45+
def __init__(self):
46+
self.key: int | None = None
47+
48+
def reset(self):
49+
self.key = None
50+
51+
52+
class TextKVCacheBlockState(BaseState):
53+
"""Per-block state holding cached text key/value projections."""
54+
55+
def __init__(self):
56+
self.kv_cache: dict[int, tuple[torch.Tensor, torch.Tensor]] = {}
57+
58+
def reset(self):
59+
self.kv_cache.clear()
60+
61+
62+
class TextKVCacheTransformerHook(ModelHook):
63+
"""Captures ``encoder_hidden_states.data_ptr()`` before ``txt_norm``
64+
and writes it to shared state for the block hooks to read."""
65+
66+
_is_stateful = True
67+
68+
def __init__(self, state_manager: StateManager):
69+
super().__init__()
70+
self.state_manager = state_manager
71+
72+
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
73+
if self.state_manager._current_context is None:
74+
self.state_manager.set_context("inference")
75+
76+
encoder_hidden_states = kwargs.get("encoder_hidden_states")
77+
if encoder_hidden_states is not None:
78+
state: TextKVCacheState = self.state_manager.get_state()
79+
state.key = encoder_hidden_states.data_ptr()
80+
return self.fn_ref.original_forward(*args, **kwargs)
81+
82+
def reset_state(self, module: torch.nn.Module):
83+
self.state_manager.reset()
84+
return module
85+
86+
87+
class TextKVCacheBlockHook(ModelHook):
88+
"""Caches ``(txt_key, txt_value)`` per block per unique prompt using
89+
the stable cache key from the shared state."""
90+
91+
_is_stateful = True
92+
93+
def __init__(self, state_manager: StateManager, block_state_manager: StateManager):
94+
super().__init__()
95+
self.state_manager = state_manager
96+
self.block_state_manager = block_state_manager
97+
98+
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
99+
from ..models.transformers.transformer_nucleusmoe_image import _apply_rotary_emb_nucleus
100+
101+
if self.state_manager._current_context is None:
102+
self.state_manager.set_context("inference")
103+
104+
if self.block_state_manager._current_context is None:
105+
self.block_state_manager.set_context("inference")
106+
107+
if "encoder_hidden_states" in kwargs:
108+
encoder_hidden_states = kwargs["encoder_hidden_states"]
109+
else:
110+
encoder_hidden_states = args[1]
111+
112+
if "image_rotary_emb" in kwargs:
113+
image_rotary_emb = kwargs["image_rotary_emb"]
114+
elif len(args) > 3:
115+
image_rotary_emb = args[3]
116+
else:
117+
image_rotary_emb = None
118+
119+
state: TextKVCacheState = self.state_manager.get_state()
120+
cache_key = state.key
121+
122+
block_state: TextKVCacheBlockState = self.block_state_manager.get_state()
123+
124+
if cache_key not in block_state.kv_cache:
125+
context = module.encoder_proj(encoder_hidden_states)
126+
127+
attn = module.attn
128+
head_dim = attn.inner_dim // attn.heads
129+
num_kv_heads = attn.inner_kv_dim // head_dim
130+
131+
txt_key = attn.add_k_proj(context).unflatten(-1, (num_kv_heads, -1))
132+
txt_value = attn.add_v_proj(context).unflatten(-1, (num_kv_heads, -1))
133+
134+
if attn.norm_added_k is not None:
135+
txt_key = attn.norm_added_k(txt_key)
136+
137+
if image_rotary_emb is not None:
138+
_, txt_freqs = image_rotary_emb
139+
txt_key = _apply_rotary_emb_nucleus(txt_key, txt_freqs, use_real=False)
140+
141+
block_state.kv_cache[cache_key] = (txt_key, txt_value)
142+
143+
txt_key, txt_value = block_state.kv_cache[cache_key]
144+
145+
attn_kwargs = kwargs.get("attention_kwargs") or {}
146+
attn_kwargs["cached_txt_key"] = txt_key
147+
attn_kwargs["cached_txt_value"] = txt_value
148+
kwargs["attention_kwargs"] = attn_kwargs
149+
150+
return self.fn_ref.original_forward(*args, **kwargs)
151+
152+
def reset_state(self, module: torch.nn.Module):
153+
self.block_state_manager.reset()
154+
return module
155+
156+
157+
def apply_text_kv_cache(module: torch.nn.Module, config: TextKVCacheConfig) -> None:
158+
from ..models.transformers.transformer_nucleusmoe_image import NucleusMoEImageTransformerBlock
159+
160+
HookRegistry.check_if_exists_or_initialize(module)
161+
162+
state_manager = StateManager(TextKVCacheState)
163+
164+
transformer_hook = TextKVCacheTransformerHook(state_manager)
165+
registry = HookRegistry.check_if_exists_or_initialize(module)
166+
registry.register_hook(transformer_hook, _TEXT_KV_CACHE_TRANSFORMER_HOOK)
167+
168+
for _, submodule in module.named_modules():
169+
if isinstance(submodule, NucleusMoEImageTransformerBlock):
170+
block_state_manager = StateManager(TextKVCacheBlockState)
171+
hook = TextKVCacheBlockHook(state_manager, block_state_manager)
172+
block_registry = HookRegistry.check_if_exists_or_initialize(submodule)
173+
block_registry.register_hook(hook, _TEXT_KV_CACHE_BLOCK_HOOK)

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
_import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"]
117117
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
118118
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
119+
_import_structure["transformers.transformer_nucleusmoe_image"] = ["NucleusMoEImageTransformer2DModel"]
119120
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
120121
_import_structure["transformers.transformer_ovis_image"] = ["OvisImageTransformer2DModel"]
121122
_import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
@@ -236,6 +237,7 @@
236237
Lumina2Transformer2DModel,
237238
LuminaNextDiT2DModel,
238239
MochiTransformer3DModel,
240+
NucleusMoEImageTransformer2DModel,
239241
OmniGenTransformer2DModel,
240242
OvisImageTransformer2DModel,
241243
PixArtTransformer2DModel,

src/diffusers/models/cache_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,12 @@ def enable_cache(self, config) -> None:
4141
Enable caching techniques on the model.
4242
4343
Args:
44-
config (`PyramidAttentionBroadcastConfig | FasterCacheConfig | FirstBlockCacheConfig`):
44+
config (`PyramidAttentionBroadcastConfig | FasterCacheConfig | FirstBlockCacheConfig | TextKVCacheConfig`):
4545
The configuration for applying the caching technique. Currently supported caching techniques are:
4646
- [`~hooks.PyramidAttentionBroadcastConfig`]
4747
- [`~hooks.FasterCacheConfig`]
4848
- [`~hooks.FirstBlockCacheConfig`]
49+
- [`~hooks.TextKVCacheConfig`]
4950
5051
Example:
5152
@@ -71,11 +72,13 @@ def enable_cache(self, config) -> None:
7172
MagCacheConfig,
7273
PyramidAttentionBroadcastConfig,
7374
TaylorSeerCacheConfig,
75+
TextKVCacheConfig,
7476
apply_faster_cache,
7577
apply_first_block_cache,
7678
apply_mag_cache,
7779
apply_pyramid_attention_broadcast,
7880
apply_taylorseer_cache,
81+
apply_text_kv_cache,
7982
)
8083

8184
if self.is_cache_enabled:
@@ -89,6 +92,8 @@ def enable_cache(self, config) -> None:
8992
apply_first_block_cache(self, config)
9093
elif isinstance(config, MagCacheConfig):
9194
apply_mag_cache(self, config)
95+
elif isinstance(config, TextKVCacheConfig):
96+
apply_text_kv_cache(self, config)
9297
elif isinstance(config, PyramidAttentionBroadcastConfig):
9398
apply_pyramid_attention_broadcast(self, config)
9499
elif isinstance(config, TaylorSeerCacheConfig):
@@ -106,12 +111,14 @@ def disable_cache(self) -> None:
106111
MagCacheConfig,
107112
PyramidAttentionBroadcastConfig,
108113
TaylorSeerCacheConfig,
114+
TextKVCacheConfig,
109115
)
110116
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
111117
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
112118
from ..hooks.mag_cache import _MAG_CACHE_BLOCK_HOOK, _MAG_CACHE_LEADER_BLOCK_HOOK
113119
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
114120
from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK
121+
from ..hooks.text_kv_cache import _TEXT_KV_CACHE_BLOCK_HOOK, _TEXT_KV_CACHE_TRANSFORMER_HOOK
115122

116123
if self._cache_config is None:
117124
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
@@ -129,6 +136,9 @@ def disable_cache(self) -> None:
129136
registry.remove_hook(_MAG_CACHE_BLOCK_HOOK, recurse=True)
130137
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
131138
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
139+
elif isinstance(self._cache_config, TextKVCacheConfig):
140+
registry.remove_hook(_TEXT_KV_CACHE_TRANSFORMER_HOOK, recurse=True)
141+
registry.remove_hook(_TEXT_KV_CACHE_BLOCK_HOOK, recurse=True)
132142
elif isinstance(self._cache_config, TaylorSeerCacheConfig):
133143
registry.remove_hook(_TAYLORSEER_CACHE_HOOK, recurse=True)
134144
else:

src/diffusers/models/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from .transformer_ltx2 import LTX2VideoTransformer3DModel
4141
from .transformer_lumina2 import Lumina2Transformer2DModel
4242
from .transformer_mochi import MochiTransformer3DModel
43+
from .transformer_nucleusmoe_image import NucleusMoEImageTransformer2DModel
4344
from .transformer_omnigen import OmniGenTransformer2DModel
4445
from .transformer_ovis_image import OvisImageTransformer2DModel
4546
from .transformer_prx import PRXTransformer2DModel

0 commit comments

Comments
 (0)