|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Convert Parakeet TDT model from NeMo format to ggml format |
| 3 | +# |
| 4 | +# Usage: python convert-parakeet-to-ggml.py --model parakeet-model.nemo --output-dir output-dir [--use-f32] |
| 5 | +# |
| 6 | +# The NeMo file is a tar archive containing: |
| 7 | +# - model_weights.ckpt (PyTorch checkpoint) |
| 8 | +# - model_config.yaml (model configuration) |
| 9 | +# - tokenizer files (BPE tokenizer) |
| 10 | +# |
| 11 | +# This script extracts the NeMo archive, loads the model weights and configuration, |
| 12 | +# and saves them in ggml format compatible with whisper.cpp. |
| 13 | +# |
| 14 | + |
| 15 | +import torch |
| 16 | +import argparse |
| 17 | +import io |
| 18 | +import os |
| 19 | +import sys |
| 20 | +import struct |
| 21 | +import tarfile |
| 22 | +import tempfile |
| 23 | +import shutil |
| 24 | +import yaml |
| 25 | +import numpy as np |
| 26 | +from pathlib import Path |
| 27 | +from typing import Optional |
| 28 | + |
| 29 | +def hz_to_mel(freq): |
| 30 | + """Convert Hz to mel scale""" |
| 31 | + return 2595.0 * np.log10(1.0 + freq / 700.0) |
| 32 | + |
| 33 | +def mel_to_hz(mel): |
| 34 | + """Convert mel scale to Hz""" |
| 35 | + return 700.0 * (10.0**(mel / 2595.0) - 1.0) |
| 36 | + |
| 37 | +def create_mel_filterbank( |
| 38 | + sample_rate: int = 16000, |
| 39 | + n_fft: int = 512, |
| 40 | + n_mels: int = 128, |
| 41 | + fmin: float = 0.0, |
| 42 | + fmax: Optional[float] = None |
| 43 | +) -> np.ndarray: |
| 44 | + """ |
| 45 | + Create mel filterbank matrix compatible with Whisper's implementation. |
| 46 | +
|
| 47 | + Args: |
| 48 | + sample_rate: Audio sample rate (Hz) |
| 49 | + n_fft: FFT size |
| 50 | + n_mels: Number of mel bands |
| 51 | + fmin: Minimum frequency (Hz) |
| 52 | + fmax: Maximum frequency (Hz), defaults to sample_rate/2 |
| 53 | +
|
| 54 | + Returns: |
| 55 | + Mel filterbank matrix of shape (n_mels, n_fft//2 + 1) |
| 56 | + """ |
| 57 | + if fmax is None: |
| 58 | + fmax = float(sample_rate / 2) |
| 59 | + |
| 60 | + # Number of FFT frequency bins |
| 61 | + n_freqs = n_fft // 2 + 1 |
| 62 | + |
| 63 | + # FFT bin frequencies |
| 64 | + fft_freqs = np.linspace(0, sample_rate / 2, n_freqs) |
| 65 | + |
| 66 | + # Mel scale boundaries |
| 67 | + mel_min = hz_to_mel(fmin) |
| 68 | + mel_max = hz_to_mel(fmax) |
| 69 | + |
| 70 | + # Equally spaced mel points |
| 71 | + mel_points = np.linspace(mel_min, mel_max, n_mels + 2) |
| 72 | + hz_points = mel_to_hz(mel_points) |
| 73 | + |
| 74 | + # Convert Hz to FFT bin indices |
| 75 | + bin_points = np.floor((n_fft + 1) * hz_points / sample_rate).astype(int) |
| 76 | + |
| 77 | + # Create filterbank |
| 78 | + filterbank = np.zeros((n_mels, n_freqs)) |
| 79 | + |
| 80 | + for m in range(n_mels): |
| 81 | + # Left, center, right points for this filter |
| 82 | + left = bin_points[m] |
| 83 | + center = bin_points[m + 1] |
| 84 | + right = bin_points[m + 2] |
| 85 | + |
| 86 | + # Rising slope |
| 87 | + for k in range(left, center): |
| 88 | + if center != left: |
| 89 | + filterbank[m, k] = (k - left) / (center - left) |
| 90 | + |
| 91 | + # Falling slope |
| 92 | + for k in range(center, right): |
| 93 | + if right != center: |
| 94 | + filterbank[m, k] = (right - k) / (right - center) |
| 95 | + |
| 96 | + # Normalize filters to have unit area (like librosa) |
| 97 | + enorm = 2.0 / (hz_points[2:n_mels+2] - hz_points[:n_mels]) |
| 98 | + filterbank *= enorm[:, np.newaxis] |
| 99 | + |
| 100 | + return filterbank.astype(np.float32) |
| 101 | + |
| 102 | +def extract_nemo_archive(nemo_path, extract_dir): |
| 103 | + """Extract .nemo archive to temporary directory""" |
| 104 | + print(f"Extracting {nemo_path} to {extract_dir}") |
| 105 | + with tarfile.open(nemo_path, 'r') as tar: |
| 106 | + tar.extractall(path=extract_dir) |
| 107 | + print("Extraction complete") |
| 108 | + |
| 109 | +def load_model_config(config_path): |
| 110 | + """Load model configuration from YAML""" |
| 111 | + with open(config_path, 'r') as f: |
| 112 | + config = yaml.safe_load(f) |
| 113 | + return config |
| 114 | + |
| 115 | +def load_tokenizer(extract_dir, config): |
| 116 | + """Load BPE tokenizer from NeMo files""" |
| 117 | + # NeMo uses sentencepiece BPE tokenizer |
| 118 | + tokenizer_model_path = None |
| 119 | + tokenizer_vocab_path = None |
| 120 | + |
| 121 | + # Find tokenizer files - prefer .vocab file which has all 8192 tokens with special tokens |
| 122 | + for file in os.listdir(extract_dir): |
| 123 | + if file.endswith('_tokenizer.model'): |
| 124 | + tokenizer_model_path = os.path.join(extract_dir, file) |
| 125 | + elif file.endswith('tokenizer.vocab'): |
| 126 | + tokenizer_vocab_path = os.path.join(extract_dir, file) |
| 127 | + |
| 128 | + if not tokenizer_model_path: |
| 129 | + raise FileNotFoundError("Tokenizer model file not found") |
| 130 | + |
| 131 | + if not tokenizer_vocab_path: |
| 132 | + raise FileNotFoundError("Tokenizer vocab file not found") |
| 133 | + |
| 134 | + # Load complete vocabulary from .vocab file (SentencePiece format: token\tscore) |
| 135 | + # This file contains all 8192 tokens in the correct order including special tokens |
| 136 | + tokens = {} |
| 137 | + with open(tokenizer_vocab_path, 'r', encoding='utf-8') as f: |
| 138 | + for idx, line in enumerate(f): |
| 139 | + parts = line.strip().split('\t') |
| 140 | + if len(parts) >= 1: |
| 141 | + token = parts[0] |
| 142 | + tokens[token.encode('utf-8')] = idx |
| 143 | + |
| 144 | + print(f"Loaded {len(tokens)} tokens from {os.path.basename(tokenizer_vocab_path)}") |
| 145 | + |
| 146 | + if len(tokens) != 8192: |
| 147 | + print(f"WARNING: Expected 8192 tokens, got {len(tokens)}") |
| 148 | + |
| 149 | + return tokens |
| 150 | + |
| 151 | +def convert_parakeet_to_ggml(nemo_path, output_dir, use_f16=True, out_name=None): |
| 152 | + nemo_path = Path(nemo_path) |
| 153 | + output_dir = Path(output_dir) |
| 154 | + output_dir.mkdir(parents=True, exist_ok=True) |
| 155 | + |
| 156 | + # Create temporary directory for extraction |
| 157 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 158 | + extract_nemo_archive(nemo_path, temp_dir) |
| 159 | + |
| 160 | + config_path = os.path.join(temp_dir, 'model_config.yaml') |
| 161 | + config = load_model_config(config_path) |
| 162 | + |
| 163 | + print("Model configuration:") |
| 164 | + print(f" Sample rate: {config['sample_rate']}") |
| 165 | + print(f" Encoder layers: {config['encoder']['n_layers']}") |
| 166 | + print(f" Encoder d_model: {config['encoder']['d_model']}") |
| 167 | + print(f" Mel features: {config['preprocessor']['features']}") |
| 168 | + |
| 169 | + weights_path = os.path.join(temp_dir, 'model_weights.ckpt') |
| 170 | + print(f"\nLoading model weights from {weights_path}") |
| 171 | + checkpoint = torch.load(weights_path, map_location='cpu') |
| 172 | + |
| 173 | + # Extract state dict |
| 174 | + if 'state_dict' in checkpoint: |
| 175 | + state_dict = checkpoint['state_dict'] |
| 176 | + else: |
| 177 | + state_dict = checkpoint |
| 178 | + |
| 179 | + print(f"Loaded {len(state_dict)} tensors") |
| 180 | + |
| 181 | + # Load tokenizer |
| 182 | + print("\nLoading tokenizer...") |
| 183 | + tokens = load_tokenizer(temp_dir, config) |
| 184 | + print(f"Loaded {len(tokens)} tokens") |
| 185 | + |
| 186 | + # Prepare hyperparameters for ggml format |
| 187 | + # Map Parakeet config to Whisper-like structure |
| 188 | + hparams = { |
| 189 | + 'n_audio_ctx': 1500, # Parakeet uses different context, approximate |
| 190 | + 'n_audio_state': config['encoder']['d_model'], |
| 191 | + 'n_audio_head': config['encoder']['n_heads'], |
| 192 | + 'n_audio_layer': config['encoder']['n_layers'], |
| 193 | + 'n_text_ctx': 448, # Placeholder - Parakeet TDT doesn't have decoder |
| 194 | + 'n_text_state': config.get('model_defaults', {}).get('pred_hidden', 640), |
| 195 | + 'n_text_head': 8, # Placeholder |
| 196 | + 'n_text_layer': 0, # No text decoder layers |
| 197 | + 'n_mels': config['preprocessor']['features'], |
| 198 | + 'n_fft': config['preprocessor']['n_fft'], |
| 199 | + 'subsampling_factor': config['encoder']['subsampling_factor'], |
| 200 | + 'n_subsampling_channels': config['encoder']['subsampling_conv_channels'], |
| 201 | + 'n_pos_max_len': config['encoder']['pos_emb_max_len'], |
| 202 | + |
| 203 | + 'n_pred_dim': config['decoder']['prednet']['pred_hidden'], |
| 204 | + 'n_pred_layers': config['decoder']['prednet']['pred_rnn_layers'], |
| 205 | + 'n_vocab': config['decoder']['vocab_size'], |
| 206 | + } |
| 207 | + |
| 208 | + print("\nGGML hyperparameters:") |
| 209 | + for key, value in hparams.items(): |
| 210 | + print(f" {key}: {value}") |
| 211 | + |
| 212 | + # Create output file |
| 213 | + if out_name: |
| 214 | + fname_out = output_dir / out_name |
| 215 | + else: |
| 216 | + fname_out = output_dir / ("ggml-model-f32.bin" if not use_f16 else "ggml-model.bin") |
| 217 | + print(f"\nWriting to {fname_out}") |
| 218 | + |
| 219 | + with open(fname_out, 'wb') as fout: |
| 220 | + # Write magic number |
| 221 | + fout.write(struct.pack("i", 0x67676d6c)) # 'ggml' in hex |
| 222 | + |
| 223 | + # Write hyperparameters |
| 224 | + fout.write(struct.pack("i", hparams['n_vocab'])) |
| 225 | + fout.write(struct.pack("i", hparams['n_audio_ctx'])) |
| 226 | + fout.write(struct.pack("i", hparams['n_audio_state'])) |
| 227 | + fout.write(struct.pack("i", hparams['n_audio_head'])) |
| 228 | + fout.write(struct.pack("i", hparams['n_audio_layer'])) |
| 229 | + fout.write(struct.pack("i", hparams['n_text_ctx'])) |
| 230 | + fout.write(struct.pack("i", hparams['n_text_state'])) |
| 231 | + fout.write(struct.pack("i", hparams['n_text_head'])) |
| 232 | + fout.write(struct.pack("i", hparams['n_text_layer'])) |
| 233 | + fout.write(struct.pack("i", hparams['n_mels'])) |
| 234 | + fout.write(struct.pack("i", 1 if use_f16 else 0)) |
| 235 | + fout.write(struct.pack("i", hparams['n_fft'])) |
| 236 | + fout.write(struct.pack("i", hparams['subsampling_factor'])) |
| 237 | + fout.write(struct.pack("i", hparams['n_subsampling_channels'])) |
| 238 | + fout.write(struct.pack("i", hparams['n_pos_max_len'])) |
| 239 | + fout.write(struct.pack("i", hparams['n_pred_dim'])) |
| 240 | + fout.write(struct.pack("i", hparams['n_pred_layers'])) |
| 241 | + |
| 242 | + # Extract mel filterbank from model |
| 243 | + fb_key = None |
| 244 | + for key in state_dict.keys(): |
| 245 | + if 'featurizer.fb' in key or 'filterbank' in key.lower(): |
| 246 | + fb_key = key |
| 247 | + break |
| 248 | + |
| 249 | + if not fb_key: |
| 250 | + print("\nERROR: Mel filterbank not found in model!") |
| 251 | + print("Expected tensor with 'featurizer.fb' or 'filterbank' in name") |
| 252 | + print("\nAvailable preprocessor tensors:") |
| 253 | + for key in sorted(state_dict.keys()): |
| 254 | + if 'preprocessor' in key or 'featurizer' in key: |
| 255 | + print(f" {key}: {state_dict[key].shape}") |
| 256 | + raise ValueError("Mel filterbank tensor not found in model") |
| 257 | + |
| 258 | + print(f"\nUsing model's mel filterbank from: {fb_key}") |
| 259 | + mel_filters = state_dict[fb_key].squeeze().numpy().astype(np.float32) |
| 260 | + print(f" Filterbank shape: {mel_filters.shape}") |
| 261 | + |
| 262 | + if len(mel_filters.shape) != 2: |
| 263 | + raise ValueError(f"Expected 2D filterbank, got shape {mel_filters.shape}") |
| 264 | + |
| 265 | + n_mels, n_freqs = mel_filters.shape |
| 266 | + fout.write(struct.pack("i", n_mels)) # n_mel |
| 267 | + fout.write(struct.pack("i", n_freqs)) # n_fb (frequency bins) |
| 268 | + |
| 269 | + # Write mel filterbank |
| 270 | + for i in range(n_mels): |
| 271 | + for j in range(n_freqs): |
| 272 | + fout.write(struct.pack("f", mel_filters[i, j])) |
| 273 | + |
| 274 | + # Extract window function from model |
| 275 | + window_key = None |
| 276 | + for key in state_dict.keys(): |
| 277 | + if 'featurizer.window' in key or 'preproc' in key and 'window' in key: |
| 278 | + window_key = key |
| 279 | + break |
| 280 | + |
| 281 | + if not window_key: |
| 282 | + print("\nERROR: Window function not found in model!") |
| 283 | + print("Expected tensor with 'featurizer.window' in name") |
| 284 | + raise ValueError("Window function tensor not found in model") |
| 285 | + |
| 286 | + print(f"\nUsing model's window function from: {window_key}") |
| 287 | + window = state_dict[window_key].squeeze().numpy().astype(np.float32) |
| 288 | + print(f" Window shape: {window.shape}") |
| 289 | + |
| 290 | + if len(window.shape) != 1: |
| 291 | + raise ValueError(f"Expected 1D window, got shape {window.shape}") |
| 292 | + |
| 293 | + n_window = window.shape[0] |
| 294 | + fout.write(struct.pack("i", n_window)) |
| 295 | + |
| 296 | + # Write window function |
| 297 | + for i in range(n_window): |
| 298 | + fout.write(struct.pack("f", window[i])) |
| 299 | + |
| 300 | + fout.write(struct.pack("i", len(tokens))) |
| 301 | + for token_bytes, idx in sorted(tokens.items(), key=lambda x: x[1]): |
| 302 | + fout.write(struct.pack("i", len(token_bytes))) |
| 303 | + fout.write(token_bytes) |
| 304 | + |
| 305 | + print("\nConverting model weights...") |
| 306 | + for name, tensor in state_dict.items(): |
| 307 | + # Skip the filterbank and window - already written in preprocessing section |
| 308 | + if name == fb_key: |
| 309 | + print(f"Skipping {name} (already written as mel filterbank)") |
| 310 | + continue |
| 311 | + if name == window_key: |
| 312 | + print(f"Skipping {name} (already written as window function)") |
| 313 | + continue |
| 314 | + |
| 315 | + # Don't squeeze Conv2d weights - they need to preserve all 4 dimensions |
| 316 | + if 'conv' in name and 'weight' in name and len(tensor.shape) == 4: |
| 317 | + data = tensor.numpy() |
| 318 | + else: |
| 319 | + data = tensor.squeeze().numpy() |
| 320 | + |
| 321 | + # Reshape Conv2d bias from [out_channels] to [1, out_channels, 1, 1] for broadcasting |
| 322 | + # This will be written reversed as [1, 1, out_channels, 1] in the file |
| 323 | + # which matches ggml conv2d output layout [W, H, C, batch] |
| 324 | + if 'pre_encode.conv' in name and 'bias' in name and len(data.shape) == 1: |
| 325 | + data = data.reshape(1, -1, 1, 1) |
| 326 | + print(f" Reshaped conv bias {name} to {data.shape}") |
| 327 | + |
| 328 | + n_dims = len(data.shape) |
| 329 | + |
| 330 | + ftype = 1 if use_f16 else 0 |
| 331 | + if use_f16: |
| 332 | + # Keep some tensors in f32 for better accuracy |
| 333 | + if n_dims < 2 or 'bias' in name or 'norm' in name: |
| 334 | + data = data.astype(np.float32) |
| 335 | + ftype = 0 |
| 336 | + else: |
| 337 | + data = data.astype(np.float16) |
| 338 | + else: |
| 339 | + data = data.astype(np.float32) |
| 340 | + |
| 341 | + dims_reversed = [data.shape[n_dims - 1 - i] for i in range(n_dims)] |
| 342 | + print(f"Processing: {name} {list(data.shape)}, dtype: {data.dtype}, n_dims: {n_dims}, reversed: {dims_reversed}") |
| 343 | + name_bytes = name.encode('utf-8') |
| 344 | + fout.write(struct.pack("iii", n_dims, len(name_bytes), ftype)) |
| 345 | + for i in range(n_dims): |
| 346 | + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) |
| 347 | + fout.write(name_bytes) |
| 348 | + |
| 349 | + data.tofile(fout) |
| 350 | + |
| 351 | + print(f"\nConversion complete!") |
| 352 | + print(f"Output file: {fname_out}") |
| 353 | + print(f"File size: {fname_out.stat().st_size / (1024**2):.2f} MB") |
| 354 | + |
| 355 | +if __name__ == '__main__': |
| 356 | + parser = argparse.ArgumentParser( |
| 357 | + description='Convert Parakeet TDT model from NeMo format to ggml format' |
| 358 | + ) |
| 359 | + parser.add_argument('--model', type=str, required=True, |
| 360 | + help='Path to Parakeet .nemo model file') |
| 361 | + parser.add_argument('--out-dir', type=str, required=True, |
| 362 | + help='Directory to write ggml model file') |
| 363 | + parser.add_argument('--use-f32', action='store_true', default=False, |
| 364 | + help='Use f32 instead of f16 (default: f16)') |
| 365 | + parser.add_argument('--out-name', type=str, default=None, |
| 366 | + help='Output file name (default: ggml-model.bin or ggml-model-f32.bin)') |
| 367 | + |
| 368 | + args = parser.parse_args() |
| 369 | + |
| 370 | + if not os.path.exists(args.model): |
| 371 | + print(f"Error: {args.model} not found") |
| 372 | + sys.exit(1) |
| 373 | + |
| 374 | + use_f16 = not args.use_f32 |
| 375 | + convert_parakeet_to_ggml(args.model, args.out_dir, use_f16, args.out_name) |
0 commit comments