Skip to content

Commit ff61b5a

Browse files
committed
Add MLX backend for native Apple Silicon inference
Introduces MLX-based inference pipeline for RVC CLI, including ported model modules, custom ConvTranspose1d, and weight conversion logic. Updates CLI to support '--backend mlx' and requirements for macOS. Includes context documentation and debug logs for development and integration.
1 parent 01a8ba5 commit ff61b5a

18 files changed

Lines changed: 1543 additions & 1 deletion

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,5 @@ rvc/models
2222
env
2323
venv
2424
.venv
25+
26+
.DS_Store

context.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Project Context & Session Summary
2+
3+
**Date:** 2026-01-05
4+
**Objective:** Add native Apple Silicon (MLX) inference support to RVC CLI.
5+
6+
## Accomplishments
7+
1. **MLX Core Integration**:
8+
* Added `mlx` dependency for macOS.
9+
* Created `rvc/lib/mlx/` package containing ported modules:
10+
* `modules.py`: WaveNet
11+
* `attentions.py`: MultiHeadAttention, FFN
12+
* `residuals.py`: ResBlock, ResidualCouplingBlock
13+
* `generators.py`: HiFiGANNSFGenerator, SineGenerator
14+
* `encoders.py`: TextEncoder, PosteriorEncoder
15+
* `synthesizers.py`: Synthesizer (The main generator model)
16+
* **Architecture Choice**: Adopted a **Hybrid Pipeline**. We rely on the existing PyTorch implementation for complex Feature Extraction (Hubert, RMVPE) to ensure compatibility and stability, and use MLX solely for the computationally expensive HiFiGAN synthesis step.
17+
18+
2. **Inference Pipeline**:
19+
* Implemented `VoiceConverterMLX` and `PipelineMLX` in `rvc/infer/infer_mlx.py`.
20+
* Implemented on-the-fly weight conversion in `rvc/lib/mlx/convert.py` which loads a standard RVC `.pth`, fuses `weight_norm` layers, and transposes weights to match MLX's (N, L, C) layout.
21+
22+
3. **CLI Integration**:
23+
* Modified `rvc_cli.py` to accept `--backend mlx`.
24+
* Standard usage: `python rvc_cli.py infer ... --backend mlx`.
25+
26+
## Critical "Tidbits" for Future Sessions
27+
28+
### 1. Model Locations
29+
The user's test models are located at:
30+
> **`/Users/mcruz/Library/Application Support/Replay/com.replay.Replay/models`**
31+
32+
You should verify availability of models here before running tests.
33+
34+
### 2. Environment Variables
35+
* **`export OMP_NUM_THREADS=1`**: This is **MANDATORY** on macOS to prevent `faiss` from crashing the process with a segmentation fault.
36+
37+
### 5. Runtime Environment
38+
* **Conda Environment**: All commands must be run within the `rvc` Conda environment.
39+
* Example: `conda run -n rvc python rvc_cli.py ...` or `source activate rvc` before running.
40+
41+
### 3. Model Compatibility
42+
* **Config Required**: The MLX converter expects the `.pth` file to contain a `config` key (list of hyperparameters) alongside the `weight` key.
43+
* **No Pretrained-Only**: Raw training checkpoints (like `f0G40k.pth`) often lack the `config` key and will fail to load in the current MLX implementation. Use fully trained/exported RVC models.
44+
45+
### 4. Implementation Details
46+
* **Data Layout**: PyTorch uses `(N, C, L)` (Channels First). MLX components were ported to use `(N, L, C)` (Channels Last) which is more native to MLX/Transformers. The converter handles this transposition.
47+
* **Missing Layers**: `mlx.nn` does not yet have a `ConvTranspose1d` layer. We implemented a custom `ConvTranspose1d` in `rvc/lib/mlx/generators.py` using an upsample-and-convolve approach.
48+
* **Weight Transposition**:
49+
* Regular Conv1d: PyTorch `(Out, In, K)` -> MLX `(Out, K, In)`. Transpose `(0, 2, 1)`.
50+
* ConvTranspose1d: PyTorch `(In, Out, K)` -> MLX `(Out, K, In)` (effectively). Transpose `(1, 2, 0)`.
51+
* **Performance**: The current implementation converts weights *every time* inference is run. For production, we should implement a mechanism to save/load converted `.npz` or `.safetensors` MLX weights.
52+
53+
## Next Steps
54+
* **Final Verification**: Run a full end-to-end test using a model from the Replay directory.
55+
* **Optimization**: Cache converted MLX weights to disk.
56+
* **Benchmarks**: Compare MPS (PyTorch) vs MLX performance.

debug_mlx_2.log

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
Traceback (most recent call last):
2+
File "/Users/mcruz/Developer/Retrieval-based-Voice-Conversion-MLX/rvc_cli.py", line 2196, in main
3+
run_infer_script(
4+
File "/Users/mcruz/Developer/Retrieval-based-Voice-Conversion-MLX/rvc_cli.py", line 183, in run_infer_script
5+
infer_pipeline.convert_audio(
6+
File "/Users/mcruz/Developer/Retrieval-based-Voice-Conversion-MLX/rvc/infer/infer.py", line 250, in convert_audio
7+
self.get_vc(model_path, sid)
8+
File "/Users/mcruz/Developer/Retrieval-based-Voice-Conversion-MLX/rvc/infer/infer_mlx.py", line 203, in get_vc
9+
self.mlx_model = Synthesizer(
10+
File "/Users/mcruz/Developer/Retrieval-based-Voice-Conversion-MLX/rvc/lib/mlx/synthesizers.py", line 51, in __init__
11+
self.dec = HiFiGANNSFGenerator(
12+
File "/Users/mcruz/Developer/Retrieval-based-Voice-Conversion-MLX/rvc/lib/mlx/generators.py", line 182, in __init__
13+
self.ups.append(nn.ConvTranspose1d(in_ch, out_ch, k, stride=u, padding=p))
14+
AttributeError: module 'mlx.nn' has no attribute 'ConvTranspose1d'
15+
16+
Using HiFi-GAN vocoder
17+
Loading MLX model: /Users/mcruz/Library/Application Support/Replay/com.replay.Replay/models/Slim Shady/model.pth
18+
An error occurred during execution: module 'mlx.nn' has no attribute 'ConvTranspose1d'
19+

debug_mlx_3.log

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
Traceback (most recent call last):
2+
File "/Users/mcruz/Developer/Retrieval-based-Voice-Conversion-MLX/rvc_cli.py", line 2196, in main
3+
run_infer_script(
4+
File "/Users/mcruz/Developer/Retrieval-based-Voice-Conversion-MLX/rvc_cli.py", line 183, in run_infer_script
5+
infer_pipeline.convert_audio(
6+
File "/Users/mcruz/Developer/Retrieval-based-Voice-Conversion-MLX/rvc/infer/infer.py", line 250, in convert_audio
7+
self.get_vc(model_path, sid)
8+
File "/Users/mcruz/Developer/Retrieval-based-Voice-Conversion-MLX/rvc/infer/infer_mlx.py", line 203, in get_vc
9+
self.mlx_model = Synthesizer(
10+
File "/Users/mcruz/Developer/Retrieval-based-Voice-Conversion-MLX/rvc/lib/mlx/synthesizers.py", line 51, in __init__
11+
self.dec = HiFiGANNSFGenerator(
12+
File "/Users/mcruz/Developer/Retrieval-based-Voice-Conversion-MLX/rvc/lib/mlx/generators.py", line 210, in __init__
13+
self.m_source = SourceModuleHnNSF(sample_rate=sr, harmonic_num=0)
14+
File "/Users/mcruz/Developer/Retrieval-based-Voice-Conversion-MLX/rvc/lib/mlx/generators.py", line 183, in __init__
15+
self.l_sin_gen = SineGenerator(sample_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
16+
NameError: name 'SineGenerator' is not defined
17+
18+
Using HiFi-GAN vocoder
19+
Loading MLX model: /Users/mcruz/Library/Application Support/Replay/com.replay.Replay/models/Slim Shady/model.pth
20+
An error occurred during execution: name 'SineGenerator' is not defined
21+

debug_mlx_4.log

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Using HiFi-GAN vocoder
2+
Loading MLX model: /Users/mcruz/Library/Application Support/Replay/com.replay.Replay/models/Slim Shady/model.pth
3+
Converting audio 'TestAudio/coder_audio_stock.wav'...
4+
Conversion completed at 'TestAudio/coder_mlx_slim.wav' in 3.29 seconds.
5+

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,4 @@ certifi; sys_platform == 'darwin'
4040
antlr4-python3-runtime
4141
edge-tts
4242
webrtcvad
43+
mlx; sys_platform == 'darwin'

rvc/.DS_Store

0 Bytes
Binary file not shown.

rvc/infer/infer_mlx.py

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
import os
2+
import sys
3+
import torch
4+
import numpy as np
5+
import mlx.core as mx
6+
from torch.nn import functional as F
7+
8+
from rvc.infer.infer import VoiceConverter as VoiceConverterTorch
9+
from rvc.infer.pipeline import Pipeline
10+
from rvc.lib.mlx.synthesizers import Synthesizer
11+
from rvc.lib.mlx.convert import convert_weights
12+
13+
class PipelineMLX(Pipeline):
14+
def __init__(self, tgt_sr, config, mlx_model):
15+
super().__init__(tgt_sr, config)
16+
self.mlx_model = mlx_model
17+
18+
def voice_conversion(
19+
self,
20+
model,
21+
net_g,
22+
sid,
23+
audio0,
24+
pitch,
25+
pitchf,
26+
index,
27+
big_npy,
28+
index_rate,
29+
version,
30+
protect,
31+
):
32+
"""
33+
MLX override of voice_conversion.
34+
Uses Torch for feature extraction, then MLX for synthesis.
35+
"""
36+
# Feature extraction logic from original pipeline.py
37+
with torch.no_grad():
38+
pitch_guidance = pitch != None and pitchf != None
39+
40+
feats = torch.from_numpy(audio0).float()
41+
feats = feats.mean(-1) if feats.dim() == 2 else feats
42+
assert feats.dim() == 1, feats.dim()
43+
feats = feats.view(1, -1).to(self.device)
44+
45+
# extract features (Hubert)
46+
feats = model(feats)["last_hidden_state"]
47+
feats = (
48+
model.final_proj(feats[0]).unsqueeze(0) if version == "v1" else feats
49+
)
50+
51+
feats0 = feats.clone() if pitch_guidance else None
52+
53+
if index is not None and index_rate > 0:
54+
feats = self._retrieve_speaker_embeddings(feats, index, big_npy, index_rate)
55+
56+
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
57+
58+
p_len = min(audio0.shape[0] // self.window, feats.shape[1])
59+
60+
if pitch_guidance:
61+
feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
62+
pitch, pitchf = pitch[:, :p_len], pitchf[:, :p_len]
63+
64+
if protect < 0.5:
65+
pitchff = pitchf.clone()
66+
pitchff[pitchf > 0] = 1
67+
pitchff[pitchf < 1] = protect
68+
feats = feats * pitchff.unsqueeze(-1) + feats0 * (1 - pitchff.unsqueeze(-1))
69+
feats = feats.to(feats0.dtype)
70+
else:
71+
pitch, pitchf = None, None
72+
73+
# --- MLX INFERENCE START ---
74+
# Convert tensors to MLX arrays.
75+
# feats: (1, L, C) -> MLX (1, L, C)
76+
feats_mx = mx.array(feats.cpu().numpy())
77+
78+
# pitch: (1, L)
79+
pitch_mx = mx.array(pitch.cpu().numpy()) if pitch is not None else None
80+
# pitchf: (1, L)
81+
pitchf_mx = mx.array(pitchf.cpu().numpy()) if pitchf is not None else None
82+
83+
# sid: (1,)
84+
sid_mx = mx.array(sid.cpu().numpy())
85+
86+
# Call MLX model
87+
# MLX model handles (N, L, C) naturally
88+
# Our Synthesizer.infer returns (o, x_mask, stats)
89+
# o is audio output (B, 1, T) or (B, T, 1)?
90+
# HiFiGANNSFGenerator call: returns (B, 1, T) or (B, C, L)?
91+
# My port used `conv_post` (Conv1d).
92+
# If Conv1d output is (N, L, 1), then output is (N, L, 1).
93+
# Let's check my generator logic.
94+
# `x = self.conv_post(x)`
95+
# MLX Conv1d output (N, L, C). `out_channels=1`.
96+
# So (1, L, 1).
97+
98+
out_audio, _, _ = self.mlx_model.infer(
99+
feats_mx,
100+
None, # phone_lengths is for training inputs to TextEncoder, inference computes from feats?
101+
# Wait, Synthesizer.infer signature: (phone, phone_lengths, ...).
102+
# But `pipeline.py` passes `feats` as 1st arg. `p_len` as 2nd.
103+
# `net_g.infer(feats, p_len, pitch, pitchf, sid)`
104+
105+
# In `synthesizers.py` port:
106+
# def infer(self, phone, phone_lengths, pitch, nsff0, sid, rate=None):
107+
108+
mx.array([p_len]), # phone_lengths
109+
pitch_mx,
110+
pitchf_mx,
111+
sid_mx
112+
)
113+
114+
# Output shape (1, T, 1) -> squeeze to (T)
115+
audio1 = np.array(out_audio[0, :, 0])
116+
117+
# --- MLX INFERENCE END ---
118+
119+
del feats, feats0
120+
return audio1
121+
122+
123+
class VoiceConverterMLX(VoiceConverterTorch):
124+
def __init__(self):
125+
super().__init__()
126+
self.mlx_model = None
127+
self.mlx_model_path = None
128+
self.mlx_pipeline = None
129+
130+
def get_vc(self, weight_root, sid):
131+
super().get_vc(weight_root, sid)
132+
133+
if not self.mlx_model or self.mlx_model_path != weight_root:
134+
print(f"Loading MLX model: {weight_root}")
135+
weights, config = convert_weights(weight_root)
136+
137+
renamed_weights = {}
138+
for k, v in weights.items():
139+
new_k = k
140+
# Simple heuristic replacement for list indices
141+
# enc_p.attn_layers.0 -> enc_p.attn_0
142+
# enc_p.norm_layers_1.0 -> enc_p.norm1_0 (matching my Encoders.py)
143+
# ...
144+
# This mapping must be precise.
145+
# Let's define the set of replacements based on my port names.
146+
# Encoders:
147+
# .attn_layers.X -> .attn_X
148+
# .norm_layers_1.X -> .norm1_X
149+
# .ffn_layers.X -> .ffn_X
150+
# .norm_layers_2.X -> .norm2_X
151+
152+
# Generator (HiFiGAN):
153+
# .ups.X -> .up_X
154+
# .noise_convs.X -> .noise_conv_X
155+
# .resblocks.X -> .resblock_X
156+
157+
# ResidualCouplingBlock/Layer:
158+
# .flows.X -> .flow_X
159+
# Inside ResidualCouplingLayer, `enc` is WaveNet.
160+
# WaveNet: .in_layers.X -> .in_layer_X, .res_skip_layers.X -> .res_skip_layer_X
161+
162+
# Implementation of renaming:
163+
164+
parts = new_k.split(".")
165+
new_parts = []
166+
skip_next = False
167+
for i, p in enumerate(parts):
168+
if skip_next:
169+
skip_next = False
170+
continue
171+
172+
if p.isdigit():
173+
# Should have been handled by previous part
174+
new_parts.append(p)
175+
elif i + 1 < len(parts) and parts[i+1].isdigit():
176+
# Found list access pattern e.g. "ups" followed by "0"
177+
idx = parts[i+1]
178+
179+
# Apply specific renaming rules
180+
if p == "attn_layers": p = f"attn_{idx}"
181+
elif p == "norm_layers_1": p = f"norm1_{idx}"
182+
elif p == "ffn_layers": p = f"ffn_{idx}"
183+
elif p == "norm_layers_2": p = f"norm2_{idx}"
184+
elif p == "ups": p = f"up_{idx}"
185+
elif p == "noise_convs": p = f"noise_conv_{idx}"
186+
elif p == "resblocks": p = f"resblock_{idx}"
187+
elif p == "flows": p = f"flow_{idx}"
188+
elif p == "in_layers": p = f"in_layer_{idx}"
189+
elif p == "res_skip_layers": p = f"res_skip_layer_{idx}"
190+
191+
else:
192+
# Default fallback: name_idx
193+
p = f"{p}_{idx}"
194+
195+
new_parts.append(p)
196+
skip_next = True
197+
else:
198+
new_parts.append(p)
199+
200+
new_k = ".".join(new_parts)
201+
renamed_weights[new_k] = v
202+
203+
self.mlx_model = Synthesizer(
204+
*config,
205+
use_f0=self.use_f0,
206+
text_enc_hidden_dim=self.text_enc_hidden_dim,
207+
vocoder=self.vocoder
208+
)
209+
# Use load_weights assuming flattened structure
210+
# self.mlx_model.load_weights(list(renamed_weights.items())) -- expects file
211+
self.mlx_model.update(renamed_weights)
212+
# MX eval to ensure weights loaded/cached
213+
mx.eval(self.mlx_model.parameters())
214+
215+
self.mlx_model_path = weight_root
216+
217+
# Create pipeline instance injection
218+
self.mlx_pipeline = PipelineMLX(self.tgt_sr, self.config, self.mlx_model)
219+
220+
def pipeline(self, model, net_g, sid, audio, pitch, f0_method, file_index, index_rate, pitch_guidance, volume_envelope, version, protect, f0_autotune, f0_autotune_strength, proposed_pitch, proposed_pitch_threshold):
221+
# Delegate to self.mlx_pipeline instead of self.vc (Torch pipeline)
222+
# Note: self.vc is still initialized in super().get_vc with Torch model, which is fine (used for Hubert etc?)
223+
# `model` passed here is Hubert. `net_g` is Torch Generator (ignored by us).
224+
225+
return self.mlx_pipeline.pipeline(
226+
model,
227+
None, # net_g ignored
228+
sid,
229+
audio,
230+
pitch,
231+
f0_method,
232+
file_index,
233+
index_rate,
234+
pitch_guidance,
235+
volume_envelope,
236+
version,
237+
protect,
238+
f0_autotune,
239+
f0_autotune_strength,
240+
proposed_pitch,
241+
proposed_pitch_threshold
242+
)

rvc/lib/mlx/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)