Skip to content

Commit 798148b

Browse files
Attempt at fixing issues with noisy output of MDX models (#203)
* Attempt at fixing issues with noisy output of MDX models This is still a work in progress. * fix issue
1 parent 6a76cd7 commit 798148b

1 file changed

Lines changed: 17 additions & 18 deletions

File tree

audio_separator/separator/architectures/mdx_separator.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -152,34 +152,34 @@ def separate(self, audio_file_path, custom_output_names=None):
152152
mix = self.prepare_mix(self.audio_file_path)
153153

154154
self.logger.debug("Normalizing mix before demixing...")
155+
peak = np.abs(mix).max()
155156
mix = spec_utils.normalize(wave=mix, max_peak=self.normalization_threshold, min_peak=self.amplification_threshold)
156157

157158
# Start the demixing process
158-
source = self.demix(mix)
159+
source = self.demix(mix) * peak
159160
self.logger.debug("Demixing completed.")
160161

162+
163+
if not isinstance(self.primary_source, np.ndarray):
164+
self.primary_source = source.T
165+
161166
# In UVR, the source is cached here if it's a vocal split model, but we're not supporting that yet
162167

163168
# Initialize the list for output files
164169
output_files = []
165170
self.logger.debug("Processing output files...")
166171

167-
# Normalize and transpose the primary source if it's not already an array
168-
if not isinstance(self.primary_source, np.ndarray):
169-
self.logger.debug("Normalizing primary source...")
170-
self.primary_source = spec_utils.normalize(wave=source, max_peak=self.normalization_threshold, min_peak=self.amplification_threshold).T
171-
172172
# Process the secondary source if not already an array
173173
if not isinstance(self.secondary_source, np.ndarray):
174174
self.logger.debug("Producing secondary source: demixing in match_mix mode")
175175
raw_mix = self.demix(mix, is_match_mix=True)
176176

177177
if self.invert_using_spec:
178178
self.logger.debug("Inverting secondary stem using spectogram as invert_using_spec is set to True")
179-
self.secondary_source = spec_utils.invert_stem(raw_mix, source)
179+
self.secondary_source = spec_utils.invert_stem(raw_mix, self.primary_source * self.compensate)
180180
else:
181181
self.logger.debug("Inverting secondary stem by subtracting of transposed demixed stem from transposed original mix")
182-
self.secondary_source = mix.T - source.T
182+
self.secondary_source = (-self.primary_source * self.compensate) + mix.T
183183

184184
# Save and process the secondary stem if needed
185185
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
@@ -192,10 +192,6 @@ def separate(self, audio_file_path, custom_output_names=None):
192192
# Save and process the primary stem if needed
193193
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
194194
self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names)
195-
196-
if not isinstance(self.primary_source, np.ndarray):
197-
self.primary_source = source.T
198-
199195
self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...")
200196
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
201197
output_files.append(self.primary_stem_output_path)
@@ -254,7 +250,15 @@ def initialize_mix(self, mix, is_ckpt=False):
254250
pad = self.gen_size + self.trim - (mix.shape[-1] % self.gen_size)
255251
self.logger.debug(f"Padding calculated: {pad}")
256252
# Add padding at the beginning and the end of the mix
257-
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1)
253+
mixture = np.concatenate(
254+
(
255+
np.zeros((2, self.trim), dtype="float32"), # Pad at the start
256+
mix,
257+
np.zeros((2, pad), dtype="float32"), # Pad in the middle (to match chunk size)
258+
np.zeros((2, self.trim), dtype="float32"), # Pad at the end
259+
),
260+
1
261+
)
258262
# Determine the number of chunks based on the mixture's length
259263
num_chunks = mixture.shape[-1] // self.gen_size
260264
self.logger.debug(f"Mixture shape after padding: {mixture.shape}, Number of chunks: {num_chunks}")
@@ -402,11 +406,6 @@ def demix(self, mix, is_match_mix=False):
402406

403407
# TODO: In UVR, pitch changing happens here. Consider implementing this as a feature.
404408

405-
# Compensates the source if not matching the mix.
406-
if not is_match_mix:
407-
source *= self.compensate
408-
self.logger.debug("Match mix mode; compensate multiplier applied.")
409-
410409
# TODO: In UVR, VR denoise model gets applied here. Consider implementing this as a feature.
411410

412411
self.logger.debug("Demixing process completed.")

0 commit comments

Comments
 (0)