Skip to content

Commit a0eb0d8

Browse files
authored
change mdxc to process all stems (#210)
1 parent aa09eb8 commit a0eb0d8

1 file changed

Lines changed: 68 additions & 22 deletions

File tree

audio_separator/separator/architectures/mdxc_separator.py

Lines changed: 68 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,11 @@ def __init__(self, common_config, arch_config):
5050
# • Dropping the pitch may take more processing time but works well for tracks with high-pitched vocals.
5151
self.pitch_shift = arch_config.get("pitch_shift", 0)
5252

53+
self.process_all_stems = arch_config.get("process_all_stems", True)
54+
5355
self.logger.debug(f"MDXC arch params: batch_size={self.batch_size}, segment_size={self.segment_size}, overlap={self.overlap}")
5456
self.logger.debug(f"MDXC arch params: override_model_segment_size={self.override_model_segment_size}, pitch_shift={self.pitch_shift}")
57+
self.logger.debug(f"MDXC multi-stem params: process_all_stems={self.process_all_stems}")
5558

5659
self.is_roformer = "is_roformer" in self.model_data
5760

@@ -143,31 +146,74 @@ def separate(self, audio_file_path, custom_output_names=None):
143146

144147
if isinstance(source, dict):
145148
self.logger.debug("Source is a dict, processing each stem...")
149+
150+
stem_list = []
151+
if self.model_data_cfgdict.training.target_instrument:
152+
stem_list = [self.model_data_cfgdict.training.target_instrument]
153+
else:
154+
stem_list = self.model_data_cfgdict.training.instruments
155+
156+
self.logger.debug(f"Available stems: {stem_list}")
157+
158+
is_multi_stem_model = len(stem_list) > 2
159+
should_process_all_stems = self.process_all_stems and is_multi_stem_model
160+
161+
if should_process_all_stems:
162+
self.logger.debug("Processing all stems from multi-stem model...")
163+
for stem_name in stem_list:
164+
stem_output_path = self.get_stem_output_path(stem_name, custom_output_names)
165+
stem_source = spec_utils.normalize(
166+
wave=source[stem_name],
167+
max_peak=self.normalization_threshold,
168+
min_peak=self.amplification_threshold
169+
).T
170+
171+
self.logger.info(f"Saving {stem_name} stem to {stem_output_path}...")
172+
self.final_process(stem_output_path, stem_source, stem_name)
173+
output_files.append(stem_output_path)
174+
else:
175+
# Standard processing for primary and secondary stems
176+
if not isinstance(self.primary_source, np.ndarray):
177+
self.logger.debug(f"Normalizing primary source for primary stem {self.primary_stem_name}...")
178+
self.primary_source = spec_utils.normalize(
179+
wave=source[self.primary_stem_name],
180+
max_peak=self.normalization_threshold,
181+
min_peak=self.amplification_threshold
182+
).T
183+
184+
if not isinstance(self.secondary_source, np.ndarray):
185+
self.logger.debug(f"Normalizing secondary source for secondary stem {self.secondary_stem_name}...")
186+
self.secondary_source = spec_utils.normalize(
187+
wave=source[self.secondary_stem_name],
188+
max_peak=self.normalization_threshold,
189+
min_peak=self.amplification_threshold
190+
).T
191+
192+
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
193+
self.secondary_stem_output_path = self.get_stem_output_path(self.secondary_stem_name, custom_output_names)
194+
195+
self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...")
196+
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
197+
output_files.append(self.secondary_stem_output_path)
198+
199+
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
200+
self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names)
201+
202+
self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...")
203+
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
204+
output_files.append(self.primary_stem_output_path)
146205

147-
if not isinstance(self.primary_source, np.ndarray):
148-
self.logger.debug(f"Normalizing primary source for primary stem {self.primary_stem_name}...")
149-
self.primary_source = spec_utils.normalize(wave=source[self.primary_stem_name], max_peak=self.normalization_threshold, min_peak=self.amplification_threshold).T
150-
151-
if not isinstance(self.secondary_source, np.ndarray):
152-
self.logger.debug(f"Normalizing secondary source for secondary stem {self.secondary_stem_name}...")
153-
self.secondary_source = spec_utils.normalize(wave=source[self.secondary_stem_name], max_peak=self.normalization_threshold, min_peak=self.amplification_threshold).T
154-
155-
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
156-
self.secondary_stem_output_path = self.get_stem_output_path(self.secondary_stem_name, custom_output_names)
157-
158-
self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...")
159-
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
160-
output_files.append(self.secondary_stem_output_path)
161-
162-
if not isinstance(source, dict) or not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
163-
self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names)
206+
else:
207+
# Handle case when source is not a dictionary (single source model)
208+
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
209+
self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names)
164210

165-
if not isinstance(self.primary_source, np.ndarray):
166-
self.primary_source = source.T
211+
if not isinstance(self.primary_source, np.ndarray):
212+
self.primary_source = source.T
167213

168-
self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...")
169-
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
170-
output_files.append(self.primary_stem_output_path)
214+
self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...")
215+
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
216+
output_files.append(self.primary_stem_output_path)
171217

172218
return output_files
173219

0 commit comments

Comments
 (0)