Skip to content

Commit 7f5b5db

Browse files
committed
add helper func
1 parent 7b32b43 commit 7f5b5db

1 file changed

Lines changed: 25 additions & 45 deletions

File tree

lightx2v/models/runners/wan/wan_audio_runner.py

Lines changed: 25 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import io
33
import json
44
import os
5-
import subprocess
65
import warnings
76
from dataclasses import dataclass
87
from typing import Dict, List, Optional, Tuple, Union
@@ -362,27 +361,35 @@ def get_audio_files_from_audio_path(self, audio_path):
362361

363362
return audio_files, mask_files
364363

364+
def _get_image_resize_kwargs(self):
365+
input_info = getattr(self, "input_info", None)
366+
return {
367+
"resize_mode": (getattr(input_info, "resize_mode", None) if input_info is not None else None) or self.config.get("resize_mode", "adaptive"),
368+
"bucket_shape": self.config.get("bucket_shape", None),
369+
"fixed_area": (getattr(input_info, "fixed_area", None) if input_info is not None else None) or self.config.get("fixed_area", None),
370+
"fixed_shape": self.config.get("fixed_shape", None),
371+
}
372+
373+
def _resolve_patched_spatial_size(self, h, w):
374+
patched_h = h // self.config["vae_stride"][1] // self.config["patch_size"][1]
375+
patched_w = w // self.config["vae_stride"][2] // self.config["patch_size"][2]
376+
patched_h, patched_w = get_optimal_patched_size_with_sp(patched_h, patched_w, 1)
377+
latent_h = patched_h * self.config["patch_size"][1]
378+
latent_w = patched_w * self.config["patch_size"][2]
379+
target_shape = [latent_h * self.config["vae_stride"][1], latent_w * self.config["vae_stride"][2]]
380+
return target_shape, patched_h, patched_w
381+
365382
def process_single_mask(self, mask_file):
366383
mask_img = load_image(mask_file)
367384
mask_img = TF.to_tensor(mask_img).sub_(0.5).div_(0.5).unsqueeze(0).to(AI_DEVICE)
368385

369386
if mask_img.shape[1] == 3: # If it is an RGB three-channel image
370387
mask_img = mask_img[:, :1] # Only take the first channel
371388

372-
mask_img, h, w = resize_image(
373-
mask_img,
374-
resize_mode=self.config.get("resize_mode", "adaptive"),
375-
bucket_shape=self.config.get("bucket_shape", None),
376-
fixed_area=self.config.get("fixed_area", None),
377-
fixed_shape=self.config.get("fixed_shape", None),
378-
)
379-
380-
mask_latent = torch.nn.functional.interpolate(
381-
mask_img, # (1, 1, H, W)
382-
size=(h // 16, w // 16),
383-
mode="bicubic",
384-
)
385-
389+
mask_img, h, w = resize_image(mask_img, **self._get_image_resize_kwargs())
390+
target_shape, patched_h, patched_w = self._resolve_patched_spatial_size(h, w)
391+
mask_img = F.interpolate(mask_img, size=(target_shape[0], target_shape[1]), mode="bicubic")
392+
mask_latent = F.interpolate(mask_img, size=(patched_h, patched_w), mode="bicubic")
386393
mask_latent = (mask_latent > 0).to(torch.int8)
387394
return mask_latent
388395

@@ -393,19 +400,9 @@ def read_image_input(self, img_path):
393400
ref_img = load_image(img_path)
394401
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(0).to(AI_DEVICE)
395402

396-
ref_img, h, w = resize_image(
397-
ref_img,
398-
resize_mode=getattr(self.input_info, "resize_mode", None) or self.config.get("resize_mode", "adaptive"),
399-
bucket_shape=self.config.get("bucket_shape", None),
400-
fixed_area=getattr(self.input_info, "fixed_area", None) or self.config.get("fixed_area", None),
401-
fixed_shape=self.config.get("fixed_shape", None),
402-
)
403+
ref_img, h, w = resize_image(ref_img, **self._get_image_resize_kwargs())
403404
logger.info(f"[wan_audio] resize_image target_h: {h}, target_w: {w}")
404-
patched_h = h // self.config["vae_stride"][1] // self.config["patch_size"][1]
405-
patched_w = w // self.config["vae_stride"][2] // self.config["patch_size"][2]
406-
407-
patched_h, patched_w = get_optimal_patched_size_with_sp(patched_h, patched_w, 1)
408-
405+
target_shape, patched_h, patched_w = self._resolve_patched_spatial_size(h, w)
409406
latent_h = patched_h * self.config["patch_size"][1]
410407
latent_w = patched_w * self.config["patch_size"][2]
411408

@@ -415,11 +412,9 @@ def read_image_input(self, img_path):
415412
else:
416413
latent_shape = self.get_latent_shape_with_lat_hw(latent_h, latent_w)
417414

418-
target_shape = [latent_h * self.config["vae_stride"][1], latent_w * self.config["vae_stride"][2]]
419-
420415
logger.info(f"[wan_audio] target_h: {target_shape[0]}, target_w: {target_shape[1]}, latent_h: {latent_h}, latent_w: {latent_w}")
421416

422-
ref_img = torch.nn.functional.interpolate(ref_img, size=(target_shape[0], target_shape[1]), mode="bicubic")
417+
ref_img = F.interpolate(ref_img, size=(target_shape[0], target_shape[1]), mode="bicubic")
423418
return ref_img, latent_shape, target_shape
424419

425420
@ProfilingContext4DebugL1(
@@ -732,26 +727,11 @@ def run_main(self):
732727

733728
# fixed audio segments inputs
734729
if self.va_controller.reader is None:
735-
# Save paths before super().run_main() clears input_info
736-
out_path = getattr(self.input_info, "save_result_path", None)
737-
orig_audio = (getattr(self.input_info, "audio_path", "") or "").split(",")[0].strip() or None
738730
result = super().run_main()
739731
# Stop VARecorder so ffmpeg finishes writing the file
740732
if self.va_controller is not None:
741733
self.va_controller.clear()
742734
self.va_controller = None
743-
# Re-mux with original audio to replace 16kHz audio
744-
if out_path and orig_audio and os.path.isfile(out_path) and os.path.isfile(orig_audio):
745-
try:
746-
tmp = out_path + ".remux.mp4"
747-
cmd = ["ffmpeg", "-y", "-i", out_path, "-i", orig_audio,
748-
"-c:v", "copy", "-c:a", "copy",
749-
"-map", "0:v:0", "-map", "1:a:0", "-shortest", tmp]
750-
subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
751-
os.replace(tmp, out_path)
752-
logger.info(f"[wan_audio] Re-muxed with original audio: {orig_audio}")
753-
except Exception as exc:
754-
logger.warning(f"[wan_audio] Re-mux failed: {exc}")
755735
return result
756736

757737
self.va_controller.start()

0 commit comments

Comments
 (0)