|
| 1 | +import os |
| 2 | +import sys |
| 3 | +import torch |
| 4 | +import numpy as np |
| 5 | +import mlx.core as mx |
| 6 | +from torch.nn import functional as F |
| 7 | + |
| 8 | +from rvc.infer.infer import VoiceConverter as VoiceConverterTorch |
| 9 | +from rvc.infer.pipeline import Pipeline |
| 10 | +from rvc.lib.mlx.synthesizers import Synthesizer |
| 11 | +from rvc.lib.mlx.convert import convert_weights |
| 12 | + |
| 13 | +class PipelineMLX(Pipeline): |
| 14 | + def __init__(self, tgt_sr, config, mlx_model): |
| 15 | + super().__init__(tgt_sr, config) |
| 16 | + self.mlx_model = mlx_model |
| 17 | + |
| 18 | + def voice_conversion( |
| 19 | + self, |
| 20 | + model, |
| 21 | + net_g, |
| 22 | + sid, |
| 23 | + audio0, |
| 24 | + pitch, |
| 25 | + pitchf, |
| 26 | + index, |
| 27 | + big_npy, |
| 28 | + index_rate, |
| 29 | + version, |
| 30 | + protect, |
| 31 | + ): |
| 32 | + """ |
| 33 | + MLX override of voice_conversion. |
| 34 | + Uses Torch for feature extraction, then MLX for synthesis. |
| 35 | + """ |
| 36 | + # Feature extraction logic from original pipeline.py |
| 37 | + with torch.no_grad(): |
| 38 | + pitch_guidance = pitch != None and pitchf != None |
| 39 | + |
| 40 | + feats = torch.from_numpy(audio0).float() |
| 41 | + feats = feats.mean(-1) if feats.dim() == 2 else feats |
| 42 | + assert feats.dim() == 1, feats.dim() |
| 43 | + feats = feats.view(1, -1).to(self.device) |
| 44 | + |
| 45 | + # extract features (Hubert) |
| 46 | + feats = model(feats)["last_hidden_state"] |
| 47 | + feats = ( |
| 48 | + model.final_proj(feats[0]).unsqueeze(0) if version == "v1" else feats |
| 49 | + ) |
| 50 | + |
| 51 | + feats0 = feats.clone() if pitch_guidance else None |
| 52 | + |
| 53 | + if index is not None and index_rate > 0: |
| 54 | + feats = self._retrieve_speaker_embeddings(feats, index, big_npy, index_rate) |
| 55 | + |
| 56 | + feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) |
| 57 | + |
| 58 | + p_len = min(audio0.shape[0] // self.window, feats.shape[1]) |
| 59 | + |
| 60 | + if pitch_guidance: |
| 61 | + feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) |
| 62 | + pitch, pitchf = pitch[:, :p_len], pitchf[:, :p_len] |
| 63 | + |
| 64 | + if protect < 0.5: |
| 65 | + pitchff = pitchf.clone() |
| 66 | + pitchff[pitchf > 0] = 1 |
| 67 | + pitchff[pitchf < 1] = protect |
| 68 | + feats = feats * pitchff.unsqueeze(-1) + feats0 * (1 - pitchff.unsqueeze(-1)) |
| 69 | + feats = feats.to(feats0.dtype) |
| 70 | + else: |
| 71 | + pitch, pitchf = None, None |
| 72 | + |
| 73 | + # --- MLX INFERENCE START --- |
| 74 | + # Convert tensors to MLX arrays. |
| 75 | + # feats: (1, L, C) -> MLX (1, L, C) |
| 76 | + feats_mx = mx.array(feats.cpu().numpy()) |
| 77 | + |
| 78 | + # pitch: (1, L) |
| 79 | + pitch_mx = mx.array(pitch.cpu().numpy()) if pitch is not None else None |
| 80 | + # pitchf: (1, L) |
| 81 | + pitchf_mx = mx.array(pitchf.cpu().numpy()) if pitchf is not None else None |
| 82 | + |
| 83 | + # sid: (1,) |
| 84 | + sid_mx = mx.array(sid.cpu().numpy()) |
| 85 | + |
| 86 | + # Call MLX model |
| 87 | + # MLX model handles (N, L, C) naturally |
| 88 | + # Our Synthesizer.infer returns (o, x_mask, stats) |
| 89 | + # o is audio output (B, 1, T) or (B, T, 1)? |
| 90 | + # HiFiGANNSFGenerator call: returns (B, 1, T) or (B, C, L)? |
| 91 | + # My port used `conv_post` (Conv1d). |
| 92 | + # If Conv1d output is (N, L, 1), then output is (N, L, 1). |
| 93 | + # Let's check my generator logic. |
| 94 | + # `x = self.conv_post(x)` |
| 95 | + # MLX Conv1d output (N, L, C). `out_channels=1`. |
| 96 | + # So (1, L, 1). |
| 97 | + |
| 98 | + out_audio, _, _ = self.mlx_model.infer( |
| 99 | + feats_mx, |
| 100 | + None, # phone_lengths is for training inputs to TextEncoder, inference computes from feats? |
| 101 | + # Wait, Synthesizer.infer signature: (phone, phone_lengths, ...). |
| 102 | + # But `pipeline.py` passes `feats` as 1st arg. `p_len` as 2nd. |
| 103 | + # `net_g.infer(feats, p_len, pitch, pitchf, sid)` |
| 104 | + |
| 105 | + # In `synthesizers.py` port: |
| 106 | + # def infer(self, phone, phone_lengths, pitch, nsff0, sid, rate=None): |
| 107 | + |
| 108 | + mx.array([p_len]), # phone_lengths |
| 109 | + pitch_mx, |
| 110 | + pitchf_mx, |
| 111 | + sid_mx |
| 112 | + ) |
| 113 | + |
| 114 | + # Output shape (1, T, 1) -> squeeze to (T) |
| 115 | + audio1 = np.array(out_audio[0, :, 0]) |
| 116 | + |
| 117 | + # --- MLX INFERENCE END --- |
| 118 | + |
| 119 | + del feats, feats0 |
| 120 | + return audio1 |
| 121 | + |
| 122 | + |
| 123 | +class VoiceConverterMLX(VoiceConverterTorch): |
| 124 | + def __init__(self): |
| 125 | + super().__init__() |
| 126 | + self.mlx_model = None |
| 127 | + self.mlx_model_path = None |
| 128 | + self.mlx_pipeline = None |
| 129 | + |
| 130 | + def get_vc(self, weight_root, sid): |
| 131 | + super().get_vc(weight_root, sid) |
| 132 | + |
| 133 | + if not self.mlx_model or self.mlx_model_path != weight_root: |
| 134 | + print(f"Loading MLX model: {weight_root}") |
| 135 | + weights, config = convert_weights(weight_root) |
| 136 | + |
| 137 | + renamed_weights = {} |
| 138 | + for k, v in weights.items(): |
| 139 | + new_k = k |
| 140 | + # Simple heuristic replacement for list indices |
| 141 | + # enc_p.attn_layers.0 -> enc_p.attn_0 |
| 142 | + # enc_p.norm_layers_1.0 -> enc_p.norm1_0 (matching my Encoders.py) |
| 143 | + # ... |
| 144 | + # This mapping must be precise. |
| 145 | + # Let's define the set of replacements based on my port names. |
| 146 | + # Encoders: |
| 147 | + # .attn_layers.X -> .attn_X |
| 148 | + # .norm_layers_1.X -> .norm1_X |
| 149 | + # .ffn_layers.X -> .ffn_X |
| 150 | + # .norm_layers_2.X -> .norm2_X |
| 151 | + |
| 152 | + # Generator (HiFiGAN): |
| 153 | + # .ups.X -> .up_X |
| 154 | + # .noise_convs.X -> .noise_conv_X |
| 155 | + # .resblocks.X -> .resblock_X |
| 156 | + |
| 157 | + # ResidualCouplingBlock/Layer: |
| 158 | + # .flows.X -> .flow_X |
| 159 | + # Inside ResidualCouplingLayer, `enc` is WaveNet. |
| 160 | + # WaveNet: .in_layers.X -> .in_layer_X, .res_skip_layers.X -> .res_skip_layer_X |
| 161 | + |
| 162 | + # Implementation of renaming: |
| 163 | + |
| 164 | + parts = new_k.split(".") |
| 165 | + new_parts = [] |
| 166 | + skip_next = False |
| 167 | + for i, p in enumerate(parts): |
| 168 | + if skip_next: |
| 169 | + skip_next = False |
| 170 | + continue |
| 171 | + |
| 172 | + if p.isdigit(): |
| 173 | + # Should have been handled by previous part |
| 174 | + new_parts.append(p) |
| 175 | + elif i + 1 < len(parts) and parts[i+1].isdigit(): |
| 176 | + # Found list access pattern e.g. "ups" followed by "0" |
| 177 | + idx = parts[i+1] |
| 178 | + |
| 179 | + # Apply specific renaming rules |
| 180 | + if p == "attn_layers": p = f"attn_{idx}" |
| 181 | + elif p == "norm_layers_1": p = f"norm1_{idx}" |
| 182 | + elif p == "ffn_layers": p = f"ffn_{idx}" |
| 183 | + elif p == "norm_layers_2": p = f"norm2_{idx}" |
| 184 | + elif p == "ups": p = f"up_{idx}" |
| 185 | + elif p == "noise_convs": p = f"noise_conv_{idx}" |
| 186 | + elif p == "resblocks": p = f"resblock_{idx}" |
| 187 | + elif p == "flows": p = f"flow_{idx}" |
| 188 | + elif p == "in_layers": p = f"in_layer_{idx}" |
| 189 | + elif p == "res_skip_layers": p = f"res_skip_layer_{idx}" |
| 190 | + |
| 191 | + else: |
| 192 | + # Default fallback: name_idx |
| 193 | + p = f"{p}_{idx}" |
| 194 | + |
| 195 | + new_parts.append(p) |
| 196 | + skip_next = True |
| 197 | + else: |
| 198 | + new_parts.append(p) |
| 199 | + |
| 200 | + new_k = ".".join(new_parts) |
| 201 | + renamed_weights[new_k] = v |
| 202 | + |
| 203 | + self.mlx_model = Synthesizer( |
| 204 | + *config, |
| 205 | + use_f0=self.use_f0, |
| 206 | + text_enc_hidden_dim=self.text_enc_hidden_dim, |
| 207 | + vocoder=self.vocoder |
| 208 | + ) |
| 209 | + # Use load_weights assuming flattened structure |
| 210 | + # self.mlx_model.load_weights(list(renamed_weights.items())) -- expects file |
| 211 | + self.mlx_model.update(renamed_weights) |
| 212 | + # MX eval to ensure weights loaded/cached |
| 213 | + mx.eval(self.mlx_model.parameters()) |
| 214 | + |
| 215 | + self.mlx_model_path = weight_root |
| 216 | + |
| 217 | + # Create pipeline instance injection |
| 218 | + self.mlx_pipeline = PipelineMLX(self.tgt_sr, self.config, self.mlx_model) |
| 219 | + |
| 220 | + def pipeline(self, model, net_g, sid, audio, pitch, f0_method, file_index, index_rate, pitch_guidance, volume_envelope, version, protect, f0_autotune, f0_autotune_strength, proposed_pitch, proposed_pitch_threshold): |
| 221 | + # Delegate to self.mlx_pipeline instead of self.vc (Torch pipeline) |
| 222 | + # Note: self.vc is still initialized in super().get_vc with Torch model, which is fine (used for Hubert etc?) |
| 223 | + # `model` passed here is Hubert. `net_g` is Torch Generator (ignored by us). |
| 224 | + |
| 225 | + return self.mlx_pipeline.pipeline( |
| 226 | + model, |
| 227 | + None, # net_g ignored |
| 228 | + sid, |
| 229 | + audio, |
| 230 | + pitch, |
| 231 | + f0_method, |
| 232 | + file_index, |
| 233 | + index_rate, |
| 234 | + pitch_guidance, |
| 235 | + volume_envelope, |
| 236 | + version, |
| 237 | + protect, |
| 238 | + f0_autotune, |
| 239 | + f0_autotune_strength, |
| 240 | + proposed_pitch, |
| 241 | + proposed_pitch_threshold |
| 242 | + ) |
0 commit comments