1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from __future__ import annotations
16+
1517import math
18+ from dataclasses import dataclass
1619from typing import Any
1720
1821import torch
3639logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
3740
3841
39- def _get_qkv_projections (attn : " WanAttention" , hidden_states : torch .Tensor , encoder_hidden_states : torch .Tensor ):
42+ def _get_qkv_projections (attn : WanAttention , hidden_states : torch .Tensor , encoder_hidden_states : torch .Tensor ):
4043 # encoder_hidden_states is only passed for cross-attention
4144 if encoder_hidden_states is None :
4245 encoder_hidden_states = hidden_states
@@ -56,7 +59,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
5659 return query , key , value
5760
5861
59- def _get_added_kv_projections (attn : " WanAttention" , encoder_hidden_states_img : torch .Tensor ):
62+ def _get_added_kv_projections (attn : WanAttention , encoder_hidden_states_img : torch .Tensor ):
6063 if attn .fused_projections :
6164 key_img , value_img = attn .to_added_kv (encoder_hidden_states_img ).chunk (2 , dim = - 1 )
6265 else :
@@ -65,6 +68,115 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t
6568 return key_img , value_img
6669
6770
71+ @dataclass
72+ class WanRollingKVBlockCache :
73+ """Per-block rolling KV cache state for autoregressive WAN inference.
74+
75+ ``cached_key`` and ``cached_value`` hold the post-norm, post-RoPE K/V from prior chunks
76+ with shape ``(batch_size, cached_seq_len, num_heads, head_dim)``.
77+ """
78+
79+ cached_key : torch .Tensor | None = None
80+ cached_value : torch .Tensor | None = None
81+
82+ def reset (self ) -> None :
83+ self .__init__ ()
84+
85+
86+ class WanRollingKVCache :
87+ """Rolling KV cache for autoregressive WAN video generation.
88+
89+ Holds a per-block ``WanRollingKVBlockCache`` for every transformer block, plus shared
90+ write-control state. Pass an instance via ``attention_kwargs`` on each transformer forward
91+ call. ``WanAttnProcessor`` calls :py:meth:`update` to merge the current chunk's K/V into
92+ the cache and get back the (possibly trimmed) attention K/V.
93+
94+ TODO: cross-attention K/V projections are currently recomputed on every forward pass even
95+ though the text embeddings are constant across chunks. A future change can add cross-attn
96+ caching alongside the existing self-attn cache.
97+
98+ Args:
99+ num_blocks (`int`): Number of transformer blocks (``len(transformer.blocks)``).
100+ window_size (`int`, defaults to ``-1``): Maximum cached tokens per block. ``-1`` keeps
101+ the full prefix.
102+
103+ Example:
104+
105+ ```python
106+ >>> cache = WanRollingKVCache(num_blocks=len(transformer.blocks))
107+ >>> transformer(..., attention_kwargs={"rolling_kv_cache": cache})
108+ ```
109+ """
110+
111+ def __init__ (self , num_blocks : int , window_size : int = - 1 ):
112+ self .block_caches : list [WanRollingKVBlockCache ] = [WanRollingKVBlockCache () for _ in range (num_blocks )]
113+ self .window_size : int = window_size
114+ self .overwrite_newest : bool = False
115+
116+ def enable_append_mode (self ) -> None :
117+ """Next forward pass appends the new chunk's K/V to the cache (cache grows, or oldest gets evicted)."""
118+ self .overwrite_newest = False
119+
120+ def enable_overwrite_mode (self ) -> None :
121+ """Next forward pass replaces the newest ``chunk_size`` tokens in place (cache size unchanged)."""
122+ self .overwrite_newest = True
123+
124+ def reset (self ) -> None :
125+ """Clear all cached K/V tensors and reset write-control state."""
126+ for bc in self .block_caches :
127+ bc .reset ()
128+ self .overwrite_newest = False
129+
130+ def update (
131+ self ,
132+ block_idx : int ,
133+ new_key : torch .Tensor ,
134+ new_value : torch .Tensor ,
135+ ) -> tuple [torch .Tensor , torch .Tensor ]:
136+ """Merge the current chunk's K/V into block ``block_idx``'s cache and return the
137+ K/V that the self-attention should attend over.
138+
139+ Two paths:
140+ - **Overwrite-newest** (``overwrite_newest=True`` and the cache already holds at
141+ least ``new_key.shape[1]`` tokens): write the new K/V *in place* into the trailing
142+ positions of the existing tensor. No allocation, no concat.
143+ - **Append** (default): concatenate the existing prefix with the new K/V, then trim
144+ the oldest tokens from the front if the result exceeds ``window_size``.
145+ """
146+ block_cache = self .block_caches [block_idx ]
147+ prefix_k = block_cache .cached_key
148+ prefix_v = block_cache .cached_value
149+ n = new_key .shape [1 ]
150+
151+ if self .overwrite_newest :
152+ if prefix_k is None or prefix_k .shape [1 ] < n :
153+ raise RuntimeError (
154+ "overwrite_newest requires the cache to already hold at least one chunk's worth of tokens "
155+ f"(>= { n } ); cached length is { 0 if prefix_k is None else prefix_k .shape [1 ]} . "
156+ "Use enable_append_mode() for the first write of a new chunk."
157+ )
158+ # Slide the new K/V into the trailing positions of the existing tensors.
159+ prefix_k [:, - n :] = new_key
160+ prefix_v [:, - n :] = new_value
161+ new_key , new_value = prefix_k , prefix_v
162+ elif prefix_k is not None :
163+ # Drop the part of the prefix that would be evicted anyway, so the cat never
164+ # allocates a tensor larger than ``window_size``.
165+ keep_prefix = max (0 , self .window_size - n ) if self .window_size > 0 else prefix_k .shape [1 ]
166+ if keep_prefix > 0 :
167+ new_key = torch .cat ([prefix_k [:, - keep_prefix :], new_key ], dim = 1 )
168+ new_value = torch .cat ([prefix_v [:, - keep_prefix :], new_value ], dim = 1 )
169+
170+ # Cap when the new chunk alone exceeds window_size (no-op in the common case).
171+ keep = self .window_size if self .window_size > 0 else new_key .shape [1 ]
172+ new_key = new_key [:, - keep :]
173+ new_value = new_value [:, - keep :]
174+
175+ block_cache .cached_key = new_key
176+ block_cache .cached_value = new_value
177+ return new_key , new_value
178+
179+
68180class WanAttnProcessor :
69181 _attention_backend = None
70182 _parallel_config = None
@@ -77,11 +189,13 @@ def __init__(self):
77189
78190 def __call__ (
79191 self ,
80- attn : " WanAttention" ,
192+ attn : WanAttention ,
81193 hidden_states : torch .Tensor ,
82194 encoder_hidden_states : torch .Tensor | None = None ,
83195 attention_mask : torch .Tensor | None = None ,
84196 rotary_emb : tuple [torch .Tensor , torch .Tensor ] | None = None ,
197+ rolling_kv_cache : WanRollingKVCache | None = None ,
198+ block_idx : int | None = None ,
85199 ) -> torch .Tensor :
86200 encoder_hidden_states_img = None
87201 if attn .add_k_proj is not None :
@@ -117,6 +231,11 @@ def apply_rotary_emb(
117231 query = apply_rotary_emb (query , * rotary_emb )
118232 key = apply_rotary_emb (key , * rotary_emb )
119233
234+ # Self-attention rolling KV cache: merge the current chunk's K/V into the per-block
235+ # cache and use the (possibly trimmed) result for attention.
236+ if rolling_kv_cache is not None and encoder_hidden_states is None :
237+ key , value = rolling_kv_cache .update (block_idx , key , value )
238+
120239 # I2V task
121240 hidden_states_img = None
122241 if encoder_hidden_states_img is not None :
@@ -392,7 +511,7 @@ def __init__(
392511 self .register_buffer ("freqs_cos" , torch .cat (freqs_cos , dim = 1 ), persistent = False )
393512 self .register_buffer ("freqs_sin" , torch .cat (freqs_sin , dim = 1 ), persistent = False )
394513
395- def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
514+ def forward (self , hidden_states : torch .Tensor , frame_offset : int = 0 ) -> torch .Tensor :
396515 batch_size , num_channels , num_frames , height , width = hidden_states .shape
397516 p_t , p_h , p_w = self .patch_size
398517 ppf , pph , ppw = num_frames // p_t , height // p_h , width // p_w
@@ -402,11 +521,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
402521 freqs_cos = self .freqs_cos .split (split_sizes , dim = 1 )
403522 freqs_sin = self .freqs_sin .split (split_sizes , dim = 1 )
404523
405- freqs_cos_f = freqs_cos [0 ][: ppf ].view (ppf , 1 , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
524+ freqs_cos_f = freqs_cos [0 ][frame_offset : frame_offset + ppf ].view (ppf , 1 , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
406525 freqs_cos_h = freqs_cos [1 ][:pph ].view (1 , pph , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
407526 freqs_cos_w = freqs_cos [2 ][:ppw ].view (1 , 1 , ppw , - 1 ).expand (ppf , pph , ppw , - 1 )
408527
409- freqs_sin_f = freqs_sin [0 ][: ppf ].view (ppf , 1 , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
528+ freqs_sin_f = freqs_sin [0 ][frame_offset : frame_offset + ppf ].view (ppf , 1 , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
410529 freqs_sin_h = freqs_sin [1 ][:pph ].view (1 , pph , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
411530 freqs_sin_w = freqs_sin [2 ][:ppw ].view (1 , 1 , ppw , - 1 ).expand (ppf , pph , ppw , - 1 )
412531
@@ -465,6 +584,8 @@ def forward(
465584 encoder_hidden_states : torch .Tensor ,
466585 temb : torch .Tensor ,
467586 rotary_emb : torch .Tensor ,
587+ rolling_kv_cache : WanRollingKVCache | None = None ,
588+ block_idx : int | None = None ,
468589 ) -> torch .Tensor :
469590 if temb .ndim == 4 :
470591 # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
@@ -486,7 +607,14 @@ def forward(
486607
487608 # 1. Self-attention
488609 norm_hidden_states = (self .norm1 (hidden_states .float ()) * (1 + scale_msa ) + shift_msa ).type_as (hidden_states )
489- attn_output = self .attn1 (norm_hidden_states , None , None , rotary_emb )
610+ attn_output = self .attn1 (
611+ norm_hidden_states ,
612+ None ,
613+ None ,
614+ rotary_emb ,
615+ rolling_kv_cache = rolling_kv_cache ,
616+ block_idx = block_idx ,
617+ )
490618 hidden_states = (hidden_states .float () + attn_output * gate_msa ).type_as (hidden_states )
491619
492620 # 2. Cross-attention
@@ -634,14 +762,16 @@ def forward(
634762 encoder_hidden_states_image : torch .Tensor | None = None ,
635763 return_dict : bool = True ,
636764 attention_kwargs : dict [str , Any ] | None = None ,
765+ frame_offset : int = 0 ,
637766 ) -> torch .Tensor | dict [str , torch .Tensor ]:
638767 batch_size , num_channels , num_frames , height , width = hidden_states .shape
639768 p_t , p_h , p_w = self .config .patch_size
640769 post_patch_num_frames = num_frames // p_t
641770 post_patch_height = height // p_h
642771 post_patch_width = width // p_w
643772
644- rotary_emb = self .rope (hidden_states )
773+ rotary_emb = self .rope (hidden_states , frame_offset = frame_offset )
774+ rolling_kv_cache : WanRollingKVCache | None = (attention_kwargs or {}).pop ("rolling_kv_cache" , None )
645775
646776 hidden_states = self .patch_embedding (hidden_states )
647777 hidden_states = hidden_states .flatten (2 ).transpose (1 , 2 )
@@ -668,13 +798,26 @@ def forward(
668798
669799 # 4. Transformer blocks
670800 if torch .is_grad_enabled () and self .gradient_checkpointing :
671- for block in self .blocks :
801+ for block_idx , block in enumerate ( self .blocks ) :
672802 hidden_states = self ._gradient_checkpointing_func (
673- block , hidden_states , encoder_hidden_states , timestep_proj , rotary_emb
803+ block ,
804+ hidden_states ,
805+ encoder_hidden_states ,
806+ timestep_proj ,
807+ rotary_emb ,
808+ rolling_kv_cache ,
809+ block_idx ,
674810 )
675811 else :
676- for block in self .blocks :
677- hidden_states = block (hidden_states , encoder_hidden_states , timestep_proj , rotary_emb )
812+ for block_idx , block in enumerate (self .blocks ):
813+ hidden_states = block (
814+ hidden_states ,
815+ encoder_hidden_states ,
816+ timestep_proj ,
817+ rotary_emb ,
818+ rolling_kv_cache = rolling_kv_cache ,
819+ block_idx = block_idx ,
820+ )
678821
679822 # 5. Output norm, projection & unpatchify
680823 if temb .ndim == 3 :
0 commit comments