@@ -179,30 +179,12 @@ def main():
179179 logger .info (f"Model { args .model_filename } downloaded successfully." )
180180 sys .exit (0 )
181181
182- if not hasattr (args , "audio_files" ):
183- parser .print_help ()
184- sys .exit (1 )
185-
186- # Path processing: if a directory is specified, collect all audio files from it
187- audio_files = []
188- for path in args .audio_files :
189- if os .path .isdir (path ):
190- # If the path is a directory, recursively search for all audio files
191- for root , dirs , files in os .walk (path ):
192- for file in files :
193- # Check the file extension to ensure it's an audio file
194- if file .endswith ((".wav" , ".flac" , ".mp3" , ".ogg" , ".opus" , ".m4a" , ".aiff" , ".ac3" )): # Add other formats if needed
195- audio_files .append (os .path .join (root , file ))
196- else :
197- # If the path is a file, add it to the list
198- audio_files .append (path )
199-
200- # If no audio files are found, log an error and exit the program
182+ audio_files = list (getattr (args , "audio_files" , []))
201183 if not audio_files :
202- logger . error ( "No valid audio files found in the specified path(s)." )
184+ parser . print_help ( )
203185 sys .exit (1 )
204186
205- logger .info (f"Separator version { package_version } beginning with input file (s): { ', ' .join (audio_files )} " )
187+ logger .info (f"Separator version { package_version } beginning with input path (s): { ', ' .join (audio_files )} " )
206188
207189 separator = Separator (
208190 log_formatter = log_formatter ,
@@ -234,7 +216,12 @@ def main():
234216 "post_process_threshold" : args .vr_post_process_threshold ,
235217 "high_end_process" : args .vr_high_end_process ,
236218 },
237- demucs_params = {"segment_size" : args .demucs_segment_size , "shifts" : args .demucs_shifts , "overlap" : args .demucs_overlap , "segments_enabled" : args .demucs_segments_enabled },
219+ demucs_params = {
220+ "segment_size" : args .demucs_segment_size ,
221+ "shifts" : args .demucs_shifts ,
222+ "overlap" : args .demucs_overlap ,
223+ "segments_enabled" : args .demucs_segments_enabled ,
224+ },
238225 mdxc_params = {
239226 "segment_size" : args .mdxc_segment_size ,
240227 "batch_size" : args .mdxc_batch_size ,
@@ -246,6 +233,5 @@ def main():
246233
247234 separator .load_model (model_filename = args .model_filename )
248235
249- for audio_file in audio_files :
250- output_files = separator .separate (audio_file , custom_output_names = args .custom_output_names )
251- logger .info (f"Separation complete! Output file(s): { ' ' .join (output_files )} " )
236+ output_files = separator .separate (audio_files , custom_output_names = args .custom_output_names )
237+ logger .info (f"Separation complete! Output file(s): { ' ' .join (output_files )} " )
0 commit comments