@@ -124,13 +124,12 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int
124124 return latents
125125
126126
127- def _normalize_latents (
128- latents : torch .Tensor , latents_mean : torch . Tensor , latents_std : torch . Tensor , scaling_factor : float = 1.0
127+ def _unpack_latents (
128+ latents : torch .Tensor , num_frames : int , height : int , width : int , patch_size : int = 1 , patch_size_t : int = 1
129129) -> torch .Tensor :
130- # Normalize latents across the channel dimension [B, C, F, H, W]
131- latents_mean = latents_mean .view (1 , - 1 , 1 , 1 , 1 ).to (latents .device , latents .dtype )
132- latents_std = latents_std .view (1 , - 1 , 1 , 1 , 1 ).to (latents .device , latents .dtype )
133- latents = (latents - latents_mean ) * scaling_factor / latents_std
130+ batch_size = latents .size (0 )
131+ latents = latents .reshape (batch_size , num_frames , height , width , - 1 , patch_size_t , patch_size , patch_size )
132+ latents = latents .permute (0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 ).flatten (6 , 7 ).flatten (4 , 5 ).flatten (2 , 3 )
134133 return latents
135134
136135
@@ -343,19 +342,19 @@ class LTXImage2VideoPrepareLatentsStep(ModularPipelineBlocks):
343342 @property
344343 def description (self ) -> str :
345344 return (
346- "Prepare latents step for image-to-video: takes pre-encoded image latents and creates a conditioning mask"
345+ "Prepare image-to-video latents: adds noise to pre-encoded image latents and creates a conditioning mask. "
346+ "Expects pure noise `latents` from LTXPrepareLatentsStep."
347347 )
348348
349349 @property
350350 def inputs (self ) -> list [InputParam ]:
351351 return [
352352 InputParam ("image_latents" , type_hint = torch .Tensor , required = True ),
353+ InputParam .template ("latents" , required = True ),
353354 InputParam .template ("height" , default = 512 ),
354355 InputParam .template ("width" , default = 704 ),
355356 InputParam ("num_frames" , type_hint = int , default = 161 ),
356- InputParam .template ("latents" ),
357357 InputParam .template ("num_images_per_prompt" , name = "num_videos_per_prompt" ),
358- InputParam .template ("generator" ),
359358 InputParam .template ("batch_size" , required = True ),
360359 ]
361360
@@ -377,37 +376,31 @@ def __call__(self, components: LTXModularPipeline, state: PipelineState) -> Pipe
377376 width = block_state .width // components .vae_spatial_compression_ratio
378377 num_frames = (block_state .num_frames - 1 ) // components .vae_temporal_compression_ratio + 1
379378
380- mask_shape = (batch_size , 1 , num_frames , height , width )
381-
382- if block_state .latents is not None :
383- conditioning_mask = block_state .latents .new_zeros (mask_shape )
384- conditioning_mask [:, :, 0 ] = 1.0
385- conditioning_mask = _pack_latents (
386- conditioning_mask ,
387- components .transformer_spatial_patch_size ,
388- components .transformer_temporal_patch_size ,
389- ).squeeze (- 1 )
390- block_state .latents = block_state .latents .to (device = device , dtype = torch .float32 )
391- block_state .conditioning_mask = conditioning_mask
392- self .set_block_state (state , block_state )
393- return components , state
394-
395379 init_latents = block_state .image_latents .to (device = device , dtype = torch .float32 )
396380 if init_latents .shape [0 ] < batch_size :
397381 init_latents = init_latents .repeat_interleave (batch_size // init_latents .shape [0 ], dim = 0 )
398382 init_latents = init_latents .repeat (1 , 1 , num_frames , 1 , 1 )
399383
400- actual_mask_shape = (
384+ conditioning_mask = torch . zeros (
401385 init_latents .shape [0 ],
402386 1 ,
403387 init_latents .shape [2 ],
404388 init_latents .shape [3 ],
405389 init_latents .shape [4 ],
390+ device = device ,
391+ dtype = torch .float32 ,
406392 )
407- conditioning_mask = torch .zeros (actual_mask_shape , device = device , dtype = torch .float32 )
408393 conditioning_mask [:, :, 0 ] = 1.0
409394
410- noise = randn_tensor (init_latents .shape , generator = block_state .generator , device = device , dtype = torch .float32 )
395+ # Unpack the pure noise latents from LTXPrepareLatentsStep to mix with image latents
396+ noise = _unpack_latents (
397+ block_state .latents ,
398+ num_frames ,
399+ height ,
400+ width ,
401+ components .transformer_spatial_patch_size ,
402+ components .transformer_temporal_patch_size ,
403+ )
411404 latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask )
412405
413406 conditioning_mask = _pack_latents (
0 commit comments