Skip to content

Commit f499d67

Browse files
committed
Minor file synchronization with python-audio-separator
1 parent 0cfee0a commit f499d67

12 files changed

Lines changed: 1886 additions & 34 deletions

PolUVR/separator/architectures/mdxc_separator.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,9 @@ def __init__(self, common_config, arch_config):
9797
self.audio_file_path = None
9898
self.audio_file_base = None
9999

100-
self.is_primary_stem_main_target = False
101-
if self.model_data_cfgdict.training.target_instrument == "Vocals" or len(self.model_data_cfgdict.training.instruments) > 1:
102-
self.is_primary_stem_main_target = True
100+
# Only mark primary stem as main target for single-target models.
101+
# Multi-stem models should not trigger residual subtraction logic.
102+
self.is_primary_stem_main_target = bool(self.model_data_cfgdict.training.target_instrument)
103103

104104
self.logger.debug(f"is_primary_stem_main_target: {self.is_primary_stem_main_target}")
105105

@@ -428,8 +428,8 @@ def demix(self, mix: np.ndarray) -> dict:
428428
self.logger.debug("Deleting accumulated outputs to free up memory")
429429
del accumulated_outputs
430430

431-
if num_stems > 1 or self.is_primary_stem_main_target:
432-
self.logger.debug("Number of stems is greater than 1 or vocals are main target, detaching individual sources and correcting pitch if necessary...")
431+
if num_stems > 1:
432+
self.logger.debug("Number of stems is greater than 1, detaching individual sources and correcting pitch if necessary...")
433433

434434
sources = {}
435435

@@ -445,7 +445,8 @@ def demix(self, mix: np.ndarray) -> dict:
445445
else:
446446
sources[key] = value
447447

448-
if self.is_primary_stem_main_target:
448+
# Residual subtraction is only applicable for single-target models (not multi-stem)
449+
if self.is_primary_stem_main_target and num_stems == 1:
449450
self.logger.debug(f"Primary stem: {self.primary_stem_name} is main target, detaching and matching array shapes if necessary...")
450451
if sources[self.primary_stem_name].shape[1] != orig_mix.shape[1]:
451452
sources[self.primary_stem_name] = spec_utils.match_array_shapes(sources[self.primary_stem_name], orig_mix)
@@ -456,6 +457,7 @@ def demix(self, mix: np.ndarray) -> dict:
456457

457458
self.logger.debug("Returning separated sources")
458459
return sources
460+
459461
self.logger.debug("Processing single source...")
460462

461463
if self.is_roformer:
@@ -469,8 +471,23 @@ def demix(self, mix: np.ndarray) -> dict:
469471
self.logger.debug("Deleting inferenced outputs to free up memory")
470472
del inferenced_outputs
471473

474+
# For single-target models (e.g., karaoke), also return the residual as secondary
472475
if self.pitch_shift != 0:
473476
self.logger.debug("Applying pitch correction for single instrument")
474-
return self.pitch_fix(inferenced_output, sample_rate, orig_mix)
477+
primary = self.pitch_fix(inferenced_output, sample_rate, orig_mix)
478+
else:
479+
primary = inferenced_output
480+
481+
if self.is_primary_stem_main_target:
482+
self.logger.debug("Single-target model detected; computing residual secondary stem from original mix")
483+
# Ensure shapes match before residual subtraction
484+
if primary.shape[1] != orig_mix.shape[1]:
485+
primary = spec_utils.match_array_shapes(primary, orig_mix)
486+
secondary = orig_mix - primary
487+
return {
488+
self.primary_stem_name: primary,
489+
self.secondary_stem_name: secondary,
490+
}
491+
475492
self.logger.debug("Returning inferenced output for single instrument")
476-
return inferenced_output
493+
return primary

PolUVR/separator/architectures/vr_separator.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ def __init__(self, common_config, arch_config: dict):
138138

139139
self.model_run = lambda *args, **kwargs: self.logger.error("Model run method is not initialised yet.")
140140

141-
# This should go away once we refactor to remove soundfile.write and replace with pydub like we did for the MDX rewrite
142-
self.wav_subtype = "PCM_16"
141+
# wav_subtype will be set based on input audio bit depth in prepare_mix()
142+
# Removed hardcoded "PCM_16" to allow bit depth preservation
143143

144144
self.logger.info("VR Separator initialisation complete")
145145

@@ -161,6 +161,32 @@ def separate(self, audio_file_path, custom_output_names=None):
161161
self.audio_file_path = audio_file_path
162162
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
163163

164+
# Detect input audio bit depth for output preservation
165+
try:
166+
import soundfile as sf
167+
info = sf.info(audio_file_path)
168+
self.input_audio_subtype = info.subtype
169+
self.logger.info(f"Input audio subtype: {self.input_audio_subtype}")
170+
171+
# Map subtype to wav_subtype for soundfile and set input_bit_depth for pydub
172+
if "24" in self.input_audio_subtype:
173+
self.wav_subtype = "PCM_24"
174+
self.input_bit_depth = 24
175+
self.logger.info("Detected 24-bit input audio")
176+
elif "32" in self.input_audio_subtype:
177+
self.wav_subtype = "PCM_32"
178+
self.input_bit_depth = 32
179+
self.logger.info("Detected 32-bit input audio")
180+
else:
181+
self.wav_subtype = "PCM_16"
182+
self.input_bit_depth = 16
183+
self.logger.info("Detected 16-bit input audio")
184+
except Exception as e:
185+
self.logger.warning(f"Could not detect input audio bit depth: {e}. Defaulting to PCM_16")
186+
self.wav_subtype = "PCM_16"
187+
self.input_audio_subtype = None
188+
self.input_bit_depth = 16
189+
164190
self.logger.debug(f"Starting separation for input audio file {self.audio_file_path}...")
165191

166192
nn_arch_sizes = [31191, 33966, 56817, 123821, 123812, 129605, 218409, 537238, 537227] # default
@@ -177,7 +203,7 @@ def separate(self, audio_file_path, custom_output_names=None):
177203
self.logger.debug("Determining model capacity...")
178204
self.model_run = nets.determine_model_capacity(self.model_params.param["bins"] * 2, nn_arch_size)
179205

180-
self.model_run.load_state_dict(torch.load(self.model_path, map_location=self.torch_device_cpu))
206+
self.model_run.load_state_dict(torch.load(self.model_path, map_location="cpu"))
181207
self.model_run.to(self.torch_device)
182208
self.logger.debug("Model loaded and moved to device.")
183209

PolUVR/separator/audio_chunking.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""Audio chunking utilities for processing large audio files to prevent OOM errors."""
2+
3+
import os
4+
import logging
5+
from typing import List
6+
from pydub import AudioSegment
7+
8+
9+
class AudioChunker:
10+
"""Handles splitting and merging of large audio files.
11+
12+
This class provides utilities to:
13+
- Split large audio files into fixed-duration chunks
14+
- Merge processed chunks back together with simple concatenation
15+
- Determine if a file should be chunked based on its duration
16+
17+
Example:
18+
>>> chunker = AudioChunker(chunk_duration_seconds=600) # 10-minute chunks
19+
>>> chunk_paths = chunker.split_audio("long_audio.wav", "/tmp/chunks")
20+
>>> # Process each chunk...
21+
>>> output_path = chunker.merge_chunks(processed_chunks, "output.wav")
22+
"""
23+
24+
def __init__(self, chunk_duration_seconds: float, logger: logging.Logger = None):
25+
"""Initialize the AudioChunker.
26+
27+
Args:
28+
chunk_duration_seconds: Duration of each chunk in seconds
29+
logger: Optional logger instance for logging operations
30+
"""
31+
self.chunk_duration_ms = int(chunk_duration_seconds * 1000)
32+
self.logger = logger or logging.getLogger(__name__)
33+
34+
def split_audio(self, input_path: str, output_dir: str) -> List[str]:
35+
"""Split audio file into fixed-size chunks.
36+
37+
Args:
38+
input_path: Path to the input audio file
39+
output_dir: Directory where chunk files will be saved
40+
41+
Returns:
42+
List of paths to the created chunk files
43+
44+
Raises:
45+
FileNotFoundError: If input file doesn't exist
46+
IOError: If there's an error reading or writing audio files
47+
"""
48+
if not os.path.exists(input_path):
49+
raise FileNotFoundError(f"Input file not found: {input_path}")
50+
51+
if not os.path.exists(output_dir):
52+
os.makedirs(output_dir)
53+
54+
self.logger.debug(f"Loading audio file: {input_path}")
55+
audio = AudioSegment.from_file(input_path)
56+
57+
total_duration_ms = len(audio)
58+
chunk_paths = []
59+
60+
# Calculate number of chunks
61+
num_chunks = (total_duration_ms + self.chunk_duration_ms - 1) // self.chunk_duration_ms
62+
self.logger.info(f"Splitting {total_duration_ms / 1000:.1f}s audio into {num_chunks} chunks of {self.chunk_duration_ms / 1000:.1f}s each")
63+
64+
# Get file extension from input
65+
_, ext = os.path.splitext(input_path)
66+
if not ext:
67+
ext = ".wav" # Default to WAV if no extension
68+
69+
# Split into chunks
70+
for i in range(num_chunks):
71+
start_ms = i * self.chunk_duration_ms
72+
end_ms = min(start_ms + self.chunk_duration_ms, total_duration_ms)
73+
74+
chunk = audio[start_ms:end_ms]
75+
chunk_filename = f"chunk_{i:04d}{ext}"
76+
chunk_path = os.path.join(output_dir, chunk_filename)
77+
78+
self.logger.debug(f"Exporting chunk {i + 1}/{num_chunks}: {start_ms / 1000:.1f}s - {end_ms / 1000:.1f}s to {chunk_path}")
79+
chunk.export(chunk_path, format=ext.lstrip('.'))
80+
chunk_paths.append(chunk_path)
81+
82+
return chunk_paths
83+
84+
def merge_chunks(self, chunk_paths: List[str], output_path: str) -> str:
85+
"""Merge processed chunks with simple concatenation.
86+
87+
Args:
88+
chunk_paths: List of paths to chunk files to merge
89+
output_path: Path where the merged output will be saved
90+
91+
Returns:
92+
Path to the merged output file
93+
94+
Raises:
95+
ValueError: If chunk_paths is empty
96+
FileNotFoundError: If any chunk file doesn't exist
97+
IOError: If there's an error reading or writing audio files
98+
"""
99+
if not chunk_paths:
100+
raise ValueError("Cannot merge empty list of chunks")
101+
102+
# Verify all chunks exist
103+
for chunk_path in chunk_paths:
104+
if not os.path.exists(chunk_path):
105+
raise FileNotFoundError(f"Chunk file not found: {chunk_path}")
106+
107+
self.logger.info(f"Merging {len(chunk_paths)} chunks into {output_path}")
108+
109+
# Start with empty audio segment
110+
combined = AudioSegment.empty()
111+
112+
# Concatenate all chunks
113+
for i, chunk_path in enumerate(chunk_paths):
114+
self.logger.debug(f"Loading chunk {i + 1}/{len(chunk_paths)}: {chunk_path}")
115+
chunk = AudioSegment.from_file(chunk_path)
116+
combined += chunk # Simple concatenation
117+
118+
# Get output format from file extension
119+
_, ext = os.path.splitext(output_path)
120+
output_format = ext.lstrip('.') if ext else 'wav'
121+
122+
self.logger.info(f"Exporting merged audio ({len(combined) / 1000:.1f}s) to {output_path}")
123+
combined.export(output_path, format=output_format)
124+
125+
return output_path
126+
127+
def should_chunk(self, audio_duration_seconds: float) -> bool:
128+
"""Determine if file is large enough to benefit from chunking.
129+
130+
Args:
131+
audio_duration_seconds: Duration of the audio file in seconds
132+
133+
Returns:
134+
True if the file should be chunked, False otherwise
135+
"""
136+
return audio_duration_seconds > (self.chunk_duration_ms / 1000)

0 commit comments

Comments
 (0)