You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add WanRollingKVCache for autoregressive Wan video generation
WanRollingKVCache is a per-block self-attention KV cache that lets a Wan
transformer generate video chunk by chunk while reusing the K/V tensors
computed for prior chunks instead of re-running the full attention over the
whole prefix on every step.
API:
- ``WanRollingKVCache(num_blocks, window_size=-1)`` — one cache per
transformer instance. ``window_size=-1`` keeps the full prefix; a finite
window evicts the oldest tokens once the cap is reached.
- ``cache.enable_append_mode()`` / ``cache.enable_overwrite_mode()`` — pick
the write semantics for the next forward pass. Append grows the cache
(or rolls when full); overwrite replaces the newest chunk in place — used
for additional denoising steps that re-do the most recent chunk.
- ``cache.update(block_idx, key, value)`` — called from ``WanAttnProcessor``
during self-attention to merge the current chunk into the per-block
cache and return the K/V to attend over.
- ``cache.reset()`` — clear all blocks between videos.
Wan plumbing:
- ``WanTransformer3DModel.forward`` accepts ``frame_offset: int = 0`` and
forwards ``rolling_kv_cache`` (extracted from ``attention_kwargs``) plus
``block_idx`` to each transformer block.
- ``WanRotaryPosEmbed.forward`` takes ``frame_offset`` so RoPE can address
positions in the original (uncached) sequence even when the latent input
is just one chunk.
- ``WanAttnProcessor.__call__`` receives ``rolling_kv_cache`` / ``block_idx``;
on self-attention it calls ``cache.update(...)`` and uses the returned
K/V for SDPA. Cross-attention is unaffected.
Caller usage::
cache = WanRollingKVCache(num_blocks=len(transformer.blocks))
for chunk_idx, latent_chunk in enumerate(chunks):
cache.enable_append_mode()
for step_idx, t in enumerate(denoising_steps):
if step_idx > 0:
cache.enable_overwrite_mode()
transformer(
hidden_states=latent_chunk,
timestep=t,
encoder_hidden_states=prompt_embeds,
frame_offset=chunk_idx * patch_frames_per_chunk,
attention_kwargs={"rolling_kv_cache": cache},
)
Tests cover unbounded append, windowed append (with eviction across one and
multiple chunks), in-place overwrite of the newest chunk, the
read-from-prior-context contract, reset, and frame_offset's effect on RoPE.
For autoregressive video generation that produces one chunk at a time, [`WanTransformer3DModel.forward`] accepts a `WanRollingKVCache` instance via `attention_kwargs={"rolling_kv_cache": cache}`. The cache holds post-norm, post-RoPE self-attention K/V tensors from prior chunks so subsequent chunks attend over the full prefix without recomputing it. The chunk's RoPE positions are picked via the `frame_offset` argument on `forward`.
31
+
32
+
The cache exposes two write modes that the caller toggles between denoising steps:
33
+
34
+
-`enable_append_mode()` — the next forward pass appends the chunk's K/V to the cache; once the cache reaches `window_size`, the oldest tokens are evicted from the front. Use this for the first denoising step of every new chunk.
35
+
-`enable_overwrite_mode()` — the next forward pass replaces the newest `chunk_size` tokens in place. Use this for subsequent denoising steps within the same chunk so re-running the chunk doesn't grow the cache.
36
+
37
+
```python
38
+
from diffusers import WanRollingKVCache, WanTransformer3DModel
0 commit comments