Skip to content

Commit e5369c0

Browse files
authored
feat: Context windows - add causal_window_fix to improve blending of context windows (CORE-100) (#13563)
* Context windows: add causal_window_fix toggle * Fix slice_cond to correctly handle causal anchor index for temporal offsets
1 parent 1655f80 commit e5369c0

2 files changed

Lines changed: 34 additions & 5 deletions

File tree

comfy/context_windows.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,11 @@ def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_lis
6363
dim = self.dim
6464
if dim == 0 and full.shape[dim] == 1:
6565
return full
66-
idx = tuple([slice(None)] * dim + [self.index_list])
66+
indices = self.index_list
67+
anchor_idx = getattr(self, 'causal_anchor_index', None)
68+
if anchor_idx is not None and anchor_idx >= 0:
69+
indices = [anchor_idx] + list(indices)
70+
idx = tuple([slice(None)] * dim + [indices])
6771
window = full[idx]
6872
if retain_index_list:
6973
idx = tuple([slice(None)] * dim + [retain_index_list])
@@ -113,7 +117,14 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d
113117

114118
# skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
115119
if temporal_offset > 0:
116-
indices = [i - temporal_offset for i in window.index_list[temporal_offset:]]
120+
anchor_idx = getattr(window, 'causal_anchor_index', None)
121+
if anchor_idx is not None and anchor_idx >= 0:
122+
# anchor occupies one of the no-cond positions, so skip one fewer from window.index_list
123+
skip_count = temporal_offset - 1
124+
else:
125+
skip_count = temporal_offset
126+
127+
indices = [i - temporal_offset for i in window.index_list[skip_count:]]
117128
indices = [i for i in indices if 0 <= i]
118129
else:
119130
indices = list(window.index_list)
@@ -150,7 +161,8 @@ class ContextFuseMethod:
150161
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
151162
class IndexListContextHandler(ContextHandlerABC):
152163
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
153-
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
164+
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False,
165+
causal_window_fix: bool=True):
154166
self.context_schedule = context_schedule
155167
self.fuse_method = fuse_method
156168
self.context_length = context_length
@@ -162,6 +174,7 @@ def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMe
162174
self.freenoise = freenoise
163175
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
164176
self.split_conds_to_windows = split_conds_to_windows
177+
self.causal_window_fix = causal_window_fix
165178

166179
self.callbacks = {}
167180

@@ -318,6 +331,14 @@ def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel,
318331
# allow processing to end between context window executions for faster Cancel
319332
comfy.model_management.throw_exception_if_processing_interrupted()
320333

334+
# causal_window_fix: prepend a pre-window frame that will be stripped post-forward
335+
anchor_applied = False
336+
if self.causal_window_fix:
337+
anchor_idx = window.index_list[0] - 1
338+
if 0 <= anchor_idx < x_in.size(self.dim):
339+
window.causal_anchor_index = anchor_idx
340+
anchor_applied = True
341+
321342
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
322343
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
323344

@@ -332,6 +353,12 @@ def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel,
332353
if device is not None:
333354
for i in range(len(sub_conds_out)):
334355
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
356+
357+
# strip causal_window_fix anchor if applied
358+
if anchor_applied:
359+
for i in range(len(sub_conds_out)):
360+
sub_conds_out[i] = sub_conds_out[i].narrow(self.dim, 1, sub_conds_out[i].shape[self.dim] - 1)
361+
335362
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
336363
return results
337364

comfy_extras/nodes_context_windows.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def define_schema(cls) -> io.Schema:
2929
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
3030
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
3131
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
32+
io.Boolean.Input("causal_window_fix", default=True, tooltip="Whether to add a causal fix frame to non-0-indexed context windows."),
3233
],
3334
outputs=[
3435
io.Model.Output(tooltip="The model with context windows applied during sampling."),
@@ -38,7 +39,7 @@ def define_schema(cls) -> io.Schema:
3839

3940
@classmethod
4041
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
41-
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
42+
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, causal_window_fix: bool=True) -> io.Model:
4243
model = model.clone()
4344
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
4445
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
@@ -50,7 +51,8 @@ def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int
5051
dim=dim,
5152
freenoise=freenoise,
5253
cond_retain_index_list=cond_retain_index_list,
53-
split_conds_to_windows=split_conds_to_windows
54+
split_conds_to_windows=split_conds_to_windows,
55+
causal_window_fix=causal_window_fix,
5456
)
5557
# make memory usage calculation only take into account the context window latents
5658
comfy.context_windows.create_prepare_sampling_wrapper(model)

0 commit comments

Comments
 (0)