Skip to content

Commit d29dbcc

Browse files
committed
Remove librosa dependency from audio loading
Replaced librosa functions in load_audio_infer with numpy for mono conversion and scipy for resampling, reducing dependencies and improving performance. Also added mx.eval after MLX model weight loading in infer_mlx.py and rmvpe.py to ensure weights are cached. Updated context.md with new benchmarks and a detailed TODO list for future optimizations.
1 parent 17a0597 commit d29dbcc

4 files changed

Lines changed: 44 additions & 7 deletions

File tree

context.md

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,36 @@ Replaced librosa CPU-based mel spectrogram (645ms first call) with GPU-accelerat
3737
### Backend Selection
3838
| Backend | Description | Performance |
3939
|---------|-------------|-------------|
40-
| `torch` | PyTorch with MPS | 3.14s |
41-
| `mlx` | Full MLX inference | **3.12s** (-0.5%) |
40+
| `torch` | PyTorch with MPS | 2.81s |
41+
| `mlx` | Full MLX inference | **2.91s** |
42+
43+
## 🚀 TODO / Future Optimizations
44+
45+
### 1. Batch Processing in RMVPE mel2hidden
46+
- [ ] Optimize `mel2hidden` to process the mel spectrogram in chunks for better GPU cache utilization and throughput.
47+
48+
### 2. Fused Operations in Hubert
49+
- [ ] Profile and possibly fuse transformer blocks (Q/K/V projections, softmax, output projection) in Hubert using `mx.compile` more strategically or custom kernels.
50+
51+
### 3. Cache Warmup on Model Load
52+
- [ ] Run a single dummy inference iteration immediately after loading models to trigger all MLX kernel compilation (JIT). This shifts the one-time "first run" penalty to the startup phase.
53+
54+
### 4. Proper End-to-End float16 Support
55+
Currently, float16 caused a slowdown because of constant casting between float32 (audio/mel) and float16 (model). To fix:
56+
- [ ] Convert input audio to `float16` immediately after loading.
57+
- [ ] Update `mel_spectrogram` to output `float16`.
58+
- [ ] Implement `tree_map` to cast all model parameters to `float16` at load time.
59+
- [ ] Ensure the entire pipeline operates in `float16`, only casting back to `float32` for final storage.
60+
61+
### 5. Streaming Synthesis
62+
- [ ] Implement overlapping chunk processing for the Synthesizer to reduce peak memory usage and potentially enable real-time/streaming output.
63+
64+
### 6. Remove librosa Dependency Entirely
65+
- [x] Replaced librosa `to_mono` and `resample` with `scipy`/`numpy` in `load_audio_infer`.
66+
- [ ] Investigate moving audio loading entirely to a more lightweight solution if `soundfile`/`scipy` overhead is still noticeable.
67+
68+
### 7. Custom Metal Kernels
69+
- [ ] For absolute peak performance, write optimized Metal shaders for the most compute-intensive operations if `mx.compile` isn't sufficient.
70+
71+
### 8. Quantization (INT8/INT4)
72+
- [ ] Explore `mlx.nn.QuantizedLinear` for the Synthesizer model to reduce memory bandwidth requirements.

rvc/infer/infer_mlx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def get_vc(self, weight_root, sid):
208208
)
209209
# Use load_weights assuming flattened structure
210210
# self.mlx_model.load_weights(list(renamed_weights.items())) -- expects file
211-
self.mlx_model.update(renamed_weights)
211+
self.mlx_model.update(renamed_weights)
212212
# MX eval to ensure weights loaded/cached
213213
mx.eval(self.mlx_model.parameters())
214214

rvc/lib/mlx/rmvpe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,7 @@ def __init__(self, model_path=None, weights_path=None, device=None):
438438
print(f"RMVPE MLX weights not found at {weights_path}")
439439
else:
440440
self.model.load_weights(weights_path)
441+
mx.eval(self.model.parameters())
441442

442443
# Constants for decode
443444
N_CLASS = 360

rvc/lib/utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,17 @@ def load_audio_infer(
7171
if not os.path.isfile(file):
7272
raise FileNotFoundError(f"File not found: {file}")
7373
audio, sr = sf.read(file)
74+
75+
# Convert to mono using numpy (no librosa)
7476
if len(audio.shape) > 1:
75-
audio = librosa.to_mono(audio.T)
77+
audio = np.mean(audio, axis=1) # Average channels for mono
78+
79+
# Resample using scipy (no librosa)
7680
if sr != sample_rate:
77-
audio = librosa.resample(
78-
audio, orig_sr=sr, target_sr=sample_rate, res_type="soxr_vhq"
79-
)
81+
from scipy import signal
82+
num_samples = int(len(audio) * sample_rate / sr)
83+
audio = signal.resample(audio, num_samples)
84+
8085
if formant_shifting:
8186
formant_qfrency = kwargs.get("formant_qfrency", 0.8)
8287
formant_timbre = kwargs.get("formant_timbre", 0.8)

0 commit comments

Comments
 (0)