Skip to content

Commit 874690c

Browse files
committed
LTX2 context windows - Refactor guide logic from context_windows into LTXAV model hooks
1 parent 3502376 commit 874690c

2 files changed

Lines changed: 76 additions & 63 deletions

File tree

comfy/context_windows.py

Lines changed: 25 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations
2-
from typing import TYPE_CHECKING, Callable
2+
from typing import TYPE_CHECKING, Any, Callable
33
import torch
44
import numpy as np
55
import collections
@@ -181,6 +181,12 @@ def _compute_guide_overlap(guide_entries, window_index_list):
181181
return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices)
182182

183183

184+
@dataclass
185+
class WindowingContext:
186+
tensor: torch.Tensor
187+
suffix: torch.Tensor | None
188+
aux_data: Any
189+
184190
@dataclass
185191
class ContextSchedule:
186192
name: str
@@ -242,18 +248,6 @@ def _patch_latent_shapes(self, sub_conds, new_shapes):
242248
if 'latent_shapes' in model_conds:
243249
model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes)
244250

245-
def _get_guide_entries(self, conds):
246-
"""Extract guide_attention_entries list from conditioning. Returns None if absent."""
247-
for cond_list in conds:
248-
if cond_list is None:
249-
continue
250-
for cond_dict in cond_list:
251-
model_conds = cond_dict.get('model_conds', {})
252-
gae = model_conds.get('guide_attention_entries')
253-
if gae is not None and hasattr(gae, 'cond') and gae.cond:
254-
return gae.cond
255-
return None
256-
257251
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
258252
latent_shapes = self._get_latent_shapes(conds)
259253
primary = self._decompose(x_in, latent_shapes)[0]
@@ -379,24 +373,19 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[
379373
is_multimodal = len(modalities) > 1
380374
primary = modalities[0]
381375

382-
# Separate guide frames from primary modality (guides are appended at the end)
383-
guide_count = model.get_guide_frame_count(primary, conds) if hasattr(model, 'get_guide_frame_count') else 0
384-
if guide_count > 0:
385-
video_len = primary.size(self.dim) - guide_count
386-
video_primary = primary.narrow(self.dim, 0, video_len)
387-
guide_suffix = primary.narrow(self.dim, video_len, guide_count)
388-
else:
389-
video_primary = primary
390-
guide_suffix = None
376+
# Let model strip auxiliary frames (e.g. guide frames)
377+
window_data = model.prepare_for_windowing(primary, conds, self.dim)
378+
video_primary = window_data.tensor
379+
aux_count = window_data.suffix.size(self.dim) if window_data.suffix is not None else 0
391380

392-
# Windows from video portion only (excluding guide frames)
381+
# Windows from video portion only
393382
context_windows = self.get_context_windows(model, video_primary, model_options)
394383
enumerated_context_windows = list(enumerate(context_windows))
395384
total_windows = len(enumerated_context_windows)
396385

397386
# Accumulators sized to video portion for primary, full for other modalities
398387
accum_modalities = list(modalities)
399-
if guide_suffix is not None:
388+
if window_data.suffix is not None:
400389
accum_modalities[0] = video_primary
401390

402391
accum = [[torch.zeros_like(m) for _ in conds] for m in accum_modalities]
@@ -406,25 +395,22 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[
406395
counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_modalities]
407396
biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in accum_modalities]
408397

409-
guide_entries = self._get_guide_entries(conds) if guide_count > 0 else None
410-
411398
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
412399
callback(self, model, x_in, conds, timestep, model_options)
413400

414401
for window_idx, window in enumerated_context_windows:
415402
comfy.model_management.throw_exception_if_processing_interrupted()
416403
logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {video_primary.shape[self.dim]}"
417-
+ (f" (+{guide_count} guide)" if guide_count > 0 else "")
404+
+ (f" (+{aux_count} aux)" if aux_count > 0 else "")
418405
+ (f" [{len(modalities)} modalities]" if is_multimodal else ""))
419406

420407
# Per-modality window indices
421408
if is_multimodal:
422-
# Adjust latent_shapes so video shape reflects video-only frames (excludes guides)
423409
map_shapes = latent_shapes
424-
if guide_count > 0:
410+
if video_primary.size(self.dim) != primary.size(self.dim):
425411
map_shapes = list(latent_shapes)
426412
video_shape = list(latent_shapes[0])
427-
video_shape[self.dim] = video_shape[self.dim] - guide_count
413+
video_shape[self.dim] = video_primary.size(self.dim)
428414
map_shapes[0] = torch.Size(video_shape)
429415
per_mod_indices = model.map_context_window_to_modalities(
430416
window.index_list, map_shapes, self.dim) if hasattr(model, 'map_context_window_to_modalities') else [window.index_list]
@@ -446,30 +432,10 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[
446432
for mod_idx in range(1, len(modalities)):
447433
mod_windows.append(modality_windows[mod_idx])
448434

449-
# Slice video, then select overlapping guide frames
435+
# Slice video, then let model inject auxiliary frames
450436
sliced_video = mod_windows[0].get_tensor(video_primary, retain_index_list=self.cond_retain_index_list)
451-
num_guide_in_window = 0
452-
if guide_suffix is not None and guide_entries is not None:
453-
overlap = _compute_guide_overlap(guide_entries, window.index_list)
454-
if overlap[3] > 0:
455-
suffix_idx, overlap_info, kf_local_pos, num_guide_in_window = overlap
456-
idx = tuple([slice(None)] * self.dim + [suffix_idx])
457-
sliced_guide = guide_suffix[idx]
458-
window.guide_suffix_indices = suffix_idx
459-
window.guide_overlap_info = overlap_info
460-
window.guide_kf_local_positions = kf_local_pos
461-
else:
462-
sliced_guide = None
463-
window.guide_suffix_indices = []
464-
window.guide_overlap_info = []
465-
window.guide_kf_local_positions = []
466-
else:
467-
sliced_guide = None
468-
469-
if sliced_guide is not None:
470-
sliced_primary = torch.cat([sliced_video, sliced_guide], dim=self.dim)
471-
else:
472-
sliced_primary = sliced_video
437+
sliced_primary, num_aux = model.prepare_window_input(
438+
sliced_video, window, window_data.aux_data, self.dim)
473439
sliced = [sliced_primary] + [mod_windows[mi].get_tensor(modalities[mi]) for mi in range(1, len(modalities))]
474440

475441
# Compose for pipeline
@@ -481,7 +447,6 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[
481447

482448
model_options["transformer_options"]["context_window"] = window
483449
sub_timestep = window.get_tensor(timestep, dim=0)
484-
# Resize conds using video_primary as reference (excludes guide frames)
485450
sub_conds = [self.get_resized_cond(cond, video_primary, window) for cond in conds]
486451
if is_multimodal:
487452
self._patch_latent_shapes(sub_conds, sub_shapes)
@@ -490,14 +455,12 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[
490455

491456
# Decompose output per modality
492457
out_per_mod = [self._decompose(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))]
493-
# out_per_mod[cond_idx][mod_idx] = tensor
494458

495-
# Strip guide frames from primary output before accumulation
496-
if num_guide_in_window > 0:
459+
# Strip auxiliary frames from primary output before accumulation
460+
if num_aux > 0:
497461
window_len = len(window.index_list)
498462
for ci in range(len(sub_conds_out)):
499-
primary_out = out_per_mod[ci][0]
500-
out_per_mod[ci][0] = primary_out.narrow(self.dim, 0, window_len)
463+
out_per_mod[ci][0] = out_per_mod[ci][0].narrow(self.dim, 0, window_len)
501464

502465
# Accumulate per modality (using video-only sizes)
503466
for mod_idx in range(len(accum_modalities)):
@@ -516,10 +479,9 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[
516479
if self.fuse_method.name != ContextFuseMethods.RELATIVE:
517480
accum[mod_idx][ci] /= counts[mod_idx][ci]
518481
f = accum[mod_idx][ci]
519-
# Re-append original guide_suffix (not model output — sampling loop
520-
# respects denoise_mask and never modifies guide frame positions)
521-
if mod_idx == 0 and guide_suffix is not None:
522-
f = torch.cat([f, guide_suffix], dim=self.dim)
482+
# Re-append model's suffix (auxiliary frames stripped before windowing)
483+
if mod_idx == 0 and window_data.suffix is not None:
484+
f = torch.cat([f, window_data.suffix], dim=self.dim)
523485
finalized.append(f)
524486
composed, _ = self._compose(finalized)
525487
result.append(composed)

comfy/model_base.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,12 @@ def concat_cond(self, **kwargs):
287287
return data
288288
return None
289289

290+
def prepare_for_windowing(self, primary, conds, dim):
291+
return comfy.context_windows.WindowingContext(tensor=primary, suffix=None, aux_data=None)
292+
293+
def prepare_window_input(self, video_slice, window, aux_data, dim):
294+
return video_slice, 0
295+
290296
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
291297
"""Override in subclasses to handle model-specific cond slicing for context windows.
292298
Return a sliced cond object, or None to fall through to default handling.
@@ -1113,6 +1119,51 @@ def get_guide_frame_count(self, x, conds):
11131119
return sum(e["latent_shape"][0] for e in gae.cond)
11141120
return 0
11151121

1122+
@staticmethod
1123+
def _get_guide_entries(conds):
1124+
for cond_list in conds:
1125+
if cond_list is None:
1126+
continue
1127+
for cond_dict in cond_list:
1128+
model_conds = cond_dict.get('model_conds', {})
1129+
gae = model_conds.get('guide_attention_entries')
1130+
if gae is not None and hasattr(gae, 'cond') and gae.cond:
1131+
return gae.cond
1132+
return None
1133+
1134+
def prepare_for_windowing(self, primary, conds, dim):
1135+
guide_count = self.get_guide_frame_count(primary, conds)
1136+
if guide_count <= 0:
1137+
return comfy.context_windows.WindowingContext(tensor=primary, suffix=None, aux_data=None)
1138+
video_len = primary.size(dim) - guide_count
1139+
video_primary = primary.narrow(dim, 0, video_len)
1140+
guide_suffix = primary.narrow(dim, video_len, guide_count)
1141+
guide_entries = self._get_guide_entries(conds)
1142+
return comfy.context_windows.WindowingContext(
1143+
tensor=video_primary, suffix=guide_suffix,
1144+
aux_data={"guide_entries": guide_entries, "guide_suffix": guide_suffix})
1145+
1146+
def prepare_window_input(self, video_slice, window, aux_data, dim):
1147+
if aux_data is None:
1148+
return video_slice, 0
1149+
guide_entries = aux_data["guide_entries"]
1150+
guide_suffix = aux_data["guide_suffix"]
1151+
if guide_entries is None:
1152+
window.guide_suffix_indices = []
1153+
window.guide_overlap_info = []
1154+
window.guide_kf_local_positions = []
1155+
return video_slice, 0
1156+
overlap = comfy.context_windows._compute_guide_overlap(guide_entries, window.index_list)
1157+
suffix_idx, overlap_info, kf_local_pos, num_guide = overlap
1158+
window.guide_suffix_indices = suffix_idx
1159+
window.guide_overlap_info = overlap_info
1160+
window.guide_kf_local_positions = kf_local_pos
1161+
if num_guide > 0:
1162+
idx = tuple([slice(None)] * dim + [suffix_idx])
1163+
sliced_guide = guide_suffix[idx]
1164+
return torch.cat([video_slice, sliced_guide], dim=dim), num_guide
1165+
return video_slice, 0
1166+
11161167
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
11171168
# Audio denoise mask — slice using audio modality window
11181169
if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows:

0 commit comments

Comments
 (0)