Skip to content

Commit fed8d5e

Browse files
authored
feat: Auto-regressive video generation (CORE-25) (#13082)
1 parent 9aef025 commit fed8d5e

6 files changed

Lines changed: 488 additions & 0 deletions

File tree

comfy/k_diffusion/sampling.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1810,3 +1810,102 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
18101810
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
18111811
"""Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023)."""
18121812
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
1813+
1814+
1815+
@torch.no_grad()
1816+
def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None,
1817+
num_frame_per_block=1):
1818+
"""
1819+
Autoregressive video sampler: block-by-block denoising with KV cache
1820+
and flow-match re-noising for Causal Forcing / Self-Forcing models.
1821+
1822+
Requires a Causal-WAN compatible model (diffusion_model must expose
1823+
init_kv_caches / init_crossattn_caches) and 5-D latents [B,C,T,H,W].
1824+
1825+
All AR-loop parameters are passed via the SamplerARVideo node, not read
1826+
from the checkpoint or transformer_options.
1827+
"""
1828+
extra_args = {} if extra_args is None else extra_args
1829+
model_options = extra_args.get("model_options", {})
1830+
transformer_options = model_options.get("transformer_options", {})
1831+
1832+
if x.ndim != 5:
1833+
raise ValueError(
1834+
f"ar_video sampler requires 5-D video latents [B,C,T,H,W], got {x.ndim}-D tensor with shape {x.shape}. "
1835+
"This sampler is only compatible with autoregressive video models (e.g. Causal-WAN)."
1836+
)
1837+
1838+
inner_model = model.inner_model.inner_model
1839+
causal_model = inner_model.diffusion_model
1840+
1841+
if not (hasattr(causal_model, "init_kv_caches") and hasattr(causal_model, "init_crossattn_caches")):
1842+
raise TypeError(
1843+
"ar_video sampler requires a Causal-WAN compatible model whose diffusion_model "
1844+
"exposes init_kv_caches() and init_crossattn_caches(). The loaded checkpoint "
1845+
"does not support this interface — choose a different sampler."
1846+
)
1847+
1848+
seed = extra_args.get("seed", 0)
1849+
1850+
bs, c, lat_t, lat_h, lat_w = x.shape
1851+
frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division
1852+
num_blocks = -(-lat_t // num_frame_per_block) # ceiling division
1853+
device = x.device
1854+
model_dtype = inner_model.get_dtype()
1855+
1856+
kv_caches = causal_model.init_kv_caches(bs, lat_t * frame_seq_len, device, model_dtype)
1857+
crossattn_caches = causal_model.init_crossattn_caches(bs, device, model_dtype)
1858+
1859+
output = torch.zeros_like(x)
1860+
s_in = x.new_ones([x.shape[0]])
1861+
current_start_frame = 0
1862+
num_sigma_steps = len(sigmas) - 1
1863+
total_real_steps = num_blocks * num_sigma_steps
1864+
step_count = 0
1865+
1866+
try:
1867+
for block_idx in trange(num_blocks, disable=disable):
1868+
bf = min(num_frame_per_block, lat_t - current_start_frame)
1869+
fs, fe = current_start_frame, current_start_frame + bf
1870+
noisy_input = x[:, :, fs:fe]
1871+
1872+
ar_state = {
1873+
"start_frame": current_start_frame,
1874+
"kv_caches": kv_caches,
1875+
"crossattn_caches": crossattn_caches,
1876+
}
1877+
transformer_options["ar_state"] = ar_state
1878+
1879+
for i in range(num_sigma_steps):
1880+
denoised = model(noisy_input, sigmas[i] * s_in, **extra_args)
1881+
1882+
if callback is not None:
1883+
scaled_i = step_count * num_sigma_steps // total_real_steps
1884+
callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i],
1885+
"sigma_hat": sigmas[i], "denoised": denoised})
1886+
1887+
if sigmas[i + 1] == 0:
1888+
noisy_input = denoised
1889+
else:
1890+
sigma_next = sigmas[i + 1]
1891+
torch.manual_seed(seed + block_idx * 1000 + i)
1892+
fresh_noise = torch.randn_like(denoised)
1893+
noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise
1894+
1895+
for cache in kv_caches:
1896+
cache["end"] -= bf * frame_seq_len
1897+
1898+
step_count += 1
1899+
1900+
output[:, :, fs:fe] = noisy_input
1901+
1902+
for cache in kv_caches:
1903+
cache["end"] -= bf * frame_seq_len
1904+
zero_sigma = sigmas.new_zeros([1])
1905+
_ = model(noisy_input, zero_sigma * s_in, **extra_args)
1906+
1907+
current_start_frame += bf
1908+
finally:
1909+
transformer_options.pop("ar_state", None)
1910+
1911+
return output

comfy/ldm/wan/ar_model.py

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
"""
2+
CausalWanModel: Wan 2.1 backbone with KV-cached causal self-attention for
3+
autoregressive (frame-by-frame) video generation via Causal Forcing.
4+
5+
Weight-compatible with the standard WanModel -- same layer names, same shapes.
6+
The difference is purely in the forward pass: this model processes one temporal
7+
block at a time and maintains a KV cache across blocks.
8+
9+
Reference: https://github.com/thu-ml/Causal-Forcing
10+
"""
11+
12+
import torch
13+
import torch.nn as nn
14+
15+
from comfy.ldm.modules.attention import optimized_attention
16+
from comfy.ldm.flux.math import apply_rope1
17+
from comfy.ldm.wan.model import (
18+
sinusoidal_embedding_1d,
19+
repeat_e,
20+
WanModel,
21+
WanAttentionBlock,
22+
)
23+
import comfy.ldm.common_dit
24+
import comfy.model_management
25+
26+
27+
class CausalWanSelfAttention(nn.Module):
28+
"""Self-attention with KV cache support for autoregressive inference."""
29+
30+
def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True,
31+
eps=1e-6, operation_settings={}):
32+
assert dim % num_heads == 0
33+
super().__init__()
34+
self.dim = dim
35+
self.num_heads = num_heads
36+
self.head_dim = dim // num_heads
37+
self.qk_norm = qk_norm
38+
self.eps = eps
39+
40+
ops = operation_settings.get("operations")
41+
device = operation_settings.get("device")
42+
dtype = operation_settings.get("dtype")
43+
44+
self.q = ops.Linear(dim, dim, device=device, dtype=dtype)
45+
self.k = ops.Linear(dim, dim, device=device, dtype=dtype)
46+
self.v = ops.Linear(dim, dim, device=device, dtype=dtype)
47+
self.o = ops.Linear(dim, dim, device=device, dtype=dtype)
48+
self.norm_q = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
49+
self.norm_k = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
50+
51+
def forward(self, x, freqs, kv_cache=None, transformer_options={}):
52+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
53+
54+
q = apply_rope1(self.norm_q(self.q(x)).view(b, s, n, d), freqs)
55+
k = apply_rope1(self.norm_k(self.k(x)).view(b, s, n, d), freqs)
56+
v = self.v(x).view(b, s, n, d)
57+
58+
if kv_cache is None:
59+
x = optimized_attention(
60+
q.view(b, s, n * d),
61+
k.view(b, s, n * d),
62+
v.view(b, s, n * d),
63+
heads=self.num_heads,
64+
transformer_options=transformer_options,
65+
)
66+
else:
67+
end = kv_cache["end"]
68+
new_end = end + s
69+
70+
# Roped K and plain V go into cache
71+
kv_cache["k"][:, end:new_end] = k
72+
kv_cache["v"][:, end:new_end] = v
73+
kv_cache["end"] = new_end
74+
75+
x = optimized_attention(
76+
q.view(b, s, n * d),
77+
kv_cache["k"][:, :new_end].view(b, new_end, n * d),
78+
kv_cache["v"][:, :new_end].view(b, new_end, n * d),
79+
heads=self.num_heads,
80+
transformer_options=transformer_options,
81+
)
82+
83+
x = self.o(x)
84+
return x
85+
86+
87+
class CausalWanAttentionBlock(WanAttentionBlock):
88+
"""Transformer block with KV-cached self-attention and cross-attention caching."""
89+
90+
def __init__(self, cross_attn_type, dim, ffn_dim, num_heads,
91+
window_size=(-1, -1), qk_norm=True, cross_attn_norm=False,
92+
eps=1e-6, operation_settings={}):
93+
super().__init__(cross_attn_type, dim, ffn_dim, num_heads,
94+
window_size, qk_norm, cross_attn_norm, eps,
95+
operation_settings=operation_settings)
96+
self.self_attn = CausalWanSelfAttention(
97+
dim, num_heads, window_size, qk_norm, eps,
98+
operation_settings=operation_settings)
99+
100+
def forward(self, x, e, freqs, context, context_img_len=257,
101+
kv_cache=None, crossattn_cache=None, transformer_options={}):
102+
if e.ndim < 4:
103+
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
104+
else:
105+
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
106+
107+
# Self-attention with optional KV cache
108+
x = x.contiguous()
109+
y = self.self_attn(
110+
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
111+
freqs, kv_cache=kv_cache, transformer_options=transformer_options)
112+
x = torch.addcmul(x, y, repeat_e(e[2], x))
113+
del y
114+
115+
# Cross-attention with optional caching
116+
if crossattn_cache is not None and crossattn_cache.get("is_init"):
117+
q = self.cross_attn.norm_q(self.cross_attn.q(self.norm3(x)))
118+
x_ca = optimized_attention(
119+
q, crossattn_cache["k"], crossattn_cache["v"],
120+
heads=self.num_heads, transformer_options=transformer_options)
121+
x = x + self.cross_attn.o(x_ca)
122+
else:
123+
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
124+
if crossattn_cache is not None:
125+
crossattn_cache["k"] = self.cross_attn.norm_k(self.cross_attn.k(context))
126+
crossattn_cache["v"] = self.cross_attn.v(context)
127+
crossattn_cache["is_init"] = True
128+
129+
# FFN
130+
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
131+
x = torch.addcmul(x, y, repeat_e(e[5], x))
132+
return x
133+
134+
135+
class CausalWanModel(WanModel):
136+
"""
137+
Wan 2.1 diffusion backbone with causal KV-cache support.
138+
139+
Same weight structure as WanModel -- loads identical state dicts.
140+
Adds forward_block() for frame-by-frame autoregressive inference.
141+
"""
142+
143+
def __init__(self,
144+
model_type='t2v',
145+
patch_size=(1, 2, 2),
146+
text_len=512,
147+
in_dim=16,
148+
dim=2048,
149+
ffn_dim=8192,
150+
freq_dim=256,
151+
text_dim=4096,
152+
out_dim=16,
153+
num_heads=16,
154+
num_layers=32,
155+
window_size=(-1, -1),
156+
qk_norm=True,
157+
cross_attn_norm=True,
158+
eps=1e-6,
159+
image_model=None,
160+
device=None,
161+
dtype=None,
162+
operations=None):
163+
super().__init__(
164+
model_type=model_type, patch_size=patch_size, text_len=text_len,
165+
in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim,
166+
text_dim=text_dim, out_dim=out_dim, num_heads=num_heads,
167+
num_layers=num_layers, window_size=window_size, qk_norm=qk_norm,
168+
cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model,
169+
wan_attn_block_class=CausalWanAttentionBlock,
170+
device=device, dtype=dtype, operations=operations)
171+
172+
def forward_block(self, x, timestep, context, start_frame,
173+
kv_caches, crossattn_caches, clip_fea=None):
174+
"""
175+
Forward one temporal block for autoregressive inference.
176+
177+
Args:
178+
x: [B, C, block_frames, H, W] input latent for the current block
179+
timestep: [B, block_frames] per-frame timesteps
180+
context: [B, L, text_dim] raw text embeddings (pre-text_embedding)
181+
start_frame: temporal frame index for RoPE offset
182+
kv_caches: list of per-layer KV cache dicts
183+
crossattn_caches: list of per-layer cross-attention cache dicts
184+
clip_fea: optional CLIP features for I2V
185+
186+
Returns:
187+
flow_pred: [B, C_out, block_frames, H, W] flow prediction
188+
"""
189+
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
190+
bs, c, t, h, w = x.shape
191+
192+
x = self.patch_embedding(x.float()).to(x.dtype)
193+
grid_sizes = x.shape[2:]
194+
x = x.flatten(2).transpose(1, 2)
195+
196+
# Per-frame time embedding
197+
e = self.time_embedding(
198+
sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype))
199+
e = e.reshape(timestep.shape[0], -1, e.shape[-1])
200+
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
201+
202+
# Text embedding (reuses crossattn_cache after first block)
203+
context = self.text_embedding(context)
204+
205+
context_img_len = None
206+
if clip_fea is not None and self.img_emb is not None:
207+
context_clip = self.img_emb(clip_fea)
208+
context = torch.concat([context_clip, context], dim=1)
209+
context_img_len = clip_fea.shape[-2]
210+
211+
# RoPE for current block's temporal position
212+
freqs = self.rope_encode(t, h, w, t_start=start_frame, device=x.device, dtype=x.dtype)
213+
214+
# Transformer blocks
215+
for i, block in enumerate(self.blocks):
216+
x = block(x, e=e0, freqs=freqs, context=context,
217+
context_img_len=context_img_len,
218+
kv_cache=kv_caches[i],
219+
crossattn_cache=crossattn_caches[i])
220+
221+
# Head
222+
x = self.head(x, e)
223+
224+
# Unpatchify
225+
x = self.unpatchify(x, grid_sizes)
226+
return x[:, :, :t, :h, :w]
227+
228+
def init_kv_caches(self, batch_size, max_seq_len, device, dtype):
229+
"""Create fresh KV caches for all layers."""
230+
caches = []
231+
for _ in range(self.num_layers):
232+
caches.append({
233+
"k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
234+
"v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
235+
"end": 0,
236+
})
237+
return caches
238+
239+
def init_crossattn_caches(self, batch_size, device, dtype):
240+
"""Create fresh cross-attention caches for all layers."""
241+
caches = []
242+
for _ in range(self.num_layers):
243+
caches.append({"is_init": False})
244+
return caches
245+
246+
def reset_kv_caches(self, kv_caches):
247+
"""Reset KV caches to empty (reuse allocated memory)."""
248+
for cache in kv_caches:
249+
cache["end"] = 0
250+
251+
def reset_crossattn_caches(self, crossattn_caches):
252+
"""Reset cross-attention caches."""
253+
for cache in crossattn_caches:
254+
cache["is_init"] = False
255+
256+
@property
257+
def head_dim(self):
258+
return self.dim // self.num_heads
259+
260+
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
261+
ar_state = transformer_options.get("ar_state")
262+
if ar_state is not None:
263+
bs = x.shape[0]
264+
block_frames = x.shape[2]
265+
t_per_frame = timestep.unsqueeze(1).expand(bs, block_frames)
266+
return self.forward_block(
267+
x=x, timestep=t_per_frame, context=context,
268+
start_frame=ar_state["start_frame"],
269+
kv_caches=ar_state["kv_caches"],
270+
crossattn_caches=ar_state["crossattn_caches"],
271+
clip_fea=clip_fea,
272+
)
273+
274+
return super().forward(x, timestep, context, clip_fea=clip_fea,
275+
time_dim_concat=time_dim_concat,
276+
transformer_options=transformer_options, **kwargs)

0 commit comments

Comments
 (0)