@@ -67,8 +67,9 @@ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
6767 LTX2AudioVideoUnit_SwitchStage2 (),
6868 LTX2AudioVideoUnit_NoiseInitializer (),
6969 LTX2AudioVideoUnit_LatentsUpsampler (),
70- LTX2AudioVideoUnit_SetScheduleStage2 (),
7170 LTX2AudioVideoUnit_InputImagesEmbedder (),
71+ LTX2AudioVideoUnit_InputAudioEmbedder (),
72+ LTX2AudioVideoUnit_SetScheduleStage2 (),
7273 ]
7374 self .model_fn = model_fn_ltx2
7475
@@ -155,8 +156,9 @@ def denoise_stage(self, inputs_shared, inputs_posi, inputs_nega, units, cfg_scal
155156 ** models , timestep = timestep , progress_id = progress_id
156157 )
157158 inputs_shared ["video_latents" ] = self .step (self .scheduler , inputs_shared ["video_latents" ], progress_id = progress_id , noise_pred = noise_pred_video ,
158- inpaint_mask = inputs_shared .get ("denoise_mask_video" , None ), input_latents = inputs_shared .get ("input_latents_video" , None ), ** inputs_shared )
159- inputs_shared ["audio_latents" ] = self .step (self .scheduler , inputs_shared ["audio_latents" ], progress_id = progress_id , noise_pred = noise_pred_audio , ** inputs_shared )
159+ inpaint_mask = inputs_shared .get ("video_denoise_mask" , None ), input_latents = inputs_shared .get ("video_input_latents" , None ), ** inputs_shared )
160+ inputs_shared ["audio_latents" ] = self .step (self .scheduler , inputs_shared ["audio_latents" ], progress_id = progress_id , noise_pred = noise_pred_audio ,
161+ inpaint_mask = inputs_shared .get ("audio_denoise_mask" , None ), input_latents = inputs_shared .get ("audio_input_latents" , None ), ** inputs_shared )
160162 return inputs_shared , inputs_posi , inputs_nega
161163
162164 @torch .no_grad ()
@@ -173,6 +175,9 @@ def __call__(
173175 # In-Context Video Control
174176 in_context_videos : Optional [list [list [Image .Image ]]] = None ,
175177 in_context_downsample_factor : Optional [int ] = 2 ,
178+ # Audio-to-video
179+ input_audio : Optional [torch .Tensor ] = None ,
180+ audio_sample_rate : Optional [int ] = 48000 ,
176181 # Randomness
177182 seed : Optional [int ] = None ,
178183 rand_device : Optional [str ] = "cpu" ,
@@ -210,6 +215,7 @@ def __call__(
210215 }
211216 inputs_shared = {
212217 "input_images" : input_images , "input_images_indexes" : input_images_indexes , "input_images_strength" : input_images_strength ,
218+ "input_audio" : (input_audio , audio_sample_rate ) if input_audio is not None else None ,
213219 "in_context_videos" : in_context_videos , "in_context_downsample_factor" : in_context_downsample_factor ,
214220 "seed" : seed , "rand_device" : rand_device ,
215221 "height" : height , "width" : width , "num_frames" : num_frames , "frame_rate" : frame_rate ,
@@ -361,7 +367,8 @@ def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, tiled,
361367 input_video = pipe .preprocess_video (input_video )
362368 input_latents = pipe .video_vae_encoder .encode (input_video , tiled , tile_size_in_pixels , tile_overlap_in_pixels ).to (dtype = pipe .torch_dtype , device = pipe .device )
363369 if pipe .scheduler .training :
364- return {"video_latents" : input_latents , "input_latents" : input_latents }
370+ # input_latents key is for training to add noise. with no prefix "video" to keep loss function keyword consistent.
371+ return {"video_latents" : video_noise , "input_latents" : input_latents }
365372 else :
366373 raise NotImplementedError ("Video-to-video not implemented yet." )
367374
@@ -370,7 +377,7 @@ class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
370377 def __init__ (self ):
371378 super ().__init__ (
372379 input_params = ("input_audio" , "audio_noise" ),
373- output_params = ("audio_latents" , "audio_input_latents" , "audio_positions" , "audio_latent_shape" ),
380+ output_params = ("audio_latents" , "audio_input_latents" , "audio_noise" , " audio_positions" , "audio_latent_shape" , "audio_denoise_mask " ),
374381 onload_model_names = ("audio_vae_encoder" ,)
375382 )
376383
@@ -380,21 +387,37 @@ def process(self, pipe: LTX2AudioVideoPipeline, input_audio, audio_noise):
380387 else :
381388 input_audio , sample_rate = input_audio
382389 pipe .load_models_to_device (self .onload_model_names )
383- input_audio = pipe .audio_processor .waveform_to_mel (input_audio .unsqueeze (0 ), waveform_sample_rate = sample_rate ).to (dtype = pipe .torch_dtype )
390+ input_audio = pipe .audio_processor .waveform_to_mel (input_audio .unsqueeze (0 ), waveform_sample_rate = sample_rate ).to (dtype = pipe .torch_dtype , device = pipe . device )
384391 audio_input_latents = pipe .audio_vae_encoder (input_audio )
385392 audio_latent_shape = AudioLatentShape .from_torch_shape (audio_input_latents .shape )
386393 audio_positions = pipe .audio_patchifier .get_patch_grid_bounds (audio_latent_shape , device = pipe .device )
387394 if pipe .scheduler .training :
388- return {"audio_latents" : audio_input_latents , "audio_input_latents" : audio_input_latents , "audio_positions" : audio_positions , "audio_latent_shape" : audio_latent_shape }
395+ return {
396+ "audio_latents" : audio_input_latents ,
397+ "audio_input_latents" : audio_input_latents ,
398+ "audio_positions" : audio_positions ,
399+ "audio_latent_shape" : audio_latent_shape ,
400+ }
389401 else :
390- raise NotImplementedError ("Audio-to-video not supported." )
402+ b , c , t , f = audio_input_latents .shape
403+ audio_denoise_mask = torch .zeros ((b , 1 , t , 1 ), device = audio_input_latents .device , dtype = audio_input_latents .dtype )
404+ audio_noise = torch .rand_like (audio_input_latents )
405+ audio_latents = pipe .scheduler .add_noise (audio_input_latents , audio_noise , pipe .scheduler .timesteps [0 ])
406+ return {
407+ "audio_latents" : audio_latents ,
408+ "audio_input_latents" : audio_input_latents ,
409+ "audio_noise" : audio_noise ,
410+ "audio_positions" : audio_positions ,
411+ "audio_latent_shape" : audio_latent_shape ,
412+ "audio_denoise_mask" : audio_denoise_mask ,
413+ }
391414
392415
393416class LTX2AudioVideoUnit_InputImagesEmbedder (PipelineUnit ):
394417 def __init__ (self ):
395418 super ().__init__ (
396419 input_params = ("input_images" , "input_images_indexes" , "input_images_strength" , "video_latents" , "height" , "width" , "frame_rate" , "tiled" , "tile_size_in_pixels" , "tile_overlap_in_pixels" , "initial_latents" ),
397- output_params = ("denoise_mask_video " , "input_latents_video " , "ref_frames_latents" , "ref_frames_positions" ),
420+ output_params = ("video_denoise_mask " , "video_input_latents " , "ref_frames_latents" , "ref_frames_positions" ),
398421 onload_model_names = ("video_vae_encoder" )
399422 )
400423
@@ -406,9 +429,9 @@ def get_image_latent(self, pipe, input_image, height, width, tiled, tile_size_in
406429 latents = pipe .video_vae_encoder .encode (image , tiled , tile_size_in_pixels , tile_overlap_in_pixels ).to (pipe .device )
407430 return latents
408431
409- def apply_input_images_to_latents (self , latents , input_latents , input_indexes , input_strength = 1.0 , initial_latents = None , denoise_mask_video = None ):
432+ def apply_input_images_to_latents (self , latents , input_latents , input_indexes , input_strength = 1.0 , initial_latents = None , video_denoise_mask = None ):
410433 b , _ , f , h , w = latents .shape
411- denoise_mask = torch .ones ((b , 1 , f , h , w ), dtype = latents .dtype , device = latents .device ) if denoise_mask_video is None else denoise_mask_video
434+ denoise_mask = torch .ones ((b , 1 , f , h , w ), dtype = latents .dtype , device = latents .device ) if video_denoise_mask is None else video_denoise_mask
412435 initial_latents = torch .zeros_like (latents ) if initial_latents is None else initial_latents
413436 for idx , input_latent in zip (input_indexes , input_latents ):
414437 idx = min (max (1 + (idx - 1 ) // 8 , 0 ), f - 1 )
@@ -424,13 +447,13 @@ def process(self, pipe: LTX2AudioVideoPipeline, video_latents, input_images, hei
424447 if len (input_images_indexes ) != len (set (input_images_indexes )):
425448 raise ValueError ("Input images must have unique indexes." )
426449 pipe .load_models_to_device (self .onload_model_names )
427- frame_conditions = {"input_latents_video " : None , "denoise_mask_video " : None , "ref_frames_latents" : [], "ref_frames_positions" : []}
450+ frame_conditions = {"video_input_latents " : None , "video_denoise_mask " : None , "ref_frames_latents" : [], "ref_frames_positions" : []}
428451 for img , index in zip (input_images , input_images_indexes ):
429452 latents = self .get_image_latent (pipe , img , height , width , tiled , tile_size_in_pixels , tile_overlap_in_pixels )
430453 # first_frame by replacing latents
431454 if index == 0 :
432- input_latents_video , denoise_mask_video = self .apply_input_images_to_latents (video_latents , [latents ], [0 ], input_images_strength , initial_latents )
433- frame_conditions .update ({"input_latents_video " : input_latents_video , "denoise_mask_video " : denoise_mask_video })
455+ video_input_latents , video_denoise_mask = self .apply_input_images_to_latents (video_latents , [latents ], [0 ], input_images_strength , initial_latents )
456+ frame_conditions .update ({"video_input_latents " : video_input_latents , "video_denoise_mask " : video_denoise_mask })
434457 # other frames by adding reference latents
435458 else :
436459 latent_coords = pipe .video_patchifier .get_patch_grid_bounds (output_shape = VideoLatentShape .from_torch_shape (latents .shape ), device = pipe .device )
@@ -560,14 +583,17 @@ def model_fn_ltx2(
560583 audio_patchifier = None ,
561584 timestep = None ,
562585 # First Frame Conditioning
563- input_latents_video = None ,
564- denoise_mask_video = None ,
586+ video_input_latents = None ,
587+ video_denoise_mask = None ,
565588 # Other Frames Conditioning
566589 ref_frames_latents = None ,
567590 ref_frames_positions = None ,
568591 # In-Context Conditioning
569592 in_context_video_latents = None ,
570593 in_context_video_positions = None ,
594+ # Audio Inputs
595+ audio_input_latents = None ,
596+ audio_denoise_mask = None ,
571597 # Gradient Checkpointing
572598 use_gradient_checkpointing = False ,
573599 use_gradient_checkpointing_offload = False ,
@@ -581,12 +607,12 @@ def model_fn_ltx2(
581607 seq_len_video = video_latents .shape [1 ]
582608 video_timesteps = timestep .repeat (1 , video_latents .shape [1 ], 1 )
583609 # Frist frame conditioning by replacing the video latents
584- if input_latents_video is not None :
585- denoise_mask_video = video_patchifier .patchify (denoise_mask_video )
586- video_latents = video_latents * denoise_mask_video + video_patchifier .patchify (input_latents_video ) * (1.0 - denoise_mask_video )
587- video_timesteps = denoise_mask_video * video_timesteps
588-
589- # Conditioning by replacing the video latents
610+ if video_input_latents is not None :
611+ video_denoise_mask = video_patchifier .patchify (video_denoise_mask )
612+ video_latents = video_latents * video_denoise_mask + video_patchifier .patchify (video_input_latents ) * (1.0 - video_denoise_mask )
613+ video_timesteps = video_denoise_mask * video_timesteps
614+
615+ # Reference conditioning by appending the reference video or frame latents
590616 total_ref_latents = ref_frames_latents if ref_frames_latents is not None else []
591617 total_ref_positions = ref_frames_positions if ref_frames_positions is not None else []
592618 total_ref_latents += [in_context_video_latents ] if in_context_video_latents is not None else []
@@ -605,6 +631,10 @@ def model_fn_ltx2(
605631 audio_timesteps = timestep .repeat (1 , audio_latents .shape [1 ], 1 )
606632 else :
607633 audio_timesteps = None
634+ if audio_input_latents is not None :
635+ audio_denoise_mask = audio_patchifier .patchify (audio_denoise_mask )
636+ audio_latents = audio_latents * audio_denoise_mask + audio_patchifier .patchify (audio_input_latents ) * (1.0 - audio_denoise_mask )
637+ audio_timesteps = audio_denoise_mask * audio_timesteps
608638
609639 vx , ax = dit (
610640 video_latents = video_latents ,
0 commit comments