|
| 1 | +""" |
| 2 | +CausalWanModel: Wan 2.1 backbone with KV-cached causal self-attention for |
| 3 | +autoregressive (frame-by-frame) video generation via Causal Forcing. |
| 4 | +
|
| 5 | +Weight-compatible with the standard WanModel -- same layer names, same shapes. |
| 6 | +The difference is purely in the forward pass: this model processes one temporal |
| 7 | +block at a time and maintains a KV cache across blocks. |
| 8 | +
|
| 9 | +Reference: https://github.com/thu-ml/Causal-Forcing |
| 10 | +""" |
| 11 | + |
| 12 | +import torch |
| 13 | +import torch.nn as nn |
| 14 | + |
| 15 | +from comfy.ldm.modules.attention import optimized_attention |
| 16 | +from comfy.ldm.flux.math import apply_rope1 |
| 17 | +from comfy.ldm.wan.model import ( |
| 18 | + sinusoidal_embedding_1d, |
| 19 | + repeat_e, |
| 20 | + WanModel, |
| 21 | + WanAttentionBlock, |
| 22 | +) |
| 23 | +import comfy.ldm.common_dit |
| 24 | +import comfy.model_management |
| 25 | + |
| 26 | + |
| 27 | +class CausalWanSelfAttention(nn.Module): |
| 28 | + """Self-attention with KV cache support for autoregressive inference.""" |
| 29 | + |
| 30 | + def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, |
| 31 | + eps=1e-6, operation_settings={}): |
| 32 | + assert dim % num_heads == 0 |
| 33 | + super().__init__() |
| 34 | + self.dim = dim |
| 35 | + self.num_heads = num_heads |
| 36 | + self.head_dim = dim // num_heads |
| 37 | + self.qk_norm = qk_norm |
| 38 | + self.eps = eps |
| 39 | + |
| 40 | + ops = operation_settings.get("operations") |
| 41 | + device = operation_settings.get("device") |
| 42 | + dtype = operation_settings.get("dtype") |
| 43 | + |
| 44 | + self.q = ops.Linear(dim, dim, device=device, dtype=dtype) |
| 45 | + self.k = ops.Linear(dim, dim, device=device, dtype=dtype) |
| 46 | + self.v = ops.Linear(dim, dim, device=device, dtype=dtype) |
| 47 | + self.o = ops.Linear(dim, dim, device=device, dtype=dtype) |
| 48 | + self.norm_q = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity() |
| 49 | + self.norm_k = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity() |
| 50 | + |
| 51 | + def forward(self, x, freqs, kv_cache=None, transformer_options={}): |
| 52 | + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim |
| 53 | + |
| 54 | + q = apply_rope1(self.norm_q(self.q(x)).view(b, s, n, d), freqs) |
| 55 | + k = apply_rope1(self.norm_k(self.k(x)).view(b, s, n, d), freqs) |
| 56 | + v = self.v(x).view(b, s, n, d) |
| 57 | + |
| 58 | + if kv_cache is None: |
| 59 | + x = optimized_attention( |
| 60 | + q.view(b, s, n * d), |
| 61 | + k.view(b, s, n * d), |
| 62 | + v.view(b, s, n * d), |
| 63 | + heads=self.num_heads, |
| 64 | + transformer_options=transformer_options, |
| 65 | + ) |
| 66 | + else: |
| 67 | + end = kv_cache["end"] |
| 68 | + new_end = end + s |
| 69 | + |
| 70 | + # Roped K and plain V go into cache |
| 71 | + kv_cache["k"][:, end:new_end] = k |
| 72 | + kv_cache["v"][:, end:new_end] = v |
| 73 | + kv_cache["end"] = new_end |
| 74 | + |
| 75 | + x = optimized_attention( |
| 76 | + q.view(b, s, n * d), |
| 77 | + kv_cache["k"][:, :new_end].view(b, new_end, n * d), |
| 78 | + kv_cache["v"][:, :new_end].view(b, new_end, n * d), |
| 79 | + heads=self.num_heads, |
| 80 | + transformer_options=transformer_options, |
| 81 | + ) |
| 82 | + |
| 83 | + x = self.o(x) |
| 84 | + return x |
| 85 | + |
| 86 | + |
| 87 | +class CausalWanAttentionBlock(WanAttentionBlock): |
| 88 | + """Transformer block with KV-cached self-attention and cross-attention caching.""" |
| 89 | + |
| 90 | + def __init__(self, cross_attn_type, dim, ffn_dim, num_heads, |
| 91 | + window_size=(-1, -1), qk_norm=True, cross_attn_norm=False, |
| 92 | + eps=1e-6, operation_settings={}): |
| 93 | + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, |
| 94 | + window_size, qk_norm, cross_attn_norm, eps, |
| 95 | + operation_settings=operation_settings) |
| 96 | + self.self_attn = CausalWanSelfAttention( |
| 97 | + dim, num_heads, window_size, qk_norm, eps, |
| 98 | + operation_settings=operation_settings) |
| 99 | + |
| 100 | + def forward(self, x, e, freqs, context, context_img_len=257, |
| 101 | + kv_cache=None, crossattn_cache=None, transformer_options={}): |
| 102 | + if e.ndim < 4: |
| 103 | + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) |
| 104 | + else: |
| 105 | + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2) |
| 106 | + |
| 107 | + # Self-attention with optional KV cache |
| 108 | + x = x.contiguous() |
| 109 | + y = self.self_attn( |
| 110 | + torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), |
| 111 | + freqs, kv_cache=kv_cache, transformer_options=transformer_options) |
| 112 | + x = torch.addcmul(x, y, repeat_e(e[2], x)) |
| 113 | + del y |
| 114 | + |
| 115 | + # Cross-attention with optional caching |
| 116 | + if crossattn_cache is not None and crossattn_cache.get("is_init"): |
| 117 | + q = self.cross_attn.norm_q(self.cross_attn.q(self.norm3(x))) |
| 118 | + x_ca = optimized_attention( |
| 119 | + q, crossattn_cache["k"], crossattn_cache["v"], |
| 120 | + heads=self.num_heads, transformer_options=transformer_options) |
| 121 | + x = x + self.cross_attn.o(x_ca) |
| 122 | + else: |
| 123 | + x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) |
| 124 | + if crossattn_cache is not None: |
| 125 | + crossattn_cache["k"] = self.cross_attn.norm_k(self.cross_attn.k(context)) |
| 126 | + crossattn_cache["v"] = self.cross_attn.v(context) |
| 127 | + crossattn_cache["is_init"] = True |
| 128 | + |
| 129 | + # FFN |
| 130 | + y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) |
| 131 | + x = torch.addcmul(x, y, repeat_e(e[5], x)) |
| 132 | + return x |
| 133 | + |
| 134 | + |
| 135 | +class CausalWanModel(WanModel): |
| 136 | + """ |
| 137 | + Wan 2.1 diffusion backbone with causal KV-cache support. |
| 138 | +
|
| 139 | + Same weight structure as WanModel -- loads identical state dicts. |
| 140 | + Adds forward_block() for frame-by-frame autoregressive inference. |
| 141 | + """ |
| 142 | + |
| 143 | + def __init__(self, |
| 144 | + model_type='t2v', |
| 145 | + patch_size=(1, 2, 2), |
| 146 | + text_len=512, |
| 147 | + in_dim=16, |
| 148 | + dim=2048, |
| 149 | + ffn_dim=8192, |
| 150 | + freq_dim=256, |
| 151 | + text_dim=4096, |
| 152 | + out_dim=16, |
| 153 | + num_heads=16, |
| 154 | + num_layers=32, |
| 155 | + window_size=(-1, -1), |
| 156 | + qk_norm=True, |
| 157 | + cross_attn_norm=True, |
| 158 | + eps=1e-6, |
| 159 | + image_model=None, |
| 160 | + device=None, |
| 161 | + dtype=None, |
| 162 | + operations=None): |
| 163 | + super().__init__( |
| 164 | + model_type=model_type, patch_size=patch_size, text_len=text_len, |
| 165 | + in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, |
| 166 | + text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, |
| 167 | + num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, |
| 168 | + cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, |
| 169 | + wan_attn_block_class=CausalWanAttentionBlock, |
| 170 | + device=device, dtype=dtype, operations=operations) |
| 171 | + |
| 172 | + def forward_block(self, x, timestep, context, start_frame, |
| 173 | + kv_caches, crossattn_caches, clip_fea=None): |
| 174 | + """ |
| 175 | + Forward one temporal block for autoregressive inference. |
| 176 | +
|
| 177 | + Args: |
| 178 | + x: [B, C, block_frames, H, W] input latent for the current block |
| 179 | + timestep: [B, block_frames] per-frame timesteps |
| 180 | + context: [B, L, text_dim] raw text embeddings (pre-text_embedding) |
| 181 | + start_frame: temporal frame index for RoPE offset |
| 182 | + kv_caches: list of per-layer KV cache dicts |
| 183 | + crossattn_caches: list of per-layer cross-attention cache dicts |
| 184 | + clip_fea: optional CLIP features for I2V |
| 185 | +
|
| 186 | + Returns: |
| 187 | + flow_pred: [B, C_out, block_frames, H, W] flow prediction |
| 188 | + """ |
| 189 | + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) |
| 190 | + bs, c, t, h, w = x.shape |
| 191 | + |
| 192 | + x = self.patch_embedding(x.float()).to(x.dtype) |
| 193 | + grid_sizes = x.shape[2:] |
| 194 | + x = x.flatten(2).transpose(1, 2) |
| 195 | + |
| 196 | + # Per-frame time embedding |
| 197 | + e = self.time_embedding( |
| 198 | + sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype)) |
| 199 | + e = e.reshape(timestep.shape[0], -1, e.shape[-1]) |
| 200 | + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) |
| 201 | + |
| 202 | + # Text embedding (reuses crossattn_cache after first block) |
| 203 | + context = self.text_embedding(context) |
| 204 | + |
| 205 | + context_img_len = None |
| 206 | + if clip_fea is not None and self.img_emb is not None: |
| 207 | + context_clip = self.img_emb(clip_fea) |
| 208 | + context = torch.concat([context_clip, context], dim=1) |
| 209 | + context_img_len = clip_fea.shape[-2] |
| 210 | + |
| 211 | + # RoPE for current block's temporal position |
| 212 | + freqs = self.rope_encode(t, h, w, t_start=start_frame, device=x.device, dtype=x.dtype) |
| 213 | + |
| 214 | + # Transformer blocks |
| 215 | + for i, block in enumerate(self.blocks): |
| 216 | + x = block(x, e=e0, freqs=freqs, context=context, |
| 217 | + context_img_len=context_img_len, |
| 218 | + kv_cache=kv_caches[i], |
| 219 | + crossattn_cache=crossattn_caches[i]) |
| 220 | + |
| 221 | + # Head |
| 222 | + x = self.head(x, e) |
| 223 | + |
| 224 | + # Unpatchify |
| 225 | + x = self.unpatchify(x, grid_sizes) |
| 226 | + return x[:, :, :t, :h, :w] |
| 227 | + |
| 228 | + def init_kv_caches(self, batch_size, max_seq_len, device, dtype): |
| 229 | + """Create fresh KV caches for all layers.""" |
| 230 | + caches = [] |
| 231 | + for _ in range(self.num_layers): |
| 232 | + caches.append({ |
| 233 | + "k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype), |
| 234 | + "v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype), |
| 235 | + "end": 0, |
| 236 | + }) |
| 237 | + return caches |
| 238 | + |
| 239 | + def init_crossattn_caches(self, batch_size, device, dtype): |
| 240 | + """Create fresh cross-attention caches for all layers.""" |
| 241 | + caches = [] |
| 242 | + for _ in range(self.num_layers): |
| 243 | + caches.append({"is_init": False}) |
| 244 | + return caches |
| 245 | + |
| 246 | + def reset_kv_caches(self, kv_caches): |
| 247 | + """Reset KV caches to empty (reuse allocated memory).""" |
| 248 | + for cache in kv_caches: |
| 249 | + cache["end"] = 0 |
| 250 | + |
| 251 | + def reset_crossattn_caches(self, crossattn_caches): |
| 252 | + """Reset cross-attention caches.""" |
| 253 | + for cache in crossattn_caches: |
| 254 | + cache["is_init"] = False |
| 255 | + |
| 256 | + @property |
| 257 | + def head_dim(self): |
| 258 | + return self.dim // self.num_heads |
| 259 | + |
| 260 | + def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): |
| 261 | + ar_state = transformer_options.get("ar_state") |
| 262 | + if ar_state is not None: |
| 263 | + bs = x.shape[0] |
| 264 | + block_frames = x.shape[2] |
| 265 | + t_per_frame = timestep.unsqueeze(1).expand(bs, block_frames) |
| 266 | + return self.forward_block( |
| 267 | + x=x, timestep=t_per_frame, context=context, |
| 268 | + start_frame=ar_state["start_frame"], |
| 269 | + kv_caches=ar_state["kv_caches"], |
| 270 | + crossattn_caches=ar_state["crossattn_caches"], |
| 271 | + clip_fea=clip_fea, |
| 272 | + ) |
| 273 | + |
| 274 | + return super().forward(x, timestep, context, clip_fea=clip_fea, |
| 275 | + time_dim_concat=time_dim_concat, |
| 276 | + transformer_options=transformer_options, **kwargs) |
0 commit comments