Skip to content

Commit 7c6255e

Browse files
committed
Add WanKVCache for autoregressive Wan video generation
WanKVCache is a per-block self-attention KV cache that lets a Wan transformer generate video chunk by chunk while reusing the K/V tensors computed for prior chunks instead of re-running the full attention over the whole prefix on every step. API: - ``WanKVCache(num_blocks, window_size=-1)`` — one cache per transformer instance. ``window_size=-1`` keeps the full prefix; a finite window evicts the oldest tokens once the cap is reached. - ``cache.enable_append_mode()`` / ``cache.enable_overwrite_mode()`` — pick the write semantics for the next forward pass. Append grows the cache (or rolls when full); overwrite replaces the newest chunk in place — used for additional denoising steps that re-do the most recent chunk. - ``cache.update(block_idx, key, value)`` — called from ``WanAttnProcessor`` during self-attention to merge the current chunk into the per-block cache and return the K/V to attend over. - ``cache.reset()`` — clear all blocks between videos. Wan plumbing: - ``WanTransformer3DModel.forward`` accepts ``frame_offset: int = 0`` and forwards ``kv_cache`` (extracted from ``attention_kwargs``) plus ``block_idx`` to each transformer block. - ``WanRotaryPosEmbed.forward`` takes ``frame_offset`` so RoPE can address positions in the original (uncached) sequence even when the latent input is just one chunk. - ``WanAttnProcessor.__call__`` receives ``kv_cache`` / ``block_idx``; on self-attention it calls ``cache.update(...)`` and uses the returned K/V for SDPA. Cross-attention is unaffected. Caller usage:: cache = WanKVCache(num_blocks=len(transformer.blocks)) for chunk_idx, latent_chunk in enumerate(chunks): cache.enable_append_mode() for step_idx, t in enumerate(denoising_steps): if step_idx > 0: cache.enable_overwrite_mode() transformer( hidden_states=latent_chunk, timestep=t, encoder_hidden_states=prompt_embeds, frame_offset=chunk_idx * patch_frames_per_chunk, attention_kwargs={"kv_cache": cache}, ) Tests cover unbounded append, windowed append (with eviction across one and multiple chunks), in-place overwrite of the newest chunk, the read-from-prior-context contract, reset, and frame_offset's effect on RoPE.
1 parent c8eba43 commit 7c6255e

6 files changed

Lines changed: 396 additions & 15 deletions

File tree

docs/source/en/api/models/wan_transformer_3d.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,46 @@ transformer = WanTransformer3DModel.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diff
2525

2626
[[autodoc]] WanTransformer3DModel
2727

28+
## Rolling KV cache
29+
30+
For autoregressive video generation that produces one chunk at a time, [`WanTransformer3DModel.forward`] accepts a `WanKVCache` instance via `attention_kwargs={"kv_cache": cache}`. The cache holds post-norm, post-RoPE self-attention K/V tensors from prior chunks so subsequent chunks attend over the full prefix without recomputing it. The chunk's RoPE positions are picked via the `frame_offset` argument on `forward`.
31+
32+
The cache exposes two write modes that the caller toggles between denoising steps:
33+
34+
- `enable_append_mode()` — the next forward pass appends the chunk's K/V to the cache; once the cache reaches `window_size`, the oldest tokens are evicted from the front. Use this for the first denoising step of every new chunk.
35+
- `enable_overwrite_mode()` — the next forward pass replaces the newest `chunk_size` tokens in place. Use this for subsequent denoising steps within the same chunk so re-running the chunk doesn't grow the cache.
36+
37+
```python
38+
from diffusers import WanKVCache, WanTransformer3DModel
39+
40+
transformer = WanTransformer3DModel.from_pretrained(...)
41+
cache = WanKVCache(num_blocks=len(transformer.blocks))
42+
43+
for chunk_idx, latent_chunk in enumerate(chunks):
44+
for step_idx, t in enumerate(denoising_steps):
45+
if step_idx == 0:
46+
cache.enable_append_mode()
47+
else:
48+
cache.enable_overwrite_mode()
49+
transformer(
50+
hidden_states=latent_chunk,
51+
timestep=t,
52+
encoder_hidden_states=prompt_embeds,
53+
frame_offset=chunk_idx * patch_frames_per_chunk,
54+
attention_kwargs={"kv_cache": cache},
55+
)
56+
57+
cache.reset() # between videos
58+
```
59+
60+
## WanKVCache
61+
62+
[[autodoc]] WanKVCache
63+
64+
## WanKVBlockCache
65+
66+
[[autodoc]] WanKVBlockCache
67+
2868
## Transformer2DModelOutput
2969

3070
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,8 @@
301301
"UVit2DModel",
302302
"VQModel",
303303
"WanAnimateTransformer3DModel",
304+
"WanKVBlockCache",
305+
"WanKVCache",
304306
"WanTransformer3DModel",
305307
"WanVACETransformer3DModel",
306308
"ZImageControlNetModel",
@@ -1117,6 +1119,8 @@
11171119
UVit2DModel,
11181120
VQModel,
11191121
WanAnimateTransformer3DModel,
1122+
WanKVBlockCache,
1123+
WanKVCache,
11201124
WanTransformer3DModel,
11211125
WanVACETransformer3DModel,
11221126
ZImageControlNetModel,

src/diffusers/models/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,11 @@
129129
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
130130
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
131131
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
132-
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
132+
_import_structure["transformers.transformer_wan"] = [
133+
"WanKVBlockCache",
134+
"WanKVCache",
135+
"WanTransformer3DModel",
136+
]
133137
_import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"]
134138
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
135139
_import_structure["transformers.transformer_z_image"] = ["ZImageTransformer2DModel"]
@@ -261,6 +265,8 @@
261265
Transformer2DModel,
262266
TransformerTemporalModel,
263267
WanAnimateTransformer3DModel,
268+
WanKVBlockCache,
269+
WanKVCache,
264270
WanTransformer3DModel,
265271
WanVACETransformer3DModel,
266272
ZImageTransformer2DModel,

src/diffusers/models/transformers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
from .transformer_sd3 import SD3Transformer2DModel
5353
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
5454
from .transformer_temporal import TransformerTemporalModel
55-
from .transformer_wan import WanTransformer3DModel
55+
from .transformer_wan import WanKVBlockCache, WanKVCache, WanTransformer3DModel
5656
from .transformer_wan_animate import WanAnimateTransformer3DModel
5757
from .transformer_wan_vace import WanVACETransformer3DModel
5858
from .transformer_z_image import ZImageTransformer2DModel

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 155 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import math
18+
from dataclasses import dataclass
1619
from typing import Any
1720

1821
import torch
@@ -36,7 +39,7 @@
3639
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3740

3841

39-
def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
42+
def _get_qkv_projections(attn: WanAttention, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
4043
# encoder_hidden_states is only passed for cross-attention
4144
if encoder_hidden_states is None:
4245
encoder_hidden_states = hidden_states
@@ -56,7 +59,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
5659
return query, key, value
5760

5861

59-
def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
62+
def _get_added_kv_projections(attn: WanAttention, encoder_hidden_states_img: torch.Tensor):
6063
if attn.fused_projections:
6164
key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
6265
else:
@@ -65,6 +68,115 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t
6568
return key_img, value_img
6669

6770

71+
@dataclass
72+
class WanKVBlockCache:
73+
"""Per-block rolling KV cache state for autoregressive WAN inference.
74+
75+
``cached_key`` and ``cached_value`` hold the post-norm, post-RoPE K/V from prior chunks
76+
with shape ``(batch_size, cached_seq_len, num_heads, head_dim)``.
77+
"""
78+
79+
cached_key: torch.Tensor | None = None
80+
cached_value: torch.Tensor | None = None
81+
82+
def reset(self) -> None:
83+
self.__init__()
84+
85+
86+
class WanKVCache:
87+
"""Rolling KV cache for autoregressive WAN video generation.
88+
89+
Holds a per-block ``WanKVBlockCache`` for every transformer block, plus shared
90+
write-control state. Pass an instance via ``attention_kwargs`` on each transformer forward
91+
call. ``WanAttnProcessor`` calls :py:meth:`update` to merge the current chunk's K/V into
92+
the cache and get back the (possibly trimmed) attention K/V.
93+
94+
TODO: cross-attention K/V projections are currently recomputed on every forward pass even
95+
though the text embeddings are constant across chunks. A future change can add cross-attn
96+
caching alongside the existing self-attn cache.
97+
98+
Args:
99+
num_blocks (`int`): Number of transformer blocks (``len(transformer.blocks)``).
100+
window_size (`int`, defaults to ``-1``): Maximum cached tokens per block. ``-1`` keeps
101+
the full prefix.
102+
103+
Example:
104+
105+
```python
106+
>>> cache = WanKVCache(num_blocks=len(transformer.blocks))
107+
>>> transformer(..., attention_kwargs={"kv_cache": cache})
108+
```
109+
"""
110+
111+
def __init__(self, num_blocks: int, window_size: int = -1):
112+
self.block_caches: list[WanKVBlockCache] = [WanKVBlockCache() for _ in range(num_blocks)]
113+
self.window_size: int = window_size
114+
self.overwrite_newest: bool = False
115+
116+
def enable_append_mode(self) -> None:
117+
"""Next forward pass appends the new chunk's K/V to the cache (cache grows, or oldest gets evicted)."""
118+
self.overwrite_newest = False
119+
120+
def enable_overwrite_mode(self) -> None:
121+
"""Next forward pass replaces the newest ``chunk_size`` tokens in place (cache size unchanged)."""
122+
self.overwrite_newest = True
123+
124+
def reset(self) -> None:
125+
"""Clear all cached K/V tensors and reset write-control state."""
126+
for bc in self.block_caches:
127+
bc.reset()
128+
self.overwrite_newest = False
129+
130+
def update(
131+
self,
132+
block_idx: int,
133+
new_key: torch.Tensor,
134+
new_value: torch.Tensor,
135+
) -> tuple[torch.Tensor, torch.Tensor]:
136+
"""Merge the current chunk's K/V into block ``block_idx``'s cache and return the
137+
K/V that the self-attention should attend over.
138+
139+
Two paths:
140+
- **Overwrite-newest** (``overwrite_newest=True`` and the cache already holds at
141+
least ``new_key.shape[1]`` tokens): write the new K/V *in place* into the trailing
142+
positions of the existing tensor. No allocation, no concat.
143+
- **Append** (default): concatenate the existing prefix with the new K/V, then trim
144+
the oldest tokens from the front if the result exceeds ``window_size``.
145+
"""
146+
block_cache = self.block_caches[block_idx]
147+
prefix_k = block_cache.cached_key
148+
prefix_v = block_cache.cached_value
149+
n = new_key.shape[1]
150+
151+
if self.window_size > 0 and n > self.window_size:
152+
raise RuntimeError(f"new chunk has {n} tokens, which exceeds window_size={self.window_size}.")
153+
154+
if self.overwrite_newest:
155+
if prefix_k is None or prefix_k.shape[1] < n:
156+
raise RuntimeError(
157+
"overwrite_newest requires the cache to already hold at least one chunk's worth of tokens "
158+
f"(>= {n}); cached length is {0 if prefix_k is None else prefix_k.shape[1]}. "
159+
"Use enable_append_mode() for the first write of a new chunk."
160+
)
161+
# In-place update of the cached tensors; block_cache already references them.
162+
prefix_k[:, -n:] = new_key
163+
prefix_v[:, -n:] = new_value
164+
return prefix_k, prefix_v
165+
166+
if prefix_k is None:
167+
block_cache.cached_key = new_key
168+
block_cache.cached_value = new_value
169+
return new_key, new_value
170+
171+
keep_prefix = self.window_size - n if self.window_size > 0 else prefix_k.shape[1]
172+
if keep_prefix > 0:
173+
new_key = torch.cat([prefix_k[:, -keep_prefix:], new_key], dim=1)
174+
new_value = torch.cat([prefix_v[:, -keep_prefix:], new_value], dim=1)
175+
block_cache.cached_key = new_key
176+
block_cache.cached_value = new_value
177+
return new_key, new_value
178+
179+
68180
class WanAttnProcessor:
69181
_attention_backend = None
70182
_parallel_config = None
@@ -77,11 +189,13 @@ def __init__(self):
77189

78190
def __call__(
79191
self,
80-
attn: "WanAttention",
192+
attn: WanAttention,
81193
hidden_states: torch.Tensor,
82194
encoder_hidden_states: torch.Tensor | None = None,
83195
attention_mask: torch.Tensor | None = None,
84196
rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
197+
kv_cache: WanKVCache | None = None,
198+
block_idx: int | None = None,
85199
) -> torch.Tensor:
86200
encoder_hidden_states_img = None
87201
if attn.add_k_proj is not None:
@@ -117,6 +231,11 @@ def apply_rotary_emb(
117231
query = apply_rotary_emb(query, *rotary_emb)
118232
key = apply_rotary_emb(key, *rotary_emb)
119233

234+
# Self-attention rolling KV cache: merge the current chunk's K/V into the per-block
235+
# cache and use the (possibly trimmed) result for attention.
236+
if kv_cache is not None and encoder_hidden_states is None:
237+
key, value = kv_cache.update(block_idx, key, value)
238+
120239
# I2V task
121240
hidden_states_img = None
122241
if encoder_hidden_states_img is not None:
@@ -392,7 +511,7 @@ def __init__(
392511
self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
393512
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
394513

395-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
514+
def forward(self, hidden_states: torch.Tensor, frame_offset: int = 0) -> torch.Tensor:
396515
batch_size, num_channels, num_frames, height, width = hidden_states.shape
397516
p_t, p_h, p_w = self.patch_size
398517
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
@@ -402,11 +521,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
402521
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
403522
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
404523

405-
freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
524+
freqs_cos_f = freqs_cos[0][frame_offset : frame_offset + ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
406525
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
407526
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
408527

409-
freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
528+
freqs_sin_f = freqs_sin[0][frame_offset : frame_offset + ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
410529
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
411530
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
412531

@@ -465,6 +584,8 @@ def forward(
465584
encoder_hidden_states: torch.Tensor,
466585
temb: torch.Tensor,
467586
rotary_emb: torch.Tensor,
587+
kv_cache: WanKVCache | None = None,
588+
block_idx: int | None = None,
468589
) -> torch.Tensor:
469590
if temb.ndim == 4:
470591
# temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
@@ -486,7 +607,14 @@ def forward(
486607

487608
# 1. Self-attention
488609
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
489-
attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
610+
attn_output = self.attn1(
611+
norm_hidden_states,
612+
None,
613+
None,
614+
rotary_emb,
615+
kv_cache=kv_cache,
616+
block_idx=block_idx,
617+
)
490618
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
491619

492620
# 2. Cross-attention
@@ -634,14 +762,16 @@ def forward(
634762
encoder_hidden_states_image: torch.Tensor | None = None,
635763
return_dict: bool = True,
636764
attention_kwargs: dict[str, Any] | None = None,
765+
frame_offset: int = 0,
637766
) -> torch.Tensor | dict[str, torch.Tensor]:
638767
batch_size, num_channels, num_frames, height, width = hidden_states.shape
639768
p_t, p_h, p_w = self.config.patch_size
640769
post_patch_num_frames = num_frames // p_t
641770
post_patch_height = height // p_h
642771
post_patch_width = width // p_w
643772

644-
rotary_emb = self.rope(hidden_states)
773+
rotary_emb = self.rope(hidden_states, frame_offset=frame_offset)
774+
kv_cache: WanKVCache | None = (attention_kwargs or {}).pop("kv_cache", None)
645775

646776
hidden_states = self.patch_embedding(hidden_states)
647777
hidden_states = hidden_states.flatten(2).transpose(1, 2)
@@ -668,13 +798,26 @@ def forward(
668798

669799
# 4. Transformer blocks
670800
if torch.is_grad_enabled() and self.gradient_checkpointing:
671-
for block in self.blocks:
801+
for block_idx, block in enumerate(self.blocks):
672802
hidden_states = self._gradient_checkpointing_func(
673-
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
803+
block,
804+
hidden_states,
805+
encoder_hidden_states,
806+
timestep_proj,
807+
rotary_emb,
808+
kv_cache,
809+
block_idx,
674810
)
675811
else:
676-
for block in self.blocks:
677-
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
812+
for block_idx, block in enumerate(self.blocks):
813+
hidden_states = block(
814+
hidden_states,
815+
encoder_hidden_states,
816+
timestep_proj,
817+
rotary_emb,
818+
kv_cache=kv_cache,
819+
block_idx=block_idx,
820+
)
678821

679822
# 5. Output norm, projection & unpatchify
680823
if temb.ndim == 3:

0 commit comments

Comments
 (0)