@@ -182,22 +182,22 @@ def compute_guide_overlap(guide_entries: list[dict], window_index_list: list[int
182182
183183
184184def inject_guide_frames_into_window (video_slice : torch .Tensor , window : ContextWindowABC ,
185- aux_data : dict , dim : int ) -> tuple [torch .Tensor , int ]:
185+ window_data : 'WindowingContext' , dim : int ) -> tuple [torch .Tensor , int ]:
186186 """Inject overlapping guide frames into a context window slice.
187187
188- Uses aux_data from WindowingContext to determine which guide frames overlap
189- with this window's indices, concatenates them onto the video slice, and sets
190- window attributes for downstream conditioning resize.
188+ Determines which guide frames overlap with this window's indices, concatenates
189+ them onto the video slice, and sets window attributes for downstream conditioning resize.
191190
192191 Returns (augmented_slice, num_guide_frames_added).
193192 """
194- guide_entries = aux_data ["guide_entries" ]
195- guide_frames = aux_data [ " guide_frames" ]
193+ guide_entries = window_data . aux_data ["guide_entries" ]
194+ guide_frames = window_data . guide_frames
196195 overlap = compute_guide_overlap (guide_entries , window .index_list )
197196 suffix_idx , overlap_info , kf_local_pos , guide_frame_count = overlap
198197 window .guide_frames_indices = suffix_idx
199198 window .guide_overlap_info = overlap_info
200199 window .guide_kf_local_positions = kf_local_pos
200+
201201 # Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims.
202202 # guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims.
203203 guide_downscale_factors = []
@@ -207,6 +207,7 @@ def inject_guide_frames_into_window(video_slice: torch.Tensor, window: ContextWi
207207 entry_H = guide_entries [entry_idx ]["latent_shape" ][1 ]
208208 guide_downscale_factors .append (full_H // entry_H )
209209 window .guide_downscale_factors = guide_downscale_factors
210+
210211 if guide_frame_count > 0 :
211212 idx = tuple ([slice (None )] * dim + [suffix_idx ])
212213 sliced_guide = guide_frames [idx ]
@@ -220,7 +221,6 @@ class WindowingContext:
220221 guide_frames : torch .Tensor | None
221222 aux_data : Any
222223 latent_shapes : list | None
223- is_multimodal : bool
224224
225225@dataclass
226226class ContextSchedule :
@@ -310,13 +310,13 @@ def _build_window_data(self, x_in: torch.Tensor, conds: list[list[dict]]) -> Win
310310 guide_frames = video_latent .narrow (self .dim , primary_frame_count , guide_frame_count ) if guide_frame_count > 0 else None
311311
312312 if guide_frame_count > 0 :
313- aux_data = {"guide_entries" : guide_entries , "guide_frames" : guide_frames }
313+ aux_data = {"guide_entries" : guide_entries }
314314 else :
315315 aux_data = None
316316
317317 return WindowingContext (
318318 tensor = primary_frames , guide_frames = guide_frames , aux_data = aux_data ,
319- latent_shapes = latent_shapes , is_multimodal = is_multimodal )
319+ latent_shapes = latent_shapes )
320320
321321 def should_use_context (self , model : BaseModel , conds : list [list [dict ]], x_in : torch .Tensor , timestep : torch .Tensor , model_options : dict [str ]) -> bool :
322322 self ._window_data = self ._build_window_data (x_in , conds )
@@ -437,9 +437,14 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[
437437 self .set_step (timestep , model_options )
438438
439439 window_data = self ._window_data
440- if window_data .is_multimodal or (window_data .guide_frames is not None and window_data .guide_frames .size (self .dim ) > 0 ):
440+ is_multimodal = window_data .latent_shapes is not None and len (window_data .latent_shapes ) > 1
441+ has_guide_frames = window_data .guide_frames is not None and window_data .guide_frames .size (self .dim ) > 0
442+
443+ # if multimodal or has concatenated guide frames on noise latent, use the extended execute path
444+ if is_multimodal or has_guide_frames :
441445 return self ._execute_extended (calc_cond_batch , model , conds , x_in , timestep , model_options , window_data )
442446
447+ # basic legacy execution path for single-modal video latent with no guide frames concatenated
443448 context_windows = self .get_context_windows (model , x_in , model_options )
444449 enumerated_context_windows = list (enumerate (context_windows ))
445450
@@ -475,8 +480,9 @@ def _execute_extended(self, calc_cond_batch: Callable, model: BaseModel, conds:
475480 timestep : torch .Tensor , model_options : dict [str ],
476481 window_data : WindowingContext ):
477482 """Extended execute path for multimodal models and models with guide frames appended to the noise latent."""
483+
478484 latents = self ._unpack (x_in , window_data .latent_shapes )
479- is_multimodal = window_data .is_multimodal
485+ is_multimodal = window_data .latent_shapes is not None and len ( window_data . latent_shapes ) > 1
480486
481487 primary_frames = window_data .tensor
482488 num_guide_frames = window_data .guide_frames .size (self .dim ) if window_data .guide_frames is not None else 0
@@ -538,7 +544,7 @@ def _execute_extended(self, calc_cond_batch: Callable, model: BaseModel, conds:
538544 # Slice video, then inject overlapping guide frames if present
539545 sliced_video = per_modality_windows_list [0 ].get_tensor (primary_frames , retain_index_list = self .cond_retain_index_list )
540546 if window_data .aux_data is not None :
541- sliced_primary , num_guide_frames = inject_guide_frames_into_window (sliced_video , window , window_data . aux_data , self .dim )
547+ sliced_primary , num_guide_frames = inject_guide_frames_into_window (sliced_video , window , window_data , self .dim )
542548 else :
543549 sliced_primary , num_guide_frames = sliced_video , 0
544550 sliced = [sliced_primary ] + [per_modality_windows_list [mi ].get_tensor (latents [mi ]) for mi in range (1 , len (latents ))]
0 commit comments