@@ -13,8 +13,7 @@ import MLXNN
1313 public var onLog : ( ( String ) -> Void ) ?
1414
1515 var hubertModel : HubertModel ?
16- var featureProjection : MLXNN . Linear ? // 768 -> 192
17- var generator : Generator ?
16+ var synthesizer : Synthesizer ?
1817 var rmvpe : RMVPE ?
1918
2019 private func log( _ message: String ) {
@@ -46,52 +45,53 @@ import MLXNN
4645 }
4746 self . hubertModel? . update ( parameters: ModuleParameters . unflattened ( newParams) )
4847
49- // 2. Load Generator (simplified approach - just Generator + feature projection )
50- log ( " RVCInference: Loading Generator from \( modelURL. lastPathComponent) " )
48+ // 2. Load Synthesizer (TextEncoder + Flow + Generator )
49+ log ( " RVCInference: Loading Synthesizer from \( modelURL. lastPathComponent) " )
5150 let modelWeights = try MLX . loadArrays ( url: modelURL)
51+
52+ // Note: The Python conversion script already transposes all weights to MLX format
53+ // No additional transposition needed here!
54+ log ( " RVCInference: Loaded \( modelWeights. count) weights (already in MLX format) " )
5255
53- // Auto-transpose Conv1d weights from PyTorch to MLX format
54- var transposedWeights : [ String : MLXArray ] = [ : ]
55- var transposedCount = 0
56- for (key, value) in modelWeights {
57- if key. contains ( " .weight " ) && value. ndim == 3 {
58- let shape = value. shape
59- if shape [ 2 ] < shape [ 1 ] {
60- transposedWeights [ key] = value. transposed ( 0 , 2 , 1 )
61- transposedCount += 1
62- } else {
63- transposedWeights [ key] = value
64- }
65- } else {
66- transposedWeights [ key] = value
67- }
68- }
69- log ( " RVCInference: Transposed \( transposedCount) Conv1d weights " )
70-
71- // Create feature projection (768 -> 192 using enc_p.emb_phone weights)
72- self . featureProjection = MLXNN . Linear ( 768 , 192 )
73- if let weight = transposedWeights [ " enc_p.emb_phone.weight " ] ,
74- let bias = transposedWeights [ " enc_p.emb_phone.bias " ] {
75- self . featureProjection? . update ( parameters: ModuleParameters . unflattened ( [
76- " weight " : weight,
77- " bias " : bias
78- ] ) )
79- log ( " RVCInference: Loaded feature projection (768->192) " )
80- }
56+ // Initialize Synthesizer (V2 defaults)
57+ self . synthesizer = Synthesizer (
58+ interChannels: 192 ,
59+ hiddenChannels: 192 ,
60+ filterChannels: 768 ,
61+ nHeads: 2 ,
62+ nLayers: 6 ,
63+ kernelSize: 3 ,
64+ pDropout: 0.0 ,
65+ embeddingDim: 768 , // Model weights expect 768-dim HuBERT features
66+ speakerEmbedDim: 256 ,
67+ ginChannels: 256 ,
68+ useF0: true
69+ )
8170
82- // Create Generator (192 input, with gin_channels for conditioning)
83- self . generator = Generator ( inputChannels: 192 , ginChannels: 256 )
71+ // Remap keys for Synthesizer
72+ // RVC V2 PyTorch Weights:
73+ // enc_p.* -> TextEncoder
74+ // dec.* -> Generator
75+ // flow.flows.0, 2, 4, 6 -> Flow (indices 0, 1, 2, 3)
76+ // emb_g.weight -> Speaker Embedding
8477
85- // Load Generator weights (dec.* prefix)
86- var genParams : [ String : MLXArray ] = [ : ]
87- for (k, v) in transposedWeights {
88- if k. hasPrefix ( " dec. " ) {
89- let newK = String ( k. dropFirst ( 4 ) )
90- genParams [ newK] = v
78+ var synthParams : [ String : MLXArray ] = [ : ]
79+ for (k, v) in modelWeights {
80+ var newK = k
81+ if k. hasPrefix ( " flow.flows. " ) {
82+ let parts = k. components ( separatedBy: " . " )
83+ if parts. count >= 3 , let oldIdx = Int ( parts [ 2 ] ) {
84+ // PyTorch indices: 0, 2, 4, 6
85+ // Swift indices: 0, 1, 2, 3
86+ let newIdx = oldIdx / 2
87+ newK = ( [ " flow " , " flows " , String ( newIdx) ] + parts. dropFirst ( 3 ) ) . joined ( separator: " . " )
88+ }
9189 }
90+ synthParams [ newK] = v
9291 }
93- self . generator? . update ( parameters: ModuleParameters . unflattened ( genParams) )
94- log ( " RVCInference: Loaded Generator with \( genParams. count) weight keys " )
92+
93+ self . synthesizer? . update ( parameters: ModuleParameters . unflattened ( synthParams) )
94+ log ( " RVCInference: Loaded Synthesizer with \( synthParams. count) weight keys " )
9595
9696 // 3. Load RMVPE (Optional)
9797 if let rmvpeURL = rmvpeURL {
@@ -223,58 +223,81 @@ import MLXNN
223223 }
224224
225225 private func inferChunk( chunk: MLXArray ) async throws -> MLXArray {
226- // chunk: [T]
226+ // chunk: [T] - 16kHz
227227
228- // 1. Hubert Feature Extraction
228+ // 1. Hubert Feature Extraction (16kHz -> 50fps)
229229 let audioInput = chunk. expandedDimensions ( axis: 0 ) // [1, T]
230230 guard let hubertModel = hubertModel else {
231231 throw NSError ( domain: " RVCInference " , code: 1 , userInfo: [ NSLocalizedDescriptionKey: " Hubert model missing " ] )
232232 }
233- let features = hubertModel ( audioInput) // [1, Frames, 768]
234- MLX . eval ( features)
235-
236- // 2. Feature Projection (768 -> 192)
237- guard let projection = featureProjection else {
238- throw NSError ( domain: " RVCInference " , code: 1 , userInfo: [ NSLocalizedDescriptionKey: " Feature projection missing " ] )
239- }
240- var projectedFeatures = projection ( features) // [1, Frames, 192]
241- projectedFeatures = leakyRelu ( projectedFeatures, negativeSlope: 0.1 )
242- MLX . eval ( projectedFeatures)
243-
244- // 3. F0 Estimation
233+ let hubertFeatures = hubertModel ( audioInput) // [1, Frames, 768]
234+ MLX . eval ( hubertFeatures)
235+ log ( " DEBUG: HuBERT output shape: \( hubertFeatures. shape) " )
236+
237+ // 2. F0 Estimation (16kHz -> 100fps)
245238 var f0 : MLXArray
246239 if let rmvpe = rmvpe {
247240 f0 = rmvpe. infer ( audio: chunk, thred: 0.03 ) // [1, Frames_rmvpe, 1]
248241 } else {
249- // Fallback: constant F0
250- let frames = features. shape [ 1 ] * 2
242+ // Fallback: constant F0 (200Hz)
243+ // RMVPE produces 100fps, Hubert 50fps. So RMVPE has 2x frames.
244+ let frames = hubertFeatures. shape [ 1 ] * 2
251245 f0 = MLX . full ( [ 1 , frames, 1 ] , values: MLXArray ( 200.0 ) )
252246 }
253247 MLX . eval ( f0)
254248
255- // 4 . Upsample Features (2x) - simple repeat
256- let N = projectedFeatures . shape [ 0 ]
257- let L = projectedFeatures . shape [ 1 ]
258- let C = projectedFeatures . shape [ 2 ]
249+ // 3 . Upsample Hubert Features to match F0 (100fps)
250+ let N = hubertFeatures . shape [ 0 ]
251+ let L = hubertFeatures . shape [ 1 ]
252+ let C = hubertFeatures . shape [ 2 ]
259253
260- let expanded = projectedFeatures. expandedDimensions ( axis: 2 )
254+ // Simple repeat upsampling: [1, L, 768] -> [1, L, 2, 768] -> [1, L*2, 768]
255+ let expanded = hubertFeatures. expandedDimensions ( axis: 2 )
261256 let broadcasted = MLX . broadcast ( expanded, to: [ N, L, 2 , C] )
262- var phone = broadcasted. reshaped ( [ N, L * 2 , C] ) // [N, L*2, 192]
257+ var phone = broadcasted. reshaped ( [ N, L * 2 , C] )
258+ log ( " DEBUG: Upsampled phone shape: \( phone. shape) " )
259+
260+ // 4. Coarse Pitch calculation (Hz -> Bucket 1-255)
261+ let f0Hz = f0. squeezed ( axes: [ 2 ] ) // [1, L_f0]
262+ let f0_min : Float = 50.0
263+ let f0_max : Float = 1100.0
264+ let f0_mel_min = 1127.0 * Darwin. log ( 1.0 + Double( f0_min) / 700.0 )
265+ let f0_mel_max = 1127.0 * Darwin. log ( 1.0 + Double( f0_max) / 700.0 )
263266
264- // 5. Sync lengths
265- let lenFeat = phone. shape [ 1 ]
266- let lenF0 = f0. shape [ 1 ]
267- let minLen = min ( lenFeat, lenF0)
267+ // MLX Mel calculation: 1127 * ln(1 + f/700)
268+ let f0_mel = 1127.0 * MLX. log ( 1.0 + f0Hz / 700.0 )
268269
269- phone = phone [ 0 ... , 0 ..< minLen, 0 ... ]
270- f0 = f0 [ 0 ... , 0 ..< minLen, 0 ... ]
270+ // Bucket quantization
271+ var pitch = ( f0_mel - f0_mel_min) * ( 254.0 / ( f0_mel_max - f0_mel_min) ) + 1.0
272+ pitch = MLX . where ( f0Hz .<= f0_min, MLXArray ( 1.0 ) , pitch)
273+ pitch = MLX . maximum ( pitch, 1.0 )
274+ pitch = MLX . minimum ( pitch, 255.0 )
275+ let pitchBuckets = pitch. asType ( Int32 . self)
271276
272- // 6. Generator
273- guard let generator = generator else {
274- throw NSError ( domain: " RVCInference " , code: 1 , userInfo: [ NSLocalizedDescriptionKey: " Generator missing " ] )
277+ // 5. Sync lengths
278+ let p_len_val = min ( phone. shape [ 1 ] , f0Hz. shape [ 1 ] )
279+ phone = phone [ 0 ... , 0 ..< p_len_val, 0 ... ]
280+ let nsff0 = f0Hz [ 0 ... , 0 ..< p_len_val] . expandedDimensions ( axis: 2 )
281+ let pitchFinal = pitchBuckets [ 0 ... , 0 ..< p_len_val]
282+ let phoneLengths = MLXArray ( [ Int32 ( p_len_val) ] )
283+
284+ // 6. Synthesizer Inference
285+ guard let synthesizer = synthesizer else {
286+ throw NSError ( domain: " RVCInference " , code: 1 , userInfo: [ NSLocalizedDescriptionKey: " Synthesizer missing " ] )
275287 }
276288
277- let audioOut = generator ( phone, f0: f0)
278- return audioOut
289+ let sid = MLXArray ( [ Int32 ( 0 ) ] ) // Default speaker 0
290+
291+ log ( " DEBUG: Before Synthesizer - phone: \( phone. shape) , pitch: \( pitchFinal. shape) , nsff0: \( nsff0. shape) " )
292+
293+ let audioOut = synthesizer. infer (
294+ phone: phone,
295+ phoneLengths: phoneLengths,
296+ pitch: pitchFinal,
297+ nsff0: nsff0,
298+ sid: sid
299+ )
300+
301+ return audioOut // [1, T_out, 1]
279302 }
280303 }
0 commit comments