@@ -50,8 +50,11 @@ def __init__(self, common_config, arch_config):
5050 # • Dropping the pitch may take more processing time but works well for tracks with high-pitched vocals.
5151 self .pitch_shift = arch_config .get ("pitch_shift" , 0 )
5252
53+ self .process_all_stems = arch_config .get ("process_all_stems" , True )
54+
5355 self .logger .debug (f"MDXC arch params: batch_size={ self .batch_size } , segment_size={ self .segment_size } , overlap={ self .overlap } " )
5456 self .logger .debug (f"MDXC arch params: override_model_segment_size={ self .override_model_segment_size } , pitch_shift={ self .pitch_shift } " )
57+ self .logger .debug (f"MDXC multi-stem params: process_all_stems={ self .process_all_stems } " )
5558
5659 self .is_roformer = "is_roformer" in self .model_data
5760
@@ -143,31 +146,74 @@ def separate(self, audio_file_path, custom_output_names=None):
143146
144147 if isinstance (source , dict ):
145148 self .logger .debug ("Source is a dict, processing each stem..." )
149+
150+ stem_list = []
151+ if self .model_data_cfgdict .training .target_instrument :
152+ stem_list = [self .model_data_cfgdict .training .target_instrument ]
153+ else :
154+ stem_list = self .model_data_cfgdict .training .instruments
155+
156+ self .logger .debug (f"Available stems: { stem_list } " )
157+
158+ is_multi_stem_model = len (stem_list ) > 2
159+ should_process_all_stems = self .process_all_stems and is_multi_stem_model
160+
161+ if should_process_all_stems :
162+ self .logger .debug ("Processing all stems from multi-stem model..." )
163+ for stem_name in stem_list :
164+ stem_output_path = self .get_stem_output_path (stem_name , custom_output_names )
165+ stem_source = spec_utils .normalize (
166+ wave = source [stem_name ],
167+ max_peak = self .normalization_threshold ,
168+ min_peak = self .amplification_threshold
169+ ).T
170+
171+ self .logger .info (f"Saving { stem_name } stem to { stem_output_path } ..." )
172+ self .final_process (stem_output_path , stem_source , stem_name )
173+ output_files .append (stem_output_path )
174+ else :
175+ # Standard processing for primary and secondary stems
176+ if not isinstance (self .primary_source , np .ndarray ):
177+ self .logger .debug (f"Normalizing primary source for primary stem { self .primary_stem_name } ..." )
178+ self .primary_source = spec_utils .normalize (
179+ wave = source [self .primary_stem_name ],
180+ max_peak = self .normalization_threshold ,
181+ min_peak = self .amplification_threshold
182+ ).T
183+
184+ if not isinstance (self .secondary_source , np .ndarray ):
185+ self .logger .debug (f"Normalizing secondary source for secondary stem { self .secondary_stem_name } ..." )
186+ self .secondary_source = spec_utils .normalize (
187+ wave = source [self .secondary_stem_name ],
188+ max_peak = self .normalization_threshold ,
189+ min_peak = self .amplification_threshold
190+ ).T
191+
192+ if not self .output_single_stem or self .output_single_stem .lower () == self .secondary_stem_name .lower ():
193+ self .secondary_stem_output_path = self .get_stem_output_path (self .secondary_stem_name , custom_output_names )
194+
195+ self .logger .info (f"Saving { self .secondary_stem_name } stem to { self .secondary_stem_output_path } ..." )
196+ self .final_process (self .secondary_stem_output_path , self .secondary_source , self .secondary_stem_name )
197+ output_files .append (self .secondary_stem_output_path )
198+
199+ if not self .output_single_stem or self .output_single_stem .lower () == self .primary_stem_name .lower ():
200+ self .primary_stem_output_path = self .get_stem_output_path (self .primary_stem_name , custom_output_names )
201+
202+ self .logger .info (f"Saving { self .primary_stem_name } stem to { self .primary_stem_output_path } ..." )
203+ self .final_process (self .primary_stem_output_path , self .primary_source , self .primary_stem_name )
204+ output_files .append (self .primary_stem_output_path )
146205
147- if not isinstance (self .primary_source , np .ndarray ):
148- self .logger .debug (f"Normalizing primary source for primary stem { self .primary_stem_name } ..." )
149- self .primary_source = spec_utils .normalize (wave = source [self .primary_stem_name ], max_peak = self .normalization_threshold , min_peak = self .amplification_threshold ).T
150-
151- if not isinstance (self .secondary_source , np .ndarray ):
152- self .logger .debug (f"Normalizing secondary source for secondary stem { self .secondary_stem_name } ..." )
153- self .secondary_source = spec_utils .normalize (wave = source [self .secondary_stem_name ], max_peak = self .normalization_threshold , min_peak = self .amplification_threshold ).T
154-
155- if not self .output_single_stem or self .output_single_stem .lower () == self .secondary_stem_name .lower ():
156- self .secondary_stem_output_path = self .get_stem_output_path (self .secondary_stem_name , custom_output_names )
157-
158- self .logger .info (f"Saving { self .secondary_stem_name } stem to { self .secondary_stem_output_path } ..." )
159- self .final_process (self .secondary_stem_output_path , self .secondary_source , self .secondary_stem_name )
160- output_files .append (self .secondary_stem_output_path )
161-
162- if not isinstance (source , dict ) or not self .output_single_stem or self .output_single_stem .lower () == self .primary_stem_name .lower ():
163- self .primary_stem_output_path = self .get_stem_output_path (self .primary_stem_name , custom_output_names )
206+ else :
207+ # Handle case when source is not a dictionary (single source model)
208+ if not self .output_single_stem or self .output_single_stem .lower () == self .primary_stem_name .lower ():
209+ self .primary_stem_output_path = self .get_stem_output_path (self .primary_stem_name , custom_output_names )
164210
165- if not isinstance (self .primary_source , np .ndarray ):
166- self .primary_source = source .T
211+ if not isinstance (self .primary_source , np .ndarray ):
212+ self .primary_source = source .T
167213
168- self .logger .info (f"Saving { self .primary_stem_name } stem to { self .primary_stem_output_path } ..." )
169- self .final_process (self .primary_stem_output_path , self .primary_source , self .primary_stem_name )
170- output_files .append (self .primary_stem_output_path )
214+ self .logger .info (f"Saving { self .primary_stem_name } stem to { self .primary_stem_output_path } ..." )
215+ self .final_process (self .primary_stem_output_path , self .primary_source , self .primary_stem_name )
216+ output_files .append (self .primary_stem_output_path )
171217
172218 return output_files
173219
0 commit comments