Skip to content

Commit 1dc6cf1

Browse files
authored
single call to Separator.separate for all inputs (#233)
* single call to Separator.separate for all inputs * fix unit tests
1 parent b75eb88 commit 1dc6cf1

2 files changed

Lines changed: 16 additions & 29 deletions

File tree

audio_separator/utils/cli.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -179,30 +179,12 @@ def main():
179179
logger.info(f"Model {args.model_filename} downloaded successfully.")
180180
sys.exit(0)
181181

182-
if not hasattr(args, "audio_files"):
183-
parser.print_help()
184-
sys.exit(1)
185-
186-
# Path processing: if a directory is specified, collect all audio files from it
187-
audio_files = []
188-
for path in args.audio_files:
189-
if os.path.isdir(path):
190-
# If the path is a directory, recursively search for all audio files
191-
for root, dirs, files in os.walk(path):
192-
for file in files:
193-
# Check the file extension to ensure it's an audio file
194-
if file.endswith((".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a", ".aiff", ".ac3")): # Add other formats if needed
195-
audio_files.append(os.path.join(root, file))
196-
else:
197-
# If the path is a file, add it to the list
198-
audio_files.append(path)
199-
200-
# If no audio files are found, log an error and exit the program
182+
audio_files = list(getattr(args, "audio_files", []))
201183
if not audio_files:
202-
logger.error("No valid audio files found in the specified path(s).")
184+
parser.print_help()
203185
sys.exit(1)
204186

205-
logger.info(f"Separator version {package_version} beginning with input file(s): {', '.join(audio_files)}")
187+
logger.info(f"Separator version {package_version} beginning with input path(s): {', '.join(audio_files)}")
206188

207189
separator = Separator(
208190
log_formatter=log_formatter,
@@ -234,7 +216,12 @@ def main():
234216
"post_process_threshold": args.vr_post_process_threshold,
235217
"high_end_process": args.vr_high_end_process,
236218
},
237-
demucs_params={"segment_size": args.demucs_segment_size, "shifts": args.demucs_shifts, "overlap": args.demucs_overlap, "segments_enabled": args.demucs_segments_enabled},
219+
demucs_params={
220+
"segment_size": args.demucs_segment_size,
221+
"shifts": args.demucs_shifts,
222+
"overlap": args.demucs_overlap,
223+
"segments_enabled": args.demucs_segments_enabled,
224+
},
238225
mdxc_params={
239226
"segment_size": args.mdxc_segment_size,
240227
"batch_size": args.mdxc_batch_size,
@@ -246,6 +233,5 @@ def main():
246233

247234
separator.load_model(model_filename=args.model_filename)
248235

249-
for audio_file in audio_files:
250-
output_files = separator.separate(audio_file, custom_output_names=args.custom_output_names)
251-
logger.info(f"Separation complete! Output file(s): {' '.join(output_files)}")
236+
output_files = separator.separate(audio_files, custom_output_names=args.custom_output_names)
237+
logger.info(f"Separation complete! Output file(s): {' '.join(output_files)}")

tests/unit/test_cli.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,9 @@ def test_cli_multiple_filenames():
7272
# Call the main function
7373
main()
7474

75-
# Check if separate was called twice (once for each input file)
76-
assert mock_separate.call_count == 2
75+
mock_separate.assert_called_once()
76+
args, kwargs = mock_separate.call_args
77+
assert args[0] == ["test1.mp3", "test2.mp3"]
7778

7879
# Check if the logger captured information about both files
7980
log_messages = [call[0][0] for call in mock_logger.info.call_args_list]
@@ -258,7 +259,7 @@ def test_cli_custom_output_names_argument(common_expected_args):
258259

259260
# Assertions
260261
mock_separator.assert_called_once_with(**common_expected_args)
261-
mock_separator_instance.separate.assert_called_once_with("test_audio.mp3", custom_output_names=custom_names)
262+
mock_separator_instance.separate.assert_called_once_with(["test_audio.mp3"], custom_output_names=custom_names)
262263

263264

264265
# Test using custom_output_names arguments
@@ -280,4 +281,4 @@ def test_cli_demucs_output_names_argument(common_expected_args):
280281

281282
# Assertions
282283
mock_separator.assert_called_once_with(**common_expected_args)
283-
mock_separator_instance.separate.assert_called_once_with("test_audio.mp3", custom_output_names=demucs_output_names)
284+
mock_separator_instance.separate.assert_called_once_with(["test_audio.mp3"], custom_output_names=demucs_output_names)

0 commit comments

Comments
 (0)