@@ -63,7 +63,11 @@ def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_lis
6363 dim = self .dim
6464 if dim == 0 and full .shape [dim ] == 1 :
6565 return full
66- idx = tuple ([slice (None )] * dim + [self .index_list ])
66+ indices = self .index_list
67+ anchor_idx = getattr (self , 'causal_anchor_index' , None )
68+ if anchor_idx is not None and anchor_idx >= 0 :
69+ indices = [anchor_idx ] + list (indices )
70+ idx = tuple ([slice (None )] * dim + [indices ])
6771 window = full [idx ]
6872 if retain_index_list :
6973 idx = tuple ([slice (None )] * dim + [retain_index_list ])
@@ -113,7 +117,14 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d
113117
114118 # skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
115119 if temporal_offset > 0 :
116- indices = [i - temporal_offset for i in window .index_list [temporal_offset :]]
120+ anchor_idx = getattr (window , 'causal_anchor_index' , None )
121+ if anchor_idx is not None and anchor_idx >= 0 :
122+ # anchor occupies one of the no-cond positions, so skip one fewer from window.index_list
123+ skip_count = temporal_offset - 1
124+ else :
125+ skip_count = temporal_offset
126+
127+ indices = [i - temporal_offset for i in window .index_list [skip_count :]]
117128 indices = [i for i in indices if 0 <= i ]
118129 else :
119130 indices = list (window .index_list )
@@ -150,7 +161,8 @@ class ContextFuseMethod:
150161ContextResults = collections .namedtuple ("ContextResults" , ['window_idx' , 'sub_conds_out' , 'sub_conds' , 'window' ])
151162class IndexListContextHandler (ContextHandlerABC ):
152163 def __init__ (self , context_schedule : ContextSchedule , fuse_method : ContextFuseMethod , context_length : int = 1 , context_overlap : int = 0 , context_stride : int = 1 ,
153- closed_loop : bool = False , dim :int = 0 , freenoise : bool = False , cond_retain_index_list : list [int ]= [], split_conds_to_windows : bool = False ):
164+ closed_loop : bool = False , dim :int = 0 , freenoise : bool = False , cond_retain_index_list : list [int ]= [], split_conds_to_windows : bool = False ,
165+ causal_window_fix : bool = True ):
154166 self .context_schedule = context_schedule
155167 self .fuse_method = fuse_method
156168 self .context_length = context_length
@@ -162,6 +174,7 @@ def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMe
162174 self .freenoise = freenoise
163175 self .cond_retain_index_list = [int (x .strip ()) for x in cond_retain_index_list .split ("," )] if cond_retain_index_list else []
164176 self .split_conds_to_windows = split_conds_to_windows
177+ self .causal_window_fix = causal_window_fix
165178
166179 self .callbacks = {}
167180
@@ -318,6 +331,14 @@ def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel,
318331 # allow processing to end between context window executions for faster Cancel
319332 comfy .model_management .throw_exception_if_processing_interrupted ()
320333
334+ # causal_window_fix: prepend a pre-window frame that will be stripped post-forward
335+ anchor_applied = False
336+ if self .causal_window_fix :
337+ anchor_idx = window .index_list [0 ] - 1
338+ if 0 <= anchor_idx < x_in .size (self .dim ):
339+ window .causal_anchor_index = anchor_idx
340+ anchor_applied = True
341+
321342 for callback in comfy .patcher_extension .get_all_callbacks (IndexListCallbacks .EVALUATE_CONTEXT_WINDOWS , self .callbacks ):
322343 callback (self , model , x_in , conds , timestep , model_options , window_idx , window , model_options , device , first_device )
323344
@@ -332,6 +353,12 @@ def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel,
332353 if device is not None :
333354 for i in range (len (sub_conds_out )):
334355 sub_conds_out [i ] = sub_conds_out [i ].to (x_in .device )
356+
357+ # strip causal_window_fix anchor if applied
358+ if anchor_applied :
359+ for i in range (len (sub_conds_out )):
360+ sub_conds_out [i ] = sub_conds_out [i ].narrow (self .dim , 1 , sub_conds_out [i ].shape [self .dim ] - 1 )
361+
335362 results .append (ContextResults (window_idx , sub_conds_out , sub_conds , window ))
336363 return results
337364
0 commit comments