11from __future__ import annotations
2- from typing import TYPE_CHECKING , Callable
2+ from typing import TYPE_CHECKING , Any , Callable
33import torch
44import numpy as np
55import collections
@@ -181,6 +181,12 @@ def _compute_guide_overlap(guide_entries, window_index_list):
181181 return suffix_indices , overlap_info , kf_local_positions , len (suffix_indices )
182182
183183
184+ @dataclass
185+ class WindowingContext :
186+ tensor : torch .Tensor
187+ suffix : torch .Tensor | None
188+ aux_data : Any
189+
184190@dataclass
185191class ContextSchedule :
186192 name : str
@@ -242,18 +248,6 @@ def _patch_latent_shapes(self, sub_conds, new_shapes):
242248 if 'latent_shapes' in model_conds :
243249 model_conds ['latent_shapes' ] = comfy .conds .CONDConstant (new_shapes )
244250
245- def _get_guide_entries (self , conds ):
246- """Extract guide_attention_entries list from conditioning. Returns None if absent."""
247- for cond_list in conds :
248- if cond_list is None :
249- continue
250- for cond_dict in cond_list :
251- model_conds = cond_dict .get ('model_conds' , {})
252- gae = model_conds .get ('guide_attention_entries' )
253- if gae is not None and hasattr (gae , 'cond' ) and gae .cond :
254- return gae .cond
255- return None
256-
257251 def should_use_context (self , model : BaseModel , conds : list [list [dict ]], x_in : torch .Tensor , timestep : torch .Tensor , model_options : dict [str ]) -> bool :
258252 latent_shapes = self ._get_latent_shapes (conds )
259253 primary = self ._decompose (x_in , latent_shapes )[0 ]
@@ -379,24 +373,19 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[
379373 is_multimodal = len (modalities ) > 1
380374 primary = modalities [0 ]
381375
382- # Separate guide frames from primary modality (guides are appended at the end)
383- guide_count = model .get_guide_frame_count (primary , conds ) if hasattr (model , 'get_guide_frame_count' ) else 0
384- if guide_count > 0 :
385- video_len = primary .size (self .dim ) - guide_count
386- video_primary = primary .narrow (self .dim , 0 , video_len )
387- guide_suffix = primary .narrow (self .dim , video_len , guide_count )
388- else :
389- video_primary = primary
390- guide_suffix = None
376+ # Let model strip auxiliary frames (e.g. guide frames)
377+ window_data = model .prepare_for_windowing (primary , conds , self .dim )
378+ video_primary = window_data .tensor
379+ aux_count = window_data .suffix .size (self .dim ) if window_data .suffix is not None else 0
391380
392- # Windows from video portion only (excluding guide frames)
381+ # Windows from video portion only
393382 context_windows = self .get_context_windows (model , video_primary , model_options )
394383 enumerated_context_windows = list (enumerate (context_windows ))
395384 total_windows = len (enumerated_context_windows )
396385
397386 # Accumulators sized to video portion for primary, full for other modalities
398387 accum_modalities = list (modalities )
399- if guide_suffix is not None :
388+ if window_data . suffix is not None :
400389 accum_modalities [0 ] = video_primary
401390
402391 accum = [[torch .zeros_like (m ) for _ in conds ] for m in accum_modalities ]
@@ -406,25 +395,22 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[
406395 counts = [[torch .zeros (get_shape_for_dim (m , self .dim ), device = m .device ) for _ in conds ] for m in accum_modalities ]
407396 biases = [[([0.0 ] * m .shape [self .dim ]) for _ in conds ] for m in accum_modalities ]
408397
409- guide_entries = self ._get_guide_entries (conds ) if guide_count > 0 else None
410-
411398 for callback in comfy .patcher_extension .get_all_callbacks (IndexListCallbacks .EXECUTE_START , self .callbacks ):
412399 callback (self , model , x_in , conds , timestep , model_options )
413400
414401 for window_idx , window in enumerated_context_windows :
415402 comfy .model_management .throw_exception_if_processing_interrupted ()
416403 logging .info (f"Context window { window_idx + 1 } /{ total_windows } : frames { window .index_list [0 ]} -{ window .index_list [- 1 ]} of { video_primary .shape [self .dim ]} "
417- + (f" (+{ guide_count } guide )" if guide_count > 0 else "" )
404+ + (f" (+{ aux_count } aux )" if aux_count > 0 else "" )
418405 + (f" [{ len (modalities )} modalities]" if is_multimodal else "" ))
419406
420407 # Per-modality window indices
421408 if is_multimodal :
422- # Adjust latent_shapes so video shape reflects video-only frames (excludes guides)
423409 map_shapes = latent_shapes
424- if guide_count > 0 :
410+ if video_primary . size ( self . dim ) != primary . size ( self . dim ) :
425411 map_shapes = list (latent_shapes )
426412 video_shape = list (latent_shapes [0 ])
427- video_shape [self .dim ] = video_shape [ self .dim ] - guide_count
413+ video_shape [self .dim ] = video_primary . size ( self .dim )
428414 map_shapes [0 ] = torch .Size (video_shape )
429415 per_mod_indices = model .map_context_window_to_modalities (
430416 window .index_list , map_shapes , self .dim ) if hasattr (model , 'map_context_window_to_modalities' ) else [window .index_list ]
@@ -446,30 +432,10 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[
446432 for mod_idx in range (1 , len (modalities )):
447433 mod_windows .append (modality_windows [mod_idx ])
448434
449- # Slice video, then select overlapping guide frames
435+ # Slice video, then let model inject auxiliary frames
450436 sliced_video = mod_windows [0 ].get_tensor (video_primary , retain_index_list = self .cond_retain_index_list )
451- num_guide_in_window = 0
452- if guide_suffix is not None and guide_entries is not None :
453- overlap = _compute_guide_overlap (guide_entries , window .index_list )
454- if overlap [3 ] > 0 :
455- suffix_idx , overlap_info , kf_local_pos , num_guide_in_window = overlap
456- idx = tuple ([slice (None )] * self .dim + [suffix_idx ])
457- sliced_guide = guide_suffix [idx ]
458- window .guide_suffix_indices = suffix_idx
459- window .guide_overlap_info = overlap_info
460- window .guide_kf_local_positions = kf_local_pos
461- else :
462- sliced_guide = None
463- window .guide_suffix_indices = []
464- window .guide_overlap_info = []
465- window .guide_kf_local_positions = []
466- else :
467- sliced_guide = None
468-
469- if sliced_guide is not None :
470- sliced_primary = torch .cat ([sliced_video , sliced_guide ], dim = self .dim )
471- else :
472- sliced_primary = sliced_video
437+ sliced_primary , num_aux = model .prepare_window_input (
438+ sliced_video , window , window_data .aux_data , self .dim )
473439 sliced = [sliced_primary ] + [mod_windows [mi ].get_tensor (modalities [mi ]) for mi in range (1 , len (modalities ))]
474440
475441 # Compose for pipeline
@@ -481,7 +447,6 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[
481447
482448 model_options ["transformer_options" ]["context_window" ] = window
483449 sub_timestep = window .get_tensor (timestep , dim = 0 )
484- # Resize conds using video_primary as reference (excludes guide frames)
485450 sub_conds = [self .get_resized_cond (cond , video_primary , window ) for cond in conds ]
486451 if is_multimodal :
487452 self ._patch_latent_shapes (sub_conds , sub_shapes )
@@ -490,14 +455,12 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[
490455
491456 # Decompose output per modality
492457 out_per_mod = [self ._decompose (sub_conds_out [i ], sub_shapes ) for i in range (len (sub_conds_out ))]
493- # out_per_mod[cond_idx][mod_idx] = tensor
494458
495- # Strip guide frames from primary output before accumulation
496- if num_guide_in_window > 0 :
459+ # Strip auxiliary frames from primary output before accumulation
460+ if num_aux > 0 :
497461 window_len = len (window .index_list )
498462 for ci in range (len (sub_conds_out )):
499- primary_out = out_per_mod [ci ][0 ]
500- out_per_mod [ci ][0 ] = primary_out .narrow (self .dim , 0 , window_len )
463+ out_per_mod [ci ][0 ] = out_per_mod [ci ][0 ].narrow (self .dim , 0 , window_len )
501464
502465 # Accumulate per modality (using video-only sizes)
503466 for mod_idx in range (len (accum_modalities )):
@@ -516,10 +479,9 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[
516479 if self .fuse_method .name != ContextFuseMethods .RELATIVE :
517480 accum [mod_idx ][ci ] /= counts [mod_idx ][ci ]
518481 f = accum [mod_idx ][ci ]
519- # Re-append original guide_suffix (not model output — sampling loop
520- # respects denoise_mask and never modifies guide frame positions)
521- if mod_idx == 0 and guide_suffix is not None :
522- f = torch .cat ([f , guide_suffix ], dim = self .dim )
482+ # Re-append model's suffix (auxiliary frames stripped before windowing)
483+ if mod_idx == 0 and window_data .suffix is not None :
484+ f = torch .cat ([f , window_data .suffix ], dim = self .dim )
523485 finalized .append (f )
524486 composed , _ = self ._compose (finalized )
525487 result .append (composed )
0 commit comments