22import torch , torchvision , imageio , os
33import imageio .v3 as iio
44from PIL import Image
5- import torchaudio
6- from diffsynth .utils .data .audio import read_audio
75
86
97class DataProcessingPipeline :
@@ -249,23 +247,27 @@ def __call__(self, data):
249247class LoadAudio (DataProcessingOperator ):
250248 def __init__ (self , sr = 16000 ):
251249 self .sr = sr
252- def __call__ (self , data : str ):
253250 import librosa
254- input_audio , sample_rate = librosa .load (data , sr = self .sr )
251+ self .audio_loader = librosa .load
252+
253+ def __call__ (self , data : str ):
254+ input_audio , sample_rate = self .audio_loader (data , sr = self .sr )
255255 return input_audio
256256
257257
258258class LoadAudioWithTorchaudio (DataProcessingOperator , FrameSamplerByRateMixin ):
259259
260260 def __init__ (self , num_frames = 121 , time_division_factor = 8 , time_division_remainder = 1 , frame_rate = 24 , fix_frame_rate = True ):
261261 FrameSamplerByRateMixin .__init__ (self , num_frames , time_division_factor , time_division_remainder , frame_rate , fix_frame_rate )
262+ import torchaudio
263+ self .audio_loader = torchaudio .load
262264
263265 def __call__ (self , data : str ):
264266 try :
265267 reader = self .get_reader (data )
266268 num_frames = self .get_num_frames (reader )
267269 duration = num_frames / self .frame_rate
268- waveform , sample_rate = torchaudio . load (data )
270+ waveform , sample_rate = self . audio_loader (data )
269271 target_samples = int (duration * sample_rate )
270272 current_samples = waveform .shape [- 1 ]
271273 if current_samples > target_samples :
@@ -285,10 +287,12 @@ def __init__(self, target_sample_rate=None, target_duration=None):
285287 self .target_sample_rate = target_sample_rate
286288 self .target_duration = target_duration
287289 self .resample = True if target_sample_rate is not None else False
290+ from diffsynth .utils .data .audio import read_audio
291+ self .audio_loader = read_audio
288292
289293 def __call__ (self , data : str ):
290294 try :
291- waveform , sample_rate = read_audio (data , resample = self .resample , resample_rate = self .target_sample_rate )
295+ waveform , sample_rate = self . audio_loader (data , resample = self .resample , resample_rate = self .target_sample_rate )
292296 if self .target_duration is not None :
293297 target_samples = int (self .target_duration * sample_rate )
294298 current_samples = waveform .shape [- 1 ]
0 commit comments