Skip to content

Commit 00c9da5

Browse files
authored
stream_save_video for infinitetalk (#1204)
1 parent 2cbe1f2 commit 00c9da5

5 files changed

Lines changed: 156 additions & 10 deletions

File tree

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
{
2+
"infer_steps": 4,
3+
"target_video_length": 81,
4+
"motion_frame": 9,
5+
"target_fps": 25,
6+
"video_duration": 300,
7+
"infinitetalk_mode": "single",
8+
"infinitetalk_size": "infinitetalk-720",
9+
"dit_quantized_ckpt": "/path/to/InfiniteTalk/seko/InfiniteTalk-4StepDistill-Mean-w-ITAudioAdaptorV6.1-fp8.safetensors",
10+
"dit_quantized": true,
11+
"dit_quant_scheme": "fp8-sgl",
12+
"adapter_model_path": "/path/to/InfiniteTalk/single/single/infinitetalk-fp8.safetensors",
13+
"adapter_quantized": true,
14+
"adapter_quant_scheme": "fp8-sgl",
15+
"audio_encoder_path": "/path/to/InfiniteTalk/TencentGameMate/chinese-wav2vec2-base",
16+
"clip_quantized": true,
17+
"clip_quant_scheme": "fp8-sgl",
18+
"t5_quantized": true,
19+
"t5_quant_scheme": "fp8-sgl",
20+
"audio_sample_rate": 16000,
21+
"sample_shift": 7,
22+
"sample_text_guide_scale": 1.0,
23+
"sample_audio_guide_scale": 1.0,
24+
"enable_cfg": false,
25+
"enable_text_cfg": false,
26+
"use_image_encoder": true,
27+
"feature_caching": "NoCaching",
28+
"cpu_offload": true,
29+
"offload_granularity": "block",
30+
"t5_cpu_offload": false,
31+
"vae_cpu_offload": false,
32+
"clip_cpu_offload": false,
33+
"self_attn_1_type": "sage_attn2",
34+
"cross_attn_1_type": "sage_attn2",
35+
"cross_attn_2_type": "sage_attn2",
36+
"audio_cross_attn_type": "sage_attn2",
37+
"text_len": 512,
38+
"target_height": 720,
39+
"target_width": 1280,
40+
"audio_window": 5,
41+
"infinitetalk_vae_scale": 4,
42+
"infinitetalk_context_tokens": 32,
43+
"infinitetalk_audio_output_dim": 768,
44+
"norm_output_audio": true,
45+
"mxfp8_fuse_enable": false,
46+
"use_timestep_transform": true,
47+
"parallel": {
48+
"seq_p_size": 8,
49+
"seq_p_attn_type": "ulysses-4090"
50+
}
51+
}

configs/infinitetalk/h100/infinitetalk_single_distilled.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
"video_duration": 300,
77
"infinitetalk_mode": "single",
88
"infinitetalk_size": "infinitetalk-720",
9-
"dit_quantized_ckpt": "/path/to/InfiniteTalk/seko/InfiniteTalk-4StepDistill-Mean-w-ITAudioAdaptorV6.1-fp8.safetensors",
9+
"dit_quantized_ckpt": "/data/nvme5/gushiqiao/models/InfiniteTalk/seko/InfiniteTalk-4StepDistill-Mean-w-ITAudioAdaptorV6.1-fp8.safetensors",
1010
"dit_quantized": true,
1111
"dit_quant_scheme": "fp8-sgl",
12-
"adapter_model_path": "/path/to/InfiniteTalk/single/single/infinitetalk-fp8.safetensors",
12+
"adapter_model_path": "/data/nvme5/gushiqiao/models/InfiniteTalk/single/single/infinitetalk-fp8.safetensors",
1313
"adapter_quantized": true,
1414
"adapter_quant_scheme": "fp8-sgl",
15-
"audio_encoder_path": "/path/to/InfiniteTalk/TencentGameMate/chinese-wav2vec2-base",
15+
"audio_encoder_path": "/data/nvme5/gushiqiao/models/InfiniteTalk/TencentGameMate/chinese-wav2vec2-base",
1616
"clip_quantized": true,
1717
"clip_quant_scheme": "fp8-sgl",
1818
"t5_quantized": true,

lightx2v/models/runners/wan/wan_infinitetalk_runner.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from lightx2v.utils.profiler import ProfilingContext4DebugL1, ProfilingContext4DebugL2
2424
from lightx2v.utils.registry_factory import RUNNER_REGISTER
2525
from lightx2v.utils.utils import is_main_process, save_to_video, wan_vae_to_comfy
26+
from lightx2v.utils.va_controller import VAController
2627
from lightx2v_platform.base.global_var import AI_DEVICE
2728

2829
torch_device_module = getattr(torch, AI_DEVICE)
@@ -106,8 +107,11 @@ def __init__(self, config):
106107
self.audio_sample_rate = int(self.config.get("audio_sample_rate", 16000))
107108
self.target_fps = int(self.config.get("target_fps", 25))
108109
self.video_audio_path = None
110+
self.video_audio_array = None
109111
self.cond_video_temp_path = None
110112
self.cond_video_duration = None
113+
self.va_controller = None
114+
self.stream_save_video = False
111115

112116
def init_scheduler(self):
113117
self.scheduler = InfiniteTalkScheduler(self.config)
@@ -289,6 +293,7 @@ def _audio_prepare_multi(self, left_path, right_path, audio_type):
289293
return new_speech1, new_speech2, new_speech1 + new_speech2
290294

291295
def _write_sum_audio(self, input_data, audio_arrays):
296+
self.video_audio_array = np.asarray(audio_arrays, dtype=np.float32)
292297
if sf is not None:
293298
fd, audio_path = tempfile.mkstemp(prefix="infinitetalk_sum_", suffix=".wav")
294299
os.close(fd)
@@ -506,6 +511,7 @@ def _run_input_encoder_local_s2v(self):
506511
self.cond_video_duration = self._get_cond_video_duration(self.cond_file_path)
507512
first_image = self._extract_specific_frame(self.cond_file_path, 0)
508513
self.src_h, self.src_w, self.target_h, self.target_w = self._select_target_size(first_image)
514+
self.input_info.target_shape = [self.target_h, self.target_w]
509515

510516
full_audio_embs = self._prepare_audio_embeddings(input_data)
511517
if any(audio_emb.shape[0] <= 0 for audio_emb in full_audio_embs):
@@ -623,7 +629,7 @@ def init_run(self):
623629

624630
self.cond_image = self._prepare_cond_image(0)
625631
self.cond_frame = None
626-
self.gen_video_list = []
632+
self.gen_video_list = None if self.stream_save_video else []
627633

628634
def get_video_segment_num(self):
629635
if self.expected_frames <= self.frame_num:
@@ -683,6 +689,47 @@ def run_segment(self, segment_idx=0):
683689
self._run_dit_clip(self.dit_inputs)
684690
return self.scheduler.latents
685691

692+
def _should_stream_save_video(self):
693+
return bool(self.config.get("stream_save_video", True) and not self.input_info.return_result_tensor and getattr(self.input_info, "save_result_path", None))
694+
695+
def _init_stream_video_controller(self):
696+
if not self.stream_save_video:
697+
return
698+
self.va_controller = VAController(self)
699+
logger.info(f"init va_recorder: {self.va_controller.recorder} and va_reader: {self.va_controller.reader}")
700+
701+
def _get_audio_segment(self, start_frame, frame_count):
702+
audio_sample_start = int(round(start_frame * self.audio_sample_rate / self.target_fps))
703+
audio_sample_end = int(round((start_frame + frame_count) * self.audio_sample_rate / self.target_fps))
704+
audio_sample_count = max(audio_sample_end - audio_sample_start, 0)
705+
if audio_sample_count == 0:
706+
return torch.zeros(0, dtype=torch.float32)
707+
708+
if self.video_audio_array is None:
709+
return torch.zeros(audio_sample_count, dtype=torch.float32)
710+
711+
audio = self.video_audio_array.reshape(-1)
712+
audio_seg = audio[audio_sample_start : min(audio_sample_end, audio.shape[0])]
713+
if audio_seg.shape[0] < audio_sample_count:
714+
audio_seg = np.pad(audio_seg, (0, audio_sample_count - audio_seg.shape[0]))
715+
return torch.from_numpy(audio_seg.astype(np.float32, copy=False))
716+
717+
def _publish_video_segment(self, videos, start_frame):
718+
if self.va_controller is None or self.va_controller.recorder is None:
719+
return
720+
frame_count = videos.shape[2]
721+
if frame_count <= 0:
722+
return
723+
video_seg = videos[:, :, :frame_count].to(torch.float32)
724+
comfy_video = wan_vae_to_comfy(video_seg.cpu())
725+
audio_seg = self._get_audio_segment(start_frame, frame_count)
726+
self.va_controller.pub_livestream(
727+
comfy_video,
728+
audio_seg,
729+
video_seg.cpu(),
730+
valid_duration=frame_count / self.target_fps,
731+
)
732+
686733
@ProfilingContext4DebugL1(
687734
"End run segment",
688735
recorder_mode=GET_RECORDER_MODE(),
@@ -692,9 +739,19 @@ def run_segment(self, segment_idx=0):
692739
def end_run_segment(self, segment_idx, latents):
693740
videos = self.run_vae_decoder(latents).cpu()
694741
if self.is_first_segment:
695-
self.gen_video_list.append(videos)
742+
output_videos = videos
743+
output_start_frame = 0
696744
else:
697-
self.gen_video_list.append(videos[:, :, self.current_motion_frames_num :])
745+
output_videos = videos[:, :, self.current_motion_frames_num :]
746+
output_start_frame = self.audio_start_idx + self.current_motion_frames_num
747+
748+
valid_frames = min(output_videos.shape[2], max(self.expected_frames - output_start_frame, 0))
749+
if valid_frames > 0:
750+
output_videos = output_videos[:, :, :valid_frames]
751+
if self.stream_save_video:
752+
self._publish_video_segment(output_videos, output_start_frame)
753+
else:
754+
self.gen_video_list.append(output_videos)
698755

699756
if segment_idx < self.video_segment_num - 1:
700757
self.cond_frame = videos[:, :, -self.motion_frame :].to(torch.float32).to(AI_DEVICE)
@@ -706,8 +763,10 @@ def end_run_segment(self, segment_idx, latents):
706763

707764
@ProfilingContext4DebugL2("Run DiT + decode")
708765
def run_main(self):
766+
self.stream_save_video = self._should_stream_save_video()
709767
self.init_run()
710768
self.get_video_segment_num()
769+
self._init_stream_video_controller()
711770

712771
for segment_idx in range(self.video_segment_num):
713772
logger.info(f"start InfiniteTalk segment {segment_idx + 1}/{self.video_segment_num}")
@@ -716,11 +775,19 @@ def run_main(self):
716775
latents = self.run_segment(segment_idx)
717776
self.end_run_segment(segment_idx, latents)
718777

778+
if self.stream_save_video:
779+
return self.process_images_after_vae_decoder()
780+
719781
self.gen_video = torch.cat(self.gen_video_list, dim=2)[:, :, : self.expected_frames].to(torch.float32)
720782
return self.process_images_after_vae_decoder()
721783

722784
@ProfilingContext4DebugL1("Process after vae decoder")
723785
def process_images_after_vae_decoder(self):
786+
if self.stream_save_video:
787+
if self.input_info.save_result_path is not None and is_main_process():
788+
logger.info(f"Video saved to {self.input_info.save_result_path}")
789+
return {"video": None}
790+
724791
self.gen_video_final = wan_vae_to_comfy(self.gen_video)
725792
if self.input_info.return_result_tensor:
726793
return {"video": self.gen_video_final}
@@ -771,8 +838,13 @@ def _mux_audio(video_path, audio_path):
771838
os.remove(tmp_path)
772839

773840
def end_run(self):
841+
if self.va_controller is not None:
842+
self.va_controller.clear()
843+
self.va_controller = None
774844
self._remove_video_audio_path()
775845
self._remove_cond_video_temp_path()
846+
self.video_audio_array = None
847+
self.stream_save_video = False
776848
if hasattr(self, "inputs"):
777849
del self.inputs
778850
torch.cuda.empty_cache()

scripts/infinitetalk/run_infinitetalk_single.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
#!/bin/bash
22

33
# set path firstly
4-
lightx2v_path=/path/to/LightX2V
5-
model_path=/path/to/InfiniteTalk
4+
lightx2v_path=/data/nvme4/gushiqiao/new/debug/LightX2V
5+
model_path=/data/nvme5/gushiqiao/models/InfiniteTalk
66

7-
export CUDA_VISIBLE_DEVICES=0
7+
export CUDA_VISIBLE_DEVICES=7
88

99

1010
# set environment variables
@@ -14,7 +14,7 @@ python -m lightx2v.infer \
1414
--model_cls infinitetalk \
1515
--task s2v \
1616
--model_path $model_path \
17-
--config_json ${lightx2v_path}/configs/infinitetalk/fp8/infinitetalk_single_distilled.json \
17+
--config_json ${lightx2v_path}/configs/infinitetalk/h100/infinitetalk_single_distilled.json \
1818
--prompt "让角色根据音频内容自然说话" \
1919
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
2020
--image_path /data/nvme5/gushiqiao/cases/wecom-temp-3950334-bfa56035a08485356431b5a1c5c28a82.png \
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#!/bin/bash
2+
3+
# set path firstly
4+
lightx2v_path=/data/nvme4/gushiqiao/new/debug/LightX2V
5+
model_path=/data/nvme5/gushiqiao/models/InfiniteTalk
6+
7+
export CUDA_VISIBLE_DEVICES=7
8+
9+
10+
# set environment variables
11+
source ${lightx2v_path}/scripts/base/base.sh
12+
13+
python -m lightx2v.infer \
14+
--model_cls infinitetalk \
15+
--task s2v \
16+
--model_path $model_path \
17+
--config_json ${lightx2v_path}/configs/infinitetalk/5090/infinitetalk_single_distilled_8gpus.json \
18+
--prompt "让角色根据音频内容自然说话" \
19+
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
20+
--image_path /data/nvme5/gushiqiao/cases/wecom-temp-3950334-bfa56035a08485356431b5a1c5c28a82.png \
21+
--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \
22+
--save_result_path ${lightx2v_path}/save_results/infinitetalk_single_720p_dist_8gpus.mp4 \
23+
--seed 42

0 commit comments

Comments
 (0)