2323from lightx2v .utils .profiler import ProfilingContext4DebugL1 , ProfilingContext4DebugL2
2424from lightx2v .utils .registry_factory import RUNNER_REGISTER
2525from lightx2v .utils .utils import is_main_process , save_to_video , wan_vae_to_comfy
26+ from lightx2v .utils .va_controller import VAController
2627from lightx2v_platform .base .global_var import AI_DEVICE
2728
2829torch_device_module = getattr (torch , AI_DEVICE )
@@ -106,8 +107,11 @@ def __init__(self, config):
106107 self .audio_sample_rate = int (self .config .get ("audio_sample_rate" , 16000 ))
107108 self .target_fps = int (self .config .get ("target_fps" , 25 ))
108109 self .video_audio_path = None
110+ self .video_audio_array = None
109111 self .cond_video_temp_path = None
110112 self .cond_video_duration = None
113+ self .va_controller = None
114+ self .stream_save_video = False
111115
112116 def init_scheduler (self ):
113117 self .scheduler = InfiniteTalkScheduler (self .config )
@@ -289,6 +293,7 @@ def _audio_prepare_multi(self, left_path, right_path, audio_type):
289293 return new_speech1 , new_speech2 , new_speech1 + new_speech2
290294
291295 def _write_sum_audio (self , input_data , audio_arrays ):
296+ self .video_audio_array = np .asarray (audio_arrays , dtype = np .float32 )
292297 if sf is not None :
293298 fd , audio_path = tempfile .mkstemp (prefix = "infinitetalk_sum_" , suffix = ".wav" )
294299 os .close (fd )
@@ -506,6 +511,7 @@ def _run_input_encoder_local_s2v(self):
506511 self .cond_video_duration = self ._get_cond_video_duration (self .cond_file_path )
507512 first_image = self ._extract_specific_frame (self .cond_file_path , 0 )
508513 self .src_h , self .src_w , self .target_h , self .target_w = self ._select_target_size (first_image )
514+ self .input_info .target_shape = [self .target_h , self .target_w ]
509515
510516 full_audio_embs = self ._prepare_audio_embeddings (input_data )
511517 if any (audio_emb .shape [0 ] <= 0 for audio_emb in full_audio_embs ):
@@ -623,7 +629,7 @@ def init_run(self):
623629
624630 self .cond_image = self ._prepare_cond_image (0 )
625631 self .cond_frame = None
626- self .gen_video_list = []
632+ self .gen_video_list = None if self . stream_save_video else []
627633
628634 def get_video_segment_num (self ):
629635 if self .expected_frames <= self .frame_num :
@@ -683,6 +689,47 @@ def run_segment(self, segment_idx=0):
683689 self ._run_dit_clip (self .dit_inputs )
684690 return self .scheduler .latents
685691
692+ def _should_stream_save_video (self ):
693+ return bool (self .config .get ("stream_save_video" , True ) and not self .input_info .return_result_tensor and getattr (self .input_info , "save_result_path" , None ))
694+
695+ def _init_stream_video_controller (self ):
696+ if not self .stream_save_video :
697+ return
698+ self .va_controller = VAController (self )
699+ logger .info (f"init va_recorder: { self .va_controller .recorder } and va_reader: { self .va_controller .reader } " )
700+
701+ def _get_audio_segment (self , start_frame , frame_count ):
702+ audio_sample_start = int (round (start_frame * self .audio_sample_rate / self .target_fps ))
703+ audio_sample_end = int (round ((start_frame + frame_count ) * self .audio_sample_rate / self .target_fps ))
704+ audio_sample_count = max (audio_sample_end - audio_sample_start , 0 )
705+ if audio_sample_count == 0 :
706+ return torch .zeros (0 , dtype = torch .float32 )
707+
708+ if self .video_audio_array is None :
709+ return torch .zeros (audio_sample_count , dtype = torch .float32 )
710+
711+ audio = self .video_audio_array .reshape (- 1 )
712+ audio_seg = audio [audio_sample_start : min (audio_sample_end , audio .shape [0 ])]
713+ if audio_seg .shape [0 ] < audio_sample_count :
714+ audio_seg = np .pad (audio_seg , (0 , audio_sample_count - audio_seg .shape [0 ]))
715+ return torch .from_numpy (audio_seg .astype (np .float32 , copy = False ))
716+
717+ def _publish_video_segment (self , videos , start_frame ):
718+ if self .va_controller is None or self .va_controller .recorder is None :
719+ return
720+ frame_count = videos .shape [2 ]
721+ if frame_count <= 0 :
722+ return
723+ video_seg = videos [:, :, :frame_count ].to (torch .float32 )
724+ comfy_video = wan_vae_to_comfy (video_seg .cpu ())
725+ audio_seg = self ._get_audio_segment (start_frame , frame_count )
726+ self .va_controller .pub_livestream (
727+ comfy_video ,
728+ audio_seg ,
729+ video_seg .cpu (),
730+ valid_duration = frame_count / self .target_fps ,
731+ )
732+
686733 @ProfilingContext4DebugL1 (
687734 "End run segment" ,
688735 recorder_mode = GET_RECORDER_MODE (),
@@ -692,9 +739,19 @@ def run_segment(self, segment_idx=0):
692739 def end_run_segment (self , segment_idx , latents ):
693740 videos = self .run_vae_decoder (latents ).cpu ()
694741 if self .is_first_segment :
695- self .gen_video_list .append (videos )
742+ output_videos = videos
743+ output_start_frame = 0
696744 else :
697- self .gen_video_list .append (videos [:, :, self .current_motion_frames_num :])
745+ output_videos = videos [:, :, self .current_motion_frames_num :]
746+ output_start_frame = self .audio_start_idx + self .current_motion_frames_num
747+
748+ valid_frames = min (output_videos .shape [2 ], max (self .expected_frames - output_start_frame , 0 ))
749+ if valid_frames > 0 :
750+ output_videos = output_videos [:, :, :valid_frames ]
751+ if self .stream_save_video :
752+ self ._publish_video_segment (output_videos , output_start_frame )
753+ else :
754+ self .gen_video_list .append (output_videos )
698755
699756 if segment_idx < self .video_segment_num - 1 :
700757 self .cond_frame = videos [:, :, - self .motion_frame :].to (torch .float32 ).to (AI_DEVICE )
@@ -706,8 +763,10 @@ def end_run_segment(self, segment_idx, latents):
706763
707764 @ProfilingContext4DebugL2 ("Run DiT + decode" )
708765 def run_main (self ):
766+ self .stream_save_video = self ._should_stream_save_video ()
709767 self .init_run ()
710768 self .get_video_segment_num ()
769+ self ._init_stream_video_controller ()
711770
712771 for segment_idx in range (self .video_segment_num ):
713772 logger .info (f"start InfiniteTalk segment { segment_idx + 1 } /{ self .video_segment_num } " )
@@ -716,11 +775,19 @@ def run_main(self):
716775 latents = self .run_segment (segment_idx )
717776 self .end_run_segment (segment_idx , latents )
718777
778+ if self .stream_save_video :
779+ return self .process_images_after_vae_decoder ()
780+
719781 self .gen_video = torch .cat (self .gen_video_list , dim = 2 )[:, :, : self .expected_frames ].to (torch .float32 )
720782 return self .process_images_after_vae_decoder ()
721783
722784 @ProfilingContext4DebugL1 ("Process after vae decoder" )
723785 def process_images_after_vae_decoder (self ):
786+ if self .stream_save_video :
787+ if self .input_info .save_result_path is not None and is_main_process ():
788+ logger .info (f"Video saved to { self .input_info .save_result_path } " )
789+ return {"video" : None }
790+
724791 self .gen_video_final = wan_vae_to_comfy (self .gen_video )
725792 if self .input_info .return_result_tensor :
726793 return {"video" : self .gen_video_final }
@@ -771,8 +838,13 @@ def _mux_audio(video_path, audio_path):
771838 os .remove (tmp_path )
772839
773840 def end_run (self ):
841+ if self .va_controller is not None :
842+ self .va_controller .clear ()
843+ self .va_controller = None
774844 self ._remove_video_audio_path ()
775845 self ._remove_cond_video_temp_path ()
846+ self .video_audio_array = None
847+ self .stream_save_video = False
776848 if hasattr (self , "inputs" ):
777849 del self .inputs
778850 torch .cuda .empty_cache ()
0 commit comments