Skip to content

Commit 6185a0e

Browse files
committed
big update roformer and minor other
1 parent f1c2c74 commit 6185a0e

31 files changed

Lines changed: 6507 additions & 1090 deletions

PolUVR/separator/architectures/mdxc_separator.py

Lines changed: 50 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99

1010
from PolUVR.separator.common_separator import CommonSeparator
1111
from PolUVR.separator.uvr_lib_v5 import spec_utils
12-
from PolUVR.separator.uvr_lib_v5.roformer.bs_roformer import BSRoformer
13-
from PolUVR.separator.uvr_lib_v5.roformer.mel_band_roformer import MelBandRoformer
1412
from PolUVR.separator.uvr_lib_v5.tfc_tdf_v3 import TFC_TDF_net
13+
# Roformer direct constructors removed; loading handled via RoformerLoader in CommonSeparator.
1514

1615

1716
class MDXCSeparator(CommonSeparator):
@@ -88,7 +87,8 @@ def __init__(self, common_config, arch_config):
8887
self.logger.debug(f"MDXC arch params: override_model_segment_size={self.override_model_segment_size}, pitch_shift={self.pitch_shift}")
8988
self.logger.debug(f"MDXC multi-stem params: process_all_stems={self.process_all_stems}")
9089

91-
self.is_roformer = "is_roformer" in self.model_data
90+
# Align Roformer detection flag with CommonSeparator to ensure consistent stats/logging
91+
self.is_roformer = getattr(self, "is_roformer_model", False)
9292

9393
self.load_model()
9494

@@ -115,28 +115,29 @@ def load_model(self):
115115

116116
try:
117117
if self.is_roformer:
118-
self.logger.debug("Loading Roformer model...")
119-
120-
# Determine the model type based on the configuration and instantiate it
121-
if "num_bands" in self.model_data_cfgdict.model:
122-
self.logger.debug("Loading MelBandRoformer model...")
123-
model = MelBandRoformer(**self.model_data_cfgdict.model)
124-
elif "freqs_per_bands" in self.model_data_cfgdict.model:
125-
self.logger.debug("Loading BSRoformer model...")
126-
model = BSRoformer(**self.model_data_cfgdict.model)
118+
# Use the RoformerLoader exclusively; no legacy fallback
119+
self.logger.debug("Loading Roformer model via RoformerLoader...")
120+
result = self.roformer_loader.load_model(
121+
model_path=self.model_path,
122+
config=self.model_data,
123+
device=str(self.torch_device),
124+
)
125+
126+
if getattr(result, "success", False) and getattr(result, "model", None) is not None:
127+
self.model_run = result.model
128+
self.model_run.to(self.torch_device).eval()
127129
else:
128-
raise ValueError("Unknown Roformer model type in the configuration.")
129-
130-
# Load model checkpoint
131-
checkpoint = torch.load(self.model_path, map_location="cpu", weights_only=True)
132-
self.model_run = model if not isinstance(model, torch.nn.DataParallel) else model.module
133-
self.model_run.load_state_dict(checkpoint)
134-
self.model_run.to(self.torch_device).eval()
130+
error_msg = getattr(result, "error_message", "RoformerLoader unsuccessful")
131+
self.logger.error(f"Failed to load Roformer model: {error_msg}")
132+
raise RuntimeError(error_msg)
135133

136134
else:
137135
self.logger.debug("Loading TFC_TDF_net model...")
138136
self.model_run = TFC_TDF_net(self.model_data_cfgdict, device=self.torch_device)
139-
self.model_run.load_state_dict(torch.load(self.model_path, map_location=self.torch_device))
137+
self.logger.debug("Loading model onto cpu")
138+
# For some reason loading the state onto a hardware accelerated devices causes issues,
139+
# so we load it onto CPU first then move it to the device
140+
self.model_run.load_state_dict(torch.load(self.model_path, map_location="cpu"))
140141
self.model_run.to(self.torch_device).eval()
141142

142143
except RuntimeError as e:
@@ -273,9 +274,12 @@ def pitch_fix(self, source, sr_pitched, orig_mix):
273274
return source
274275

275276
def overlap_add(self, result, x, weights, start, length):
276-
"""Adds the overlapping part of the result to the result tensor.
277-
"""
278-
result[..., start : start + length] += x[..., :length] * weights[:length]
277+
"""Adds the overlapping part of the result to the result tensor."""
278+
# Guard against minor shape mismatches from model output length
279+
# Use the minimum of provided lengths to avoid broadcasting errors
280+
safe_len = min(length, x.shape[-1], weights.shape[0])
281+
if safe_len > 0:
282+
result[..., start : start + safe_len] += x[..., :safe_len] * weights[:safe_len]
279283
return result
280284

281285
def demix(self, mix: np.ndarray) -> dict:
@@ -311,11 +315,21 @@ def demix(self, mix: np.ndarray) -> dict:
311315
self.logger.debug(f"Number of stems: {num_stems}")
312316

313317
# chunk_size aka "C" in UVR
314-
chunk_size = self.model_data_cfgdict.audio.hop_length * (mdx_segment_size - 1)
315-
self.logger.debug(f"Chunk size: {chunk_size}")
316-
317-
step = int(self.overlap * self.model_data_cfgdict.audio.sample_rate)
318-
self.logger.debug(f"Step: {step}")
318+
# IMPORTANT: For Roformer models, use the model's STFT hop length to derive the temporal chunk size
319+
stft_hop_len = getattr(self.model_data_cfgdict.model, "stft_hop_length", None)
320+
if stft_hop_len is None:
321+
# Fallback to audio.hop_length if not present, but log for visibility
322+
stft_hop_len = self.model_data_cfgdict.audio.hop_length
323+
self.logger.debug(f"Model.stft_hop_length missing; falling back to audio.hop_length={stft_hop_len}")
324+
325+
chunk_size = int(stft_hop_len) * (int(mdx_segment_size) - 1)
326+
self.logger.debug(f"Chunk size: {chunk_size} (using stft_hop_length={stft_hop_len} and dim_t={mdx_segment_size})")
327+
328+
# Align step to chunk_size by default for Roformer to avoid stride mismatches
329+
# If a user-specified overlap (in seconds) results in a step larger than chunk_size, clamp it
330+
desired_step = int(self.overlap * self.model_data_cfgdict.audio.sample_rate)
331+
step = chunk_size if desired_step <= 0 else min(desired_step, chunk_size)
332+
self.logger.debug(f"Step: {step} (desired={desired_step})")
319333

320334
# Create a weighting table and convert it to a PyTorch tensor
321335
window = torch.tensor(signal.windows.hamming(chunk_size), dtype=torch.float32)
@@ -340,11 +354,16 @@ def demix(self, mix: np.ndarray) -> dict:
340354
# Perform overlap_add on CPU
341355
if i + chunk_size > mix.shape[1]:
342356
# Fixed to correctly add to the end of the tensor
343-
result = self.overlap_add(result, x, window, result.shape[-1] - chunk_size, length)
344-
counter[..., result.shape[-1] - chunk_size :] += window[:length]
357+
start_idx = result.shape[-1] - chunk_size
358+
result = self.overlap_add(result, x, window, start_idx, length)
359+
safe_len = min(length, x.shape[-1], window.shape[0])
360+
if safe_len > 0:
361+
counter[..., start_idx : start_idx + safe_len] += window[:safe_len]
345362
else:
346363
result = self.overlap_add(result, x, window, i, length)
347-
counter[..., i : i + length] += window[:length]
364+
safe_len = min(length, x.shape[-1], window.shape[0])
365+
if safe_len > 0:
366+
counter[..., i : i + safe_len] += window[:safe_len]
348367

349368
inferenced_outputs = result / counter.clamp(min=1e-10)
350369

PolUVR/separator/common_separator.py

Lines changed: 86 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515

1616

1717
class CommonSeparator:
18-
"""This class contains the common methods and attributes common to all architecture-specific Separator classes.
19-
"""
18+
"""This class contains the common methods and attributes common to all architecture-specific Separator classes."""
2019

2120
ALL_STEMS = "All Stems"
2221
VOCAL_STEM = "Vocals"
@@ -84,6 +83,12 @@ def __init__(self, config):
8483
self.invert_using_spec = config.get("invert_using_spec")
8584
self.sample_rate = config.get("sample_rate")
8685
self.use_soundfile = config.get("use_soundfile")
86+
87+
# Roformer-specific loading support
88+
self.roformer_loader = None
89+
self.is_roformer_model = self._detect_roformer_model()
90+
if self.is_roformer_model:
91+
self._initialize_roformer_loader()
8792

8893
# Model specific properties
8994

@@ -138,13 +143,11 @@ def secondary_stem(self, primary_stem: str):
138143
return secondary_stem
139144

140145
def separate(self, audio_file_path):
141-
"""Placeholder method for separating audio sources. Should be overridden by subclasses.
142-
"""
146+
"""Placeholder method for separating audio sources. Should be overridden by subclasses."""
143147
raise NotImplementedError("This method should be overridden by subclasses.")
144148

145149
def final_process(self, stem_path, source, stem_name):
146-
"""Finalizes the processing of a stem by writing the audio to a file and returning the processed source.
147-
"""
150+
"""Finalizes the processing of a stem by writing the audio to a file and returning the processed source."""
148151
self.logger.debug(f"Finalizing {stem_name} stem processing and writing audio...")
149152
self.write_audio(stem_path, source)
150153

@@ -189,7 +192,7 @@ def cached_model_source_holder(self, model_architecture, sources, model_name=Non
189192
"""Update the dictionary for the given model_architecture with the new model name and its sources.
190193
Use the model_architecture as a key to access the corresponding cache source mapper dictionary.
191194
"""
192-
self.cached_sources_map[model_architecture] = {**self.cached_sources_map.get(model_architecture, {}), model_name: sources}
195+
self.cached_sources_map[model_architecture] = {**self.cached_sources_map.get(model_architecture, {}), **{model_name: sources}}
193196

194197
def prepare_mix(self, mix):
195198
"""Prepares the mix for processing. This includes loading the audio from a file if necessary,
@@ -246,8 +249,7 @@ def write_audio(self, stem_path: str, stem_source):
246249
self.write_audio_pydub(stem_path, stem_source)
247250

248251
def write_audio_pydub(self, stem_path: str, stem_source):
249-
"""Writes the separated audio source to a file using pydub (ffmpeg)
250-
"""
252+
"""Writes the separated audio source to a file using pydub (ffmpeg)."""
251253
self.logger.debug(f"Entering write_audio_pydub with stem_path: {stem_path}")
252254

253255
stem_source = spec_utils.normalize(wave=stem_source, max_peak=self.normalization_threshold, min_peak=self.amplification_threshold)
@@ -305,10 +307,21 @@ def write_audio_pydub(self, stem_path: str, stem_source):
305307
self.logger.error(f"Error exporting audio file: {e}")
306308

307309
def write_audio_soundfile(self, stem_path: str, stem_source):
308-
"""Writes the separated audio source to a file using soundfile library.
309-
"""
310+
"""Writes the separated audio source to a file using soundfile library."""
310311
self.logger.debug(f"Entering write_audio_soundfile with stem_path: {stem_path}")
311312

313+
stem_source = spec_utils.normalize(wave=stem_source, max_peak=self.normalization_threshold, min_peak=self.amplification_threshold)
314+
315+
# Check if the numpy array is empty or contains very low values
316+
if np.max(np.abs(stem_source)) < 1e-6:
317+
self.logger.warning("Warning: stem_source array is near-silent or empty.")
318+
return
319+
320+
# If output_dir is specified, create it and join it with stem_path
321+
if self.output_dir:
322+
os.makedirs(self.output_dir, exist_ok=True)
323+
stem_path = os.path.join(self.output_dir, stem_path)
324+
312325
# Correctly interleave stereo channels if needed
313326
if stem_source.shape[1] == 2:
314327
# If the audio is already interleaved, ensure it's in the correct order
@@ -327,9 +340,7 @@ def write_audio_soundfile(self, stem_path: str, stem_source):
327340

328341
self.logger.debug(f"Interleaved audio data shape: {stem_source.shape}")
329342

330-
"""
331-
Write audio using soundfile (for formats other than M4A).
332-
"""
343+
"""Write audio using soundfile (for formats other than M4A)."""
333344
# Save audio using soundfile
334345
try:
335346
# Specify the subtype to define the sample width
@@ -339,8 +350,7 @@ def write_audio_soundfile(self, stem_path: str, stem_source):
339350
self.logger.error(f"Error exporting audio file: {e}")
340351

341352
def clear_gpu_cache(self):
342-
"""This method clears the GPU cache to free up memory.
343-
"""
353+
"""This method clears the GPU cache to free up memory."""
344354
self.logger.debug("Running garbage collection...")
345355
gc.collect()
346356
if self.torch_device == torch.device("mps"):
@@ -351,8 +361,7 @@ def clear_gpu_cache(self):
351361
torch.cuda.empty_cache()
352362

353363
def clear_file_specific_paths(self):
354-
"""Clears the file-specific variables which need to be cleared between processing different audio inputs.
355-
"""
364+
"""Clears the file-specific variables which need to be cleared between processing different audio inputs."""
356365
self.logger.info("Clearing input audio file paths, sources and stems...")
357366

358367
self.audio_file_path = None
@@ -365,16 +374,14 @@ def clear_file_specific_paths(self):
365374
self.secondary_stem_output_path = None
366375

367376
def sanitize_filename(self, filename):
368-
"""Cleans the filename by replacing invalid characters with underscores.
369-
"""
377+
"""Cleans the filename by replacing invalid characters with underscores."""
370378
sanitized = re.sub(r'[<>:"/\\|?*]', "_", filename)
371379
sanitized = re.sub(r"_+", "_", sanitized)
372380
sanitized = sanitized.strip("_. ")
373381
return sanitized
374382

375383
def get_stem_output_path(self, stem_name, custom_output_names):
376-
"""Gets the output path for a stem based on the stem name and custom output names.
377-
"""
384+
"""Gets the output path for a stem based on the stem name and custom output names."""
378385
# Convert custom_output_names keys to lowercase for case-insensitive comparison
379386
if custom_output_names:
380387
custom_output_names_lower = {k.lower(): v for k, v in custom_output_names.items()}
@@ -389,3 +396,60 @@ def get_stem_output_path(self, stem_name, custom_output_names):
389396

390397
filename = f"{sanitized_audio_base}_({sanitized_stem_name})_{sanitized_model_name}.{self.output_format.lower()}"
391398
return os.path.join(filename)
399+
400+
def _detect_roformer_model(self):
401+
"""Detect if the current model is a Roformer model.
402+
403+
Returns:
404+
bool: True if this is a Roformer model, False otherwise
405+
"""
406+
if not self.model_data:
407+
return False
408+
409+
# Check for explicit Roformer flag
410+
if self.model_data.get("is_roformer", False):
411+
return True
412+
413+
# Check model path for Roformer indicators
414+
if self.model_path and "roformer" in self.model_path.lower():
415+
return True
416+
417+
# Check model name for Roformer indicators
418+
if self.model_name and "roformer" in self.model_name.lower():
419+
return True
420+
421+
return False
422+
423+
def _initialize_roformer_loader(self):
424+
"""Initialize the Roformer loader for this model."""
425+
try:
426+
from .roformer.roformer_loader import RoformerLoader
427+
self.roformer_loader = RoformerLoader()
428+
self.logger.debug("Initialized Roformer loader for CommonSeparator")
429+
except ImportError as e:
430+
self.logger.warning(f"Could not import RoformerLoader: {e}")
431+
self.roformer_loader = None
432+
433+
def get_roformer_loading_stats(self):
434+
"""Get Roformer loading statistics if available.
435+
436+
Returns:
437+
dict: Loading statistics or empty dict if not available
438+
"""
439+
if self.roformer_loader:
440+
return self.roformer_loader.get_loading_stats()
441+
return {}
442+
443+
def validate_roformer_config(self, config, model_type):
444+
"""Validate Roformer configuration if loader is available.
445+
446+
Args:
447+
config: Configuration dictionary to validate
448+
model_type: Type of model to validate for
449+
450+
Returns:
451+
bool: True if valid or validation not available, False if invalid
452+
"""
453+
if self.roformer_loader:
454+
return self.roformer_loader.validate_configuration(config, model_type)
455+
return True # Assume valid if no loader available

0 commit comments

Comments
 (0)