Skip to content

Commit d5badc5

Browse files
committed
LTX2 context windows - Clean up unnecessary code
1 parent c9edd2d commit d5badc5

2 files changed

Lines changed: 18 additions & 29 deletions

File tree

comfy/context_windows.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -182,22 +182,22 @@ def compute_guide_overlap(guide_entries: list[dict], window_index_list: list[int
182182

183183

184184
def inject_guide_frames_into_window(video_slice: torch.Tensor, window: ContextWindowABC,
185-
aux_data: dict, dim: int) -> tuple[torch.Tensor, int]:
185+
window_data: 'WindowingContext', dim: int) -> tuple[torch.Tensor, int]:
186186
"""Inject overlapping guide frames into a context window slice.
187187
188-
Uses aux_data from WindowingContext to determine which guide frames overlap
189-
with this window's indices, concatenates them onto the video slice, and sets
190-
window attributes for downstream conditioning resize.
188+
Determines which guide frames overlap with this window's indices, concatenates
189+
them onto the video slice, and sets window attributes for downstream conditioning resize.
191190
192191
Returns (augmented_slice, num_guide_frames_added).
193192
"""
194-
guide_entries = aux_data["guide_entries"]
195-
guide_frames = aux_data["guide_frames"]
193+
guide_entries = window_data.aux_data["guide_entries"]
194+
guide_frames = window_data.guide_frames
196195
overlap = compute_guide_overlap(guide_entries, window.index_list)
197196
suffix_idx, overlap_info, kf_local_pos, guide_frame_count = overlap
198197
window.guide_frames_indices = suffix_idx
199198
window.guide_overlap_info = overlap_info
200199
window.guide_kf_local_positions = kf_local_pos
200+
201201
# Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims.
202202
# guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims.
203203
guide_downscale_factors = []
@@ -207,6 +207,7 @@ def inject_guide_frames_into_window(video_slice: torch.Tensor, window: ContextWi
207207
entry_H = guide_entries[entry_idx]["latent_shape"][1]
208208
guide_downscale_factors.append(full_H // entry_H)
209209
window.guide_downscale_factors = guide_downscale_factors
210+
210211
if guide_frame_count > 0:
211212
idx = tuple([slice(None)] * dim + [suffix_idx])
212213
sliced_guide = guide_frames[idx]
@@ -220,7 +221,6 @@ class WindowingContext:
220221
guide_frames: torch.Tensor | None
221222
aux_data: Any
222223
latent_shapes: list | None
223-
is_multimodal: bool
224224

225225
@dataclass
226226
class ContextSchedule:
@@ -310,13 +310,13 @@ def _build_window_data(self, x_in: torch.Tensor, conds: list[list[dict]]) -> Win
310310
guide_frames = video_latent.narrow(self.dim, primary_frame_count, guide_frame_count) if guide_frame_count > 0 else None
311311

312312
if guide_frame_count > 0:
313-
aux_data = {"guide_entries": guide_entries, "guide_frames": guide_frames}
313+
aux_data = {"guide_entries": guide_entries}
314314
else:
315315
aux_data = None
316316

317317
return WindowingContext(
318318
tensor=primary_frames, guide_frames=guide_frames, aux_data=aux_data,
319-
latent_shapes=latent_shapes, is_multimodal=is_multimodal)
319+
latent_shapes=latent_shapes)
320320

321321
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
322322
self._window_data = self._build_window_data(x_in, conds)
@@ -437,9 +437,14 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[
437437
self.set_step(timestep, model_options)
438438

439439
window_data = self._window_data
440-
if window_data.is_multimodal or (window_data.guide_frames is not None and window_data.guide_frames.size(self.dim) > 0):
440+
is_multimodal = window_data.latent_shapes is not None and len(window_data.latent_shapes) > 1
441+
has_guide_frames = window_data.guide_frames is not None and window_data.guide_frames.size(self.dim) > 0
442+
443+
# if multimodal or has concatenated guide frames on noise latent, use the extended execute path
444+
if is_multimodal or has_guide_frames:
441445
return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options, window_data)
442446

447+
# basic legacy execution path for single-modal video latent with no guide frames concatenated
443448
context_windows = self.get_context_windows(model, x_in, model_options)
444449
enumerated_context_windows = list(enumerate(context_windows))
445450

@@ -475,8 +480,9 @@ def _execute_extended(self, calc_cond_batch: Callable, model: BaseModel, conds:
475480
timestep: torch.Tensor, model_options: dict[str],
476481
window_data: WindowingContext):
477482
"""Extended execute path for multimodal models and models with guide frames appended to the noise latent."""
483+
478484
latents = self._unpack(x_in, window_data.latent_shapes)
479-
is_multimodal = window_data.is_multimodal
485+
is_multimodal = window_data.latent_shapes is not None and len(window_data.latent_shapes) > 1
480486

481487
primary_frames = window_data.tensor
482488
num_guide_frames = window_data.guide_frames.size(self.dim) if window_data.guide_frames is not None else 0
@@ -538,7 +544,7 @@ def _execute_extended(self, calc_cond_batch: Callable, model: BaseModel, conds:
538544
# Slice video, then inject overlapping guide frames if present
539545
sliced_video = per_modality_windows_list[0].get_tensor(primary_frames, retain_index_list=self.cond_retain_index_list)
540546
if window_data.aux_data is not None:
541-
sliced_primary, num_guide_frames = inject_guide_frames_into_window(sliced_video, window, window_data.aux_data, self.dim)
547+
sliced_primary, num_guide_frames = inject_guide_frames_into_window(sliced_video, window, window_data, self.dim)
542548
else:
543549
sliced_primary, num_guide_frames = sliced_video, 0
544550
sliced = [sliced_primary] + [per_modality_windows_list[mi].get_tensor(latents[mi]) for mi in range(1, len(latents))]

comfy/model_base.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,23 +1113,6 @@ def _get_guide_entries(conds):
11131113
if entries is not None and hasattr(entries, 'cond') and entries.cond:
11141114
return entries.cond
11151115
return None
1116-
1117-
def prepare_window_data(self, x_in, conds, dim, window_data):
1118-
primary = comfy.utils.unpack_latents(x_in, window_data.latent_shapes)[0] if window_data.is_multimodal else x_in
1119-
guide_entries = self._get_guide_entries(conds)
1120-
guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0
1121-
if guide_count <= 0:
1122-
return comfy.context_windows.WindowingContext(
1123-
tensor=primary, guide_frames=None, aux_data=None,
1124-
latent_shapes=window_data.latent_shapes, is_multimodal=window_data.is_multimodal)
1125-
video_len = primary.size(dim) - guide_count
1126-
video_primary = primary.narrow(dim, 0, video_len)
1127-
guide_frames = primary.narrow(dim, video_len, guide_count)
1128-
return comfy.context_windows.WindowingContext(
1129-
tensor=video_primary, guide_frames=guide_frames,
1130-
aux_data={"guide_entries": guide_entries, "guide_frames": guide_frames},
1131-
latent_shapes=window_data.latent_shapes, is_multimodal=window_data.is_multimodal)
1132-
11331116

11341117
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
11351118
# Audio denoise mask — slice using audio modality window

0 commit comments

Comments
 (0)