@@ -367,18 +367,60 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[
367367 self ._model = model
368368 self .set_step (timestep , model_options )
369369
370- # Decompose — single-modality: [x_in], multimodal: [video, audio, ...]
370+ # Check if multimodal or model has auxiliary frames requiring the extended path
371371 latent_shapes = self ._get_latent_shapes (conds )
372+ is_multimodal = latent_shapes is not None and len (latent_shapes ) > 1
373+ if is_multimodal :
374+ return self ._execute_extended (calc_cond_batch , model , conds , x_in , timestep , model_options , latent_shapes )
375+ window_data = model .prepare_for_windowing (x_in , conds , self .dim )
376+ if window_data .suffix is not None or window_data .aux_data is not None :
377+ return self ._execute_extended (calc_cond_batch , model , conds , x_in , timestep , model_options ,
378+ latent_shapes , window_data )
379+
380+ context_windows = self .get_context_windows (model , x_in , model_options )
381+ enumerated_context_windows = list (enumerate (context_windows ))
382+
383+ conds_final = [torch .zeros_like (x_in ) for _ in conds ]
384+ if self .fuse_method .name == ContextFuseMethods .RELATIVE :
385+ counts_final = [torch .ones (get_shape_for_dim (x_in , self .dim ), device = x_in .device ) for _ in conds ]
386+ else :
387+ counts_final = [torch .zeros (get_shape_for_dim (x_in , self .dim ), device = x_in .device ) for _ in conds ]
388+ biases_final = [([0.0 ] * x_in .shape [self .dim ]) for _ in conds ]
389+
390+ for callback in comfy .patcher_extension .get_all_callbacks (IndexListCallbacks .EXECUTE_START , self .callbacks ):
391+ callback (self , model , x_in , conds , timestep , model_options )
392+
393+ for enum_window in enumerated_context_windows :
394+ results = self .evaluate_context_windows (calc_cond_batch , model , x_in , conds , timestep , [enum_window ], model_options )
395+ for result in results :
396+ self .combine_context_window_results (x_in , result .sub_conds_out , result .sub_conds , result .window , result .window_idx , len (enumerated_context_windows ), timestep ,
397+ conds_final , counts_final , biases_final )
398+ try :
399+ if self .fuse_method .name == ContextFuseMethods .RELATIVE :
400+ del counts_final
401+ return conds_final
402+ else :
403+ for i in range (len (conds_final )):
404+ conds_final [i ] /= counts_final [i ]
405+ del counts_final
406+ return conds_final
407+ finally :
408+ for callback in comfy .patcher_extension .get_all_callbacks (IndexListCallbacks .EXECUTE_CLEANUP , self .callbacks ):
409+ callback (self , model , x_in , conds , timestep , model_options )
410+
411+ def _execute_extended (self , calc_cond_batch : Callable , model : BaseModel , conds : list [list [dict ]], x_in : torch .Tensor ,
412+ timestep : torch .Tensor , model_options : dict [str ],
413+ latent_shapes , window_data : WindowingContext = None ):
414+ """Extended execute path for multimodal models and models with auxiliary frames."""
372415 modalities = self ._decompose (x_in , latent_shapes )
373416 is_multimodal = len (modalities ) > 1
374- primary = modalities [0 ]
375417
376- # Let model strip auxiliary frames (e.g. guide frames)
377- window_data = model .prepare_for_windowing (primary , conds , self .dim )
418+ if window_data is None :
419+ window_data = model .prepare_for_windowing (modalities [0 ], conds , self .dim )
420+
378421 video_primary = window_data .tensor
379422 aux_count = window_data .suffix .size (self .dim ) if window_data .suffix is not None else 0
380423
381- # Windows from video portion only
382424 context_windows = self .get_context_windows (model , video_primary , model_options )
383425 enumerated_context_windows = list (enumerate (context_windows ))
384426 total_windows = len (enumerated_context_windows )
@@ -407,14 +449,13 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[
407449 # Per-modality window indices
408450 if is_multimodal :
409451 map_shapes = latent_shapes
410- if video_primary .size (self .dim ) != primary .size (self .dim ):
452+ if video_primary .size (self .dim ) != modalities [ 0 ] .size (self .dim ):
411453 map_shapes = list (latent_shapes )
412454 video_shape = list (latent_shapes [0 ])
413455 video_shape [self .dim ] = video_primary .size (self .dim )
414456 map_shapes [0 ] = torch .Size (video_shape )
415457 per_mod_indices = model .map_context_window_to_modalities (
416458 window .index_list , map_shapes , self .dim ) if hasattr (model , 'map_context_window_to_modalities' ) else [window .index_list ]
417- # Build per-modality windows and attach to primary window
418459 modality_windows = {}
419460 for mod_idx in range (1 , len (modalities )):
420461 modality_windows [mod_idx ] = IndexListContextWindow (
@@ -423,11 +464,9 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[
423464 window = IndexListContextWindow (
424465 window .index_list , dim = self .dim , total_frames = video_primary .shape [self .dim ],
425466 modality_windows = modality_windows )
426- else :
427- per_mod_indices = [window .index_list ]
428467
429- # Build per-modality windows list (including primary)
430- mod_windows = [window ] # primary window at index 0
468+ # Build per-modality windows list
469+ mod_windows = [window ]
431470 if is_multimodal :
432471 for mod_idx in range (1 , len (modalities )):
433472 mod_windows .append (modality_windows [mod_idx ])
@@ -438,10 +477,8 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[
438477 sliced_video , window , window_data .aux_data , self .dim )
439478 sliced = [sliced_primary ] + [mod_windows [mi ].get_tensor (modalities [mi ]) for mi in range (1 , len (modalities ))]
440479
441- # Compose for pipeline
442480 sub_x , sub_shapes = self ._compose (sliced )
443481
444- # Callbacks
445482 for callback in comfy .patcher_extension .get_all_callbacks (IndexListCallbacks .EVALUATE_CONTEXT_WINDOWS , self .callbacks ):
446483 callback (self , model , x_in , conds , timestep , model_options , window_idx , window , model_options , None , None )
447484
@@ -462,7 +499,7 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[
462499 for ci in range (len (sub_conds_out )):
463500 out_per_mod [ci ][0 ] = out_per_mod [ci ][0 ].narrow (self .dim , 0 , window_len )
464501
465- # Accumulate per modality (using video-only sizes)
502+ # Accumulate per modality
466503 for mod_idx in range (len (accum_modalities )):
467504 mw = mod_windows [mod_idx ]
468505 mod_sub_out = [out_per_mod [ci ][mod_idx ] for ci in range (len (sub_conds_out ))]
@@ -479,7 +516,6 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[
479516 if self .fuse_method .name != ContextFuseMethods .RELATIVE :
480517 accum [mod_idx ][ci ] /= counts [mod_idx ][ci ]
481518 f = accum [mod_idx ][ci ]
482- # Re-append model's suffix (auxiliary frames stripped before windowing)
483519 if mod_idx == 0 and window_data .suffix is not None :
484520 f = torch .cat ([f , window_data .suffix ], dim = self .dim )
485521 finalized .append (f )
0 commit comments