1515
1616
1717class CommonSeparator :
18- """This class contains the common methods and attributes common to all architecture-specific Separator classes.
19- """
18+ """This class contains the common methods and attributes common to all architecture-specific Separator classes."""
2019
2120 ALL_STEMS = "All Stems"
2221 VOCAL_STEM = "Vocals"
@@ -84,6 +83,12 @@ def __init__(self, config):
8483 self .invert_using_spec = config .get ("invert_using_spec" )
8584 self .sample_rate = config .get ("sample_rate" )
8685 self .use_soundfile = config .get ("use_soundfile" )
86+
87+ # Roformer-specific loading support
88+ self .roformer_loader = None
89+ self .is_roformer_model = self ._detect_roformer_model ()
90+ if self .is_roformer_model :
91+ self ._initialize_roformer_loader ()
8792
8893 # Model specific properties
8994
@@ -138,13 +143,11 @@ def secondary_stem(self, primary_stem: str):
138143 return secondary_stem
139144
140145 def separate (self , audio_file_path ):
141- """Placeholder method for separating audio sources. Should be overridden by subclasses.
142- """
146+ """Placeholder method for separating audio sources. Should be overridden by subclasses."""
143147 raise NotImplementedError ("This method should be overridden by subclasses." )
144148
145149 def final_process (self , stem_path , source , stem_name ):
146- """Finalizes the processing of a stem by writing the audio to a file and returning the processed source.
147- """
150+ """Finalizes the processing of a stem by writing the audio to a file and returning the processed source."""
148151 self .logger .debug (f"Finalizing { stem_name } stem processing and writing audio..." )
149152 self .write_audio (stem_path , source )
150153
@@ -189,7 +192,7 @@ def cached_model_source_holder(self, model_architecture, sources, model_name=Non
189192 """Update the dictionary for the given model_architecture with the new model name and its sources.
190193 Use the model_architecture as a key to access the corresponding cache source mapper dictionary.
191194 """
192- self .cached_sources_map [model_architecture ] = {** self .cached_sources_map .get (model_architecture , {}), model_name : sources }
195+ self .cached_sources_map [model_architecture ] = {** self .cached_sources_map .get (model_architecture , {}), ** { model_name : sources } }
193196
194197 def prepare_mix (self , mix ):
195198 """Prepares the mix for processing. This includes loading the audio from a file if necessary,
@@ -246,8 +249,7 @@ def write_audio(self, stem_path: str, stem_source):
246249 self .write_audio_pydub (stem_path , stem_source )
247250
248251 def write_audio_pydub (self , stem_path : str , stem_source ):
249- """Writes the separated audio source to a file using pydub (ffmpeg)
250- """
252+ """Writes the separated audio source to a file using pydub (ffmpeg)."""
251253 self .logger .debug (f"Entering write_audio_pydub with stem_path: { stem_path } " )
252254
253255 stem_source = spec_utils .normalize (wave = stem_source , max_peak = self .normalization_threshold , min_peak = self .amplification_threshold )
@@ -305,10 +307,21 @@ def write_audio_pydub(self, stem_path: str, stem_source):
305307 self .logger .error (f"Error exporting audio file: { e } " )
306308
307309 def write_audio_soundfile (self , stem_path : str , stem_source ):
308- """Writes the separated audio source to a file using soundfile library.
309- """
310+ """Writes the separated audio source to a file using soundfile library."""
310311 self .logger .debug (f"Entering write_audio_soundfile with stem_path: { stem_path } " )
311312
313+ stem_source = spec_utils .normalize (wave = stem_source , max_peak = self .normalization_threshold , min_peak = self .amplification_threshold )
314+
315+ # Check if the numpy array is empty or contains very low values
316+ if np .max (np .abs (stem_source )) < 1e-6 :
317+ self .logger .warning ("Warning: stem_source array is near-silent or empty." )
318+ return
319+
320+ # If output_dir is specified, create it and join it with stem_path
321+ if self .output_dir :
322+ os .makedirs (self .output_dir , exist_ok = True )
323+ stem_path = os .path .join (self .output_dir , stem_path )
324+
312325 # Correctly interleave stereo channels if needed
313326 if stem_source .shape [1 ] == 2 :
314327 # If the audio is already interleaved, ensure it's in the correct order
@@ -327,9 +340,7 @@ def write_audio_soundfile(self, stem_path: str, stem_source):
327340
328341 self .logger .debug (f"Interleaved audio data shape: { stem_source .shape } " )
329342
330- """
331- Write audio using soundfile (for formats other than M4A).
332- """
343+ """Write audio using soundfile (for formats other than M4A)."""
333344 # Save audio using soundfile
334345 try :
335346 # Specify the subtype to define the sample width
@@ -339,8 +350,7 @@ def write_audio_soundfile(self, stem_path: str, stem_source):
339350 self .logger .error (f"Error exporting audio file: { e } " )
340351
341352 def clear_gpu_cache (self ):
342- """This method clears the GPU cache to free up memory.
343- """
353+ """This method clears the GPU cache to free up memory."""
344354 self .logger .debug ("Running garbage collection..." )
345355 gc .collect ()
346356 if self .torch_device == torch .device ("mps" ):
@@ -351,8 +361,7 @@ def clear_gpu_cache(self):
351361 torch .cuda .empty_cache ()
352362
353363 def clear_file_specific_paths (self ):
354- """Clears the file-specific variables which need to be cleared between processing different audio inputs.
355- """
364+ """Clears the file-specific variables which need to be cleared between processing different audio inputs."""
356365 self .logger .info ("Clearing input audio file paths, sources and stems..." )
357366
358367 self .audio_file_path = None
@@ -365,16 +374,14 @@ def clear_file_specific_paths(self):
365374 self .secondary_stem_output_path = None
366375
367376 def sanitize_filename (self , filename ):
368- """Cleans the filename by replacing invalid characters with underscores.
369- """
377+ """Cleans the filename by replacing invalid characters with underscores."""
370378 sanitized = re .sub (r'[<>:"/\\|?*]' , "_" , filename )
371379 sanitized = re .sub (r"_+" , "_" , sanitized )
372380 sanitized = sanitized .strip ("_. " )
373381 return sanitized
374382
375383 def get_stem_output_path (self , stem_name , custom_output_names ):
376- """Gets the output path for a stem based on the stem name and custom output names.
377- """
384+ """Gets the output path for a stem based on the stem name and custom output names."""
378385 # Convert custom_output_names keys to lowercase for case-insensitive comparison
379386 if custom_output_names :
380387 custom_output_names_lower = {k .lower (): v for k , v in custom_output_names .items ()}
@@ -389,3 +396,60 @@ def get_stem_output_path(self, stem_name, custom_output_names):
389396
390397 filename = f"{ sanitized_audio_base } _({ sanitized_stem_name } )_{ sanitized_model_name } .{ self .output_format .lower ()} "
391398 return os .path .join (filename )
399+
400+ def _detect_roformer_model (self ):
401+ """Detect if the current model is a Roformer model.
402+
403+ Returns:
404+ bool: True if this is a Roformer model, False otherwise
405+ """
406+ if not self .model_data :
407+ return False
408+
409+ # Check for explicit Roformer flag
410+ if self .model_data .get ("is_roformer" , False ):
411+ return True
412+
413+ # Check model path for Roformer indicators
414+ if self .model_path and "roformer" in self .model_path .lower ():
415+ return True
416+
417+ # Check model name for Roformer indicators
418+ if self .model_name and "roformer" in self .model_name .lower ():
419+ return True
420+
421+ return False
422+
423+ def _initialize_roformer_loader (self ):
424+ """Initialize the Roformer loader for this model."""
425+ try :
426+ from .roformer .roformer_loader import RoformerLoader
427+ self .roformer_loader = RoformerLoader ()
428+ self .logger .debug ("Initialized Roformer loader for CommonSeparator" )
429+ except ImportError as e :
430+ self .logger .warning (f"Could not import RoformerLoader: { e } " )
431+ self .roformer_loader = None
432+
433+ def get_roformer_loading_stats (self ):
434+ """Get Roformer loading statistics if available.
435+
436+ Returns:
437+ dict: Loading statistics or empty dict if not available
438+ """
439+ if self .roformer_loader :
440+ return self .roformer_loader .get_loading_stats ()
441+ return {}
442+
443+ def validate_roformer_config (self , config , model_type ):
444+ """Validate Roformer configuration if loader is available.
445+
446+ Args:
447+ config: Configuration dictionary to validate
448+ model_type: Type of model to validate for
449+
450+ Returns:
451+ bool: True if valid or validation not available, False if invalid
452+ """
453+ if self .roformer_loader :
454+ return self .roformer_loader .validate_configuration (config , model_type )
455+ return True # Assume valid if no loader available
0 commit comments