22import io
33import json
44import os
5- import subprocess
65import warnings
76from dataclasses import dataclass
87from typing import Dict , List , Optional , Tuple , Union
@@ -362,27 +361,35 @@ def get_audio_files_from_audio_path(self, audio_path):
362361
363362 return audio_files , mask_files
364363
364+ def _get_image_resize_kwargs (self ):
365+ input_info = getattr (self , "input_info" , None )
366+ return {
367+ "resize_mode" : (getattr (input_info , "resize_mode" , None ) if input_info is not None else None ) or self .config .get ("resize_mode" , "adaptive" ),
368+ "bucket_shape" : self .config .get ("bucket_shape" , None ),
369+ "fixed_area" : (getattr (input_info , "fixed_area" , None ) if input_info is not None else None ) or self .config .get ("fixed_area" , None ),
370+ "fixed_shape" : self .config .get ("fixed_shape" , None ),
371+ }
372+
373+ def _resolve_patched_spatial_size (self , h , w ):
374+ patched_h = h // self .config ["vae_stride" ][1 ] // self .config ["patch_size" ][1 ]
375+ patched_w = w // self .config ["vae_stride" ][2 ] // self .config ["patch_size" ][2 ]
376+ patched_h , patched_w = get_optimal_patched_size_with_sp (patched_h , patched_w , 1 )
377+ latent_h = patched_h * self .config ["patch_size" ][1 ]
378+ latent_w = patched_w * self .config ["patch_size" ][2 ]
379+ target_shape = [latent_h * self .config ["vae_stride" ][1 ], latent_w * self .config ["vae_stride" ][2 ]]
380+ return target_shape , patched_h , patched_w
381+
365382 def process_single_mask (self , mask_file ):
366383 mask_img = load_image (mask_file )
367384 mask_img = TF .to_tensor (mask_img ).sub_ (0.5 ).div_ (0.5 ).unsqueeze (0 ).to (AI_DEVICE )
368385
369386 if mask_img .shape [1 ] == 3 : # If it is an RGB three-channel image
370387 mask_img = mask_img [:, :1 ] # Only take the first channel
371388
372- mask_img , h , w = resize_image (
373- mask_img ,
374- resize_mode = self .config .get ("resize_mode" , "adaptive" ),
375- bucket_shape = self .config .get ("bucket_shape" , None ),
376- fixed_area = self .config .get ("fixed_area" , None ),
377- fixed_shape = self .config .get ("fixed_shape" , None ),
378- )
379-
380- mask_latent = torch .nn .functional .interpolate (
381- mask_img , # (1, 1, H, W)
382- size = (h // 16 , w // 16 ),
383- mode = "bicubic" ,
384- )
385-
389+ mask_img , h , w = resize_image (mask_img , ** self ._get_image_resize_kwargs ())
390+ target_shape , patched_h , patched_w = self ._resolve_patched_spatial_size (h , w )
391+ mask_img = F .interpolate (mask_img , size = (target_shape [0 ], target_shape [1 ]), mode = "bicubic" )
392+ mask_latent = F .interpolate (mask_img , size = (patched_h , patched_w ), mode = "bicubic" )
386393 mask_latent = (mask_latent > 0 ).to (torch .int8 )
387394 return mask_latent
388395
@@ -393,19 +400,9 @@ def read_image_input(self, img_path):
393400 ref_img = load_image (img_path )
394401 ref_img = TF .to_tensor (ref_img ).sub_ (0.5 ).div_ (0.5 ).unsqueeze (0 ).to (AI_DEVICE )
395402
396- ref_img , h , w = resize_image (
397- ref_img ,
398- resize_mode = getattr (self .input_info , "resize_mode" , None ) or self .config .get ("resize_mode" , "adaptive" ),
399- bucket_shape = self .config .get ("bucket_shape" , None ),
400- fixed_area = getattr (self .input_info , "fixed_area" , None ) or self .config .get ("fixed_area" , None ),
401- fixed_shape = self .config .get ("fixed_shape" , None ),
402- )
403+ ref_img , h , w = resize_image (ref_img , ** self ._get_image_resize_kwargs ())
403404 logger .info (f"[wan_audio] resize_image target_h: { h } , target_w: { w } " )
404- patched_h = h // self .config ["vae_stride" ][1 ] // self .config ["patch_size" ][1 ]
405- patched_w = w // self .config ["vae_stride" ][2 ] // self .config ["patch_size" ][2 ]
406-
407- patched_h , patched_w = get_optimal_patched_size_with_sp (patched_h , patched_w , 1 )
408-
405+ target_shape , patched_h , patched_w = self ._resolve_patched_spatial_size (h , w )
409406 latent_h = patched_h * self .config ["patch_size" ][1 ]
410407 latent_w = patched_w * self .config ["patch_size" ][2 ]
411408
@@ -415,11 +412,9 @@ def read_image_input(self, img_path):
415412 else :
416413 latent_shape = self .get_latent_shape_with_lat_hw (latent_h , latent_w )
417414
418- target_shape = [latent_h * self .config ["vae_stride" ][1 ], latent_w * self .config ["vae_stride" ][2 ]]
419-
420415 logger .info (f"[wan_audio] target_h: { target_shape [0 ]} , target_w: { target_shape [1 ]} , latent_h: { latent_h } , latent_w: { latent_w } " )
421416
422- ref_img = torch . nn . functional .interpolate (ref_img , size = (target_shape [0 ], target_shape [1 ]), mode = "bicubic" )
417+ ref_img = F .interpolate (ref_img , size = (target_shape [0 ], target_shape [1 ]), mode = "bicubic" )
423418 return ref_img , latent_shape , target_shape
424419
425420 @ProfilingContext4DebugL1 (
@@ -732,26 +727,11 @@ def run_main(self):
732727
733728 # fixed audio segments inputs
734729 if self .va_controller .reader is None :
735- # Save paths before super().run_main() clears input_info
736- out_path = getattr (self .input_info , "save_result_path" , None )
737- orig_audio = (getattr (self .input_info , "audio_path" , "" ) or "" ).split ("," )[0 ].strip () or None
738730 result = super ().run_main ()
739731 # Stop VARecorder so ffmpeg finishes writing the file
740732 if self .va_controller is not None :
741733 self .va_controller .clear ()
742734 self .va_controller = None
743- # Re-mux with original audio to replace 16kHz audio
744- if out_path and orig_audio and os .path .isfile (out_path ) and os .path .isfile (orig_audio ):
745- try :
746- tmp = out_path + ".remux.mp4"
747- cmd = ["ffmpeg" , "-y" , "-i" , out_path , "-i" , orig_audio ,
748- "-c:v" , "copy" , "-c:a" , "copy" ,
749- "-map" , "0:v:0" , "-map" , "1:a:0" , "-shortest" , tmp ]
750- subprocess .run (cmd , check = True , stdout = subprocess .DEVNULL , stderr = subprocess .DEVNULL )
751- os .replace (tmp , out_path )
752- logger .info (f"[wan_audio] Re-muxed with original audio: { orig_audio } " )
753- except Exception as exc :
754- logger .warning (f"[wan_audio] Re-mux failed: { exc } " )
755735 return result
756736
757737 self .va_controller .start ()
0 commit comments