@@ -444,48 +444,109 @@ def __init__(self, model_path=None, weights_path=None, device=None):
444444 self .cents_mapping = 20 * np .arange (N_CLASS ) + 1997.3794084376191
445445 self .cents_mapping = np .pad (self .cents_mapping , (4 , 4 ))
446446
447+ def _create_mel_filterbank (self , n_fft , n_mels , sr , fmin , fmax ):
448+ """Create mel filterbank matrix (computed once, cached)."""
449+ # Convert Hz to Mel
450+ def hz_to_mel (hz ):
451+ return 2595 * np .log10 (1 + hz / 700 )
452+
453+ def mel_to_hz (mel ):
454+ return 700 * (10 ** (mel / 2595 ) - 1 )
455+
456+ # Create mel points
457+ mel_min = hz_to_mel (fmin )
458+ mel_max = hz_to_mel (fmax )
459+ mel_points = np .linspace (mel_min , mel_max , n_mels + 2 )
460+ hz_points = mel_to_hz (mel_points )
461+
462+ # FFT bin frequencies
463+ freq_bins = np .fft .rfftfreq (n_fft , 1.0 / sr )
464+
465+ # Create filterbank
466+ filterbank = np .zeros ((n_mels , len (freq_bins )))
467+
468+ for i in range (n_mels ):
469+ left = hz_points [i ]
470+ center = hz_points [i + 1 ]
471+ right = hz_points [i + 2 ]
472+
473+ # Left slope
474+ left_mask = (freq_bins >= left ) & (freq_bins <= center )
475+ filterbank [i , left_mask ] = (freq_bins [left_mask ] - left ) / (center - left + 1e-10 )
476+
477+ # Right slope
478+ right_mask = (freq_bins >= center ) & (freq_bins <= right )
479+ filterbank [i , right_mask ] = (right - freq_bins [right_mask ]) / (right - center + 1e-10 )
480+
481+ return mx .array (filterbank .astype (np .float32 ))
482+
483+ def _create_window (self , win_length ):
484+ """Create Hann window."""
485+ n = np .arange (win_length )
486+ window = 0.5 - 0.5 * np .cos (2 * np .pi * n / win_length )
487+ return mx .array (window .astype (np .float32 ))
488+
447489 def mel_spectrogram (self , audio ):
448- # audio: numpy array (T,) at 16k
449- # Use Librosa to match PyTorch MelSpectrogram settings
450- # n_fft=1024, hop_length=160, win_length=1024, n_mels=128, fmin=30, fmax=8000
451- # center=True
452-
453- mel = librosa .feature .melspectrogram (
454- y = audio ,
455- sr = 16000 ,
456- n_fft = 1024 ,
457- hop_length = 160 ,
458- win_length = 1024 ,
459- n_mels = 128 ,
460- fmin = 30 ,
461- fmax = 8000 ,
462- center = True ,
463- power = 2.0 # Magnitude squared? PyTorch MelSpectrogram usually uses Magnitude only??
464- # Wait. PyTorch implementation in RMVPE.py line 408:
465- # magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
466- # mel_output = torch.matmul(self.mel_basis, magnitude)
467- # log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
490+ """GPU-accelerated mel spectrogram using MLX FFT.
491+
492+ Args:
493+ audio: numpy array (T,) at 16kHz
468494
469- # So: FFT -> Magnitude (NOT Squared) -> Mel Basis -> Log.
470- # librosa.feature.melspectrogram returns POWER spectrogram (Magnitude^2) by default (power=2.0).
471- # We want POWER=1.0 (Magnitude).
472- )
473- # Re-compute with power=1.0
474- mel = librosa .feature .melspectrogram (
475- y = audio ,
476- sr = 16000 ,
477- n_fft = 1024 ,
478- hop_length = 160 ,
479- win_length = 1024 ,
480- n_mels = 128 ,
481- fmin = 30 ,
482- fmax = 8000 ,
483- center = True ,
484- power = 1.0 # Magnitude
485- )
495+ Returns:
496+ log_mel: numpy array (n_mels, num_frames)
497+ """
498+ # Parameters matching RMVPE
499+ n_fft = 1024
500+ hop_length = 160
501+ win_length = 1024
502+ n_mels = 128
503+ sr = 16000
504+ fmin = 30
505+ fmax = 8000
506+
507+ # Create/cache mel filterbank and window
508+ if not hasattr (self , '_mel_filterbank' ):
509+ self ._mel_filterbank = self ._create_mel_filterbank (n_fft , n_mels , sr , fmin , fmax )
510+ if not hasattr (self , '_window' ):
511+ self ._window = self ._create_window (win_length )
512+
513+ # Pad audio for center=True (reflect padding)
514+ pad_len = n_fft // 2
515+ audio_padded = np .pad (audio , (pad_len , pad_len ), mode = 'reflect' )
516+
517+ # Convert to MLX
518+ audio_mx = mx .array (audio_padded .astype (np .float32 ))
519+
520+ # Compute STFT frames
521+ # Number of frames
522+ num_frames = 1 + (len (audio_padded ) - n_fft ) // hop_length
523+
524+ # Extract frames using strided view (vectorized)
525+ # Create frame indices
526+ frame_starts = np .arange (num_frames ) * hop_length
527+ frame_indices = frame_starts [:, None ] + np .arange (n_fft )
528+
529+ # Gather frames
530+ frames = audio_mx [frame_indices ] # (num_frames, n_fft)
531+
532+ # Apply window
533+ frames = frames * self ._window
534+
535+ # FFT
536+ spectrum = mx .fft .rfft (frames , axis = - 1 ) # (num_frames, n_fft//2 + 1)
537+
538+ # Magnitude (not power)
539+ magnitude = mx .abs (spectrum ) # (num_frames, n_fft//2 + 1)
540+
541+ # Apply mel filterbank: (n_mels, n_fft//2+1) @ (n_fft//2+1, num_frames) = (n_mels, num_frames)
542+ mel = self ._mel_filterbank @ magnitude .T # (n_mels, num_frames)
543+
544+ # Log scale with floor
545+ log_mel = mx .log (mx .maximum (mel , 1e-5 ))
486546
487- log_mel = np .log (np .maximum (mel , 1e-5 ))
488- return log_mel
547+ # Force evaluation and convert to numpy
548+ mx .eval (log_mel )
549+ return np .array (log_mel )
489550
490551 def mel2hidden (self , mel , chunk_size = 32000 ):
491552 # mel: (n_mels, T)
0 commit comments