Skip to content

Commit 94c2a7a

Browse files
committed
parakeet : add support for NVIDIA Parakeet (wip)
This is a work in progress to support the Parakeet model.
1 parent dc96116 commit 94c2a7a

11 files changed

Lines changed: 5666 additions & 0 deletions

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@ target_compile_definitions(whisper PRIVATE
185185
WHISPER_VERSION="${PROJECT_VERSION}"
186186
)
187187

188+
target_compile_definitions(parakeet PRIVATE
189+
PARAKEET_VERSION="${PROJECT_VERSION}"
190+
)
191+
188192
configure_package_config_file(
189193
${CMAKE_CURRENT_SOURCE_DIR}/cmake/whisper-config.cmake.in
190194
${CMAKE_CURRENT_BINARY_DIR}/whisper-config.cmake

include/parakeet.h

Lines changed: 455 additions & 0 deletions
Large diffs are not rendered by default.
11.8 KB
Binary file not shown.

models/convert-parakeet-to-ggml.py

Lines changed: 375 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,375 @@
1+
#!/usr/bin/env python3
2+
# Convert Parakeet TDT model from NeMo format to ggml format
3+
#
4+
# Usage: python convert-parakeet-to-ggml.py --model parakeet-model.nemo --output-dir output-dir [--use-f32]
5+
#
6+
# The NeMo file is a tar archive containing:
7+
# - model_weights.ckpt (PyTorch checkpoint)
8+
# - model_config.yaml (model configuration)
9+
# - tokenizer files (BPE tokenizer)
10+
#
11+
# This script extracts the NeMo archive, loads the model weights and configuration,
12+
# and saves them in ggml format compatible with whisper.cpp.
13+
#
14+
15+
import torch
16+
import argparse
17+
import io
18+
import os
19+
import sys
20+
import struct
21+
import tarfile
22+
import tempfile
23+
import shutil
24+
import yaml
25+
import numpy as np
26+
from pathlib import Path
27+
from typing import Optional
28+
29+
def hz_to_mel(freq):
30+
"""Convert Hz to mel scale"""
31+
return 2595.0 * np.log10(1.0 + freq / 700.0)
32+
33+
def mel_to_hz(mel):
34+
"""Convert mel scale to Hz"""
35+
return 700.0 * (10.0**(mel / 2595.0) - 1.0)
36+
37+
def create_mel_filterbank(
38+
sample_rate: int = 16000,
39+
n_fft: int = 512,
40+
n_mels: int = 128,
41+
fmin: float = 0.0,
42+
fmax: Optional[float] = None
43+
) -> np.ndarray:
44+
"""
45+
Create mel filterbank matrix compatible with Whisper's implementation.
46+
47+
Args:
48+
sample_rate: Audio sample rate (Hz)
49+
n_fft: FFT size
50+
n_mels: Number of mel bands
51+
fmin: Minimum frequency (Hz)
52+
fmax: Maximum frequency (Hz), defaults to sample_rate/2
53+
54+
Returns:
55+
Mel filterbank matrix of shape (n_mels, n_fft//2 + 1)
56+
"""
57+
if fmax is None:
58+
fmax = float(sample_rate / 2)
59+
60+
# Number of FFT frequency bins
61+
n_freqs = n_fft // 2 + 1
62+
63+
# FFT bin frequencies
64+
fft_freqs = np.linspace(0, sample_rate / 2, n_freqs)
65+
66+
# Mel scale boundaries
67+
mel_min = hz_to_mel(fmin)
68+
mel_max = hz_to_mel(fmax)
69+
70+
# Equally spaced mel points
71+
mel_points = np.linspace(mel_min, mel_max, n_mels + 2)
72+
hz_points = mel_to_hz(mel_points)
73+
74+
# Convert Hz to FFT bin indices
75+
bin_points = np.floor((n_fft + 1) * hz_points / sample_rate).astype(int)
76+
77+
# Create filterbank
78+
filterbank = np.zeros((n_mels, n_freqs))
79+
80+
for m in range(n_mels):
81+
# Left, center, right points for this filter
82+
left = bin_points[m]
83+
center = bin_points[m + 1]
84+
right = bin_points[m + 2]
85+
86+
# Rising slope
87+
for k in range(left, center):
88+
if center != left:
89+
filterbank[m, k] = (k - left) / (center - left)
90+
91+
# Falling slope
92+
for k in range(center, right):
93+
if right != center:
94+
filterbank[m, k] = (right - k) / (right - center)
95+
96+
# Normalize filters to have unit area (like librosa)
97+
enorm = 2.0 / (hz_points[2:n_mels+2] - hz_points[:n_mels])
98+
filterbank *= enorm[:, np.newaxis]
99+
100+
return filterbank.astype(np.float32)
101+
102+
def extract_nemo_archive(nemo_path, extract_dir):
103+
"""Extract .nemo archive to temporary directory"""
104+
print(f"Extracting {nemo_path} to {extract_dir}")
105+
with tarfile.open(nemo_path, 'r') as tar:
106+
tar.extractall(path=extract_dir)
107+
print("Extraction complete")
108+
109+
def load_model_config(config_path):
110+
"""Load model configuration from YAML"""
111+
with open(config_path, 'r') as f:
112+
config = yaml.safe_load(f)
113+
return config
114+
115+
def load_tokenizer(extract_dir, config):
116+
"""Load BPE tokenizer from NeMo files"""
117+
# NeMo uses sentencepiece BPE tokenizer
118+
tokenizer_model_path = None
119+
tokenizer_vocab_path = None
120+
121+
# Find tokenizer files - prefer .vocab file which has all 8192 tokens with special tokens
122+
for file in os.listdir(extract_dir):
123+
if file.endswith('_tokenizer.model'):
124+
tokenizer_model_path = os.path.join(extract_dir, file)
125+
elif file.endswith('tokenizer.vocab'):
126+
tokenizer_vocab_path = os.path.join(extract_dir, file)
127+
128+
if not tokenizer_model_path:
129+
raise FileNotFoundError("Tokenizer model file not found")
130+
131+
if not tokenizer_vocab_path:
132+
raise FileNotFoundError("Tokenizer vocab file not found")
133+
134+
# Load complete vocabulary from .vocab file (SentencePiece format: token\tscore)
135+
# This file contains all 8192 tokens in the correct order including special tokens
136+
tokens = {}
137+
with open(tokenizer_vocab_path, 'r', encoding='utf-8') as f:
138+
for idx, line in enumerate(f):
139+
parts = line.strip().split('\t')
140+
if len(parts) >= 1:
141+
token = parts[0]
142+
tokens[token.encode('utf-8')] = idx
143+
144+
print(f"Loaded {len(tokens)} tokens from {os.path.basename(tokenizer_vocab_path)}")
145+
146+
if len(tokens) != 8192:
147+
print(f"WARNING: Expected 8192 tokens, got {len(tokens)}")
148+
149+
return tokens
150+
151+
def convert_parakeet_to_ggml(nemo_path, output_dir, use_f16=True, out_name=None):
152+
nemo_path = Path(nemo_path)
153+
output_dir = Path(output_dir)
154+
output_dir.mkdir(parents=True, exist_ok=True)
155+
156+
# Create temporary directory for extraction
157+
with tempfile.TemporaryDirectory() as temp_dir:
158+
extract_nemo_archive(nemo_path, temp_dir)
159+
160+
config_path = os.path.join(temp_dir, 'model_config.yaml')
161+
config = load_model_config(config_path)
162+
163+
print("Model configuration:")
164+
print(f" Sample rate: {config['sample_rate']}")
165+
print(f" Encoder layers: {config['encoder']['n_layers']}")
166+
print(f" Encoder d_model: {config['encoder']['d_model']}")
167+
print(f" Mel features: {config['preprocessor']['features']}")
168+
169+
weights_path = os.path.join(temp_dir, 'model_weights.ckpt')
170+
print(f"\nLoading model weights from {weights_path}")
171+
checkpoint = torch.load(weights_path, map_location='cpu')
172+
173+
# Extract state dict
174+
if 'state_dict' in checkpoint:
175+
state_dict = checkpoint['state_dict']
176+
else:
177+
state_dict = checkpoint
178+
179+
print(f"Loaded {len(state_dict)} tensors")
180+
181+
# Load tokenizer
182+
print("\nLoading tokenizer...")
183+
tokens = load_tokenizer(temp_dir, config)
184+
print(f"Loaded {len(tokens)} tokens")
185+
186+
# Prepare hyperparameters for ggml format
187+
# Map Parakeet config to Whisper-like structure
188+
hparams = {
189+
'n_audio_ctx': 1500, # Parakeet uses different context, approximate
190+
'n_audio_state': config['encoder']['d_model'],
191+
'n_audio_head': config['encoder']['n_heads'],
192+
'n_audio_layer': config['encoder']['n_layers'],
193+
'n_text_ctx': 448, # Placeholder - Parakeet TDT doesn't have decoder
194+
'n_text_state': config.get('model_defaults', {}).get('pred_hidden', 640),
195+
'n_text_head': 8, # Placeholder
196+
'n_text_layer': 0, # No text decoder layers
197+
'n_mels': config['preprocessor']['features'],
198+
'n_fft': config['preprocessor']['n_fft'],
199+
'subsampling_factor': config['encoder']['subsampling_factor'],
200+
'n_subsampling_channels': config['encoder']['subsampling_conv_channels'],
201+
'n_pos_max_len': config['encoder']['pos_emb_max_len'],
202+
203+
'n_pred_dim': config['decoder']['prednet']['pred_hidden'],
204+
'n_pred_layers': config['decoder']['prednet']['pred_rnn_layers'],
205+
'n_vocab': config['decoder']['vocab_size'],
206+
}
207+
208+
print("\nGGML hyperparameters:")
209+
for key, value in hparams.items():
210+
print(f" {key}: {value}")
211+
212+
# Create output file
213+
if out_name:
214+
fname_out = output_dir / out_name
215+
else:
216+
fname_out = output_dir / ("ggml-model-f32.bin" if not use_f16 else "ggml-model.bin")
217+
print(f"\nWriting to {fname_out}")
218+
219+
with open(fname_out, 'wb') as fout:
220+
# Write magic number
221+
fout.write(struct.pack("i", 0x67676d6c)) # 'ggml' in hex
222+
223+
# Write hyperparameters
224+
fout.write(struct.pack("i", hparams['n_vocab']))
225+
fout.write(struct.pack("i", hparams['n_audio_ctx']))
226+
fout.write(struct.pack("i", hparams['n_audio_state']))
227+
fout.write(struct.pack("i", hparams['n_audio_head']))
228+
fout.write(struct.pack("i", hparams['n_audio_layer']))
229+
fout.write(struct.pack("i", hparams['n_text_ctx']))
230+
fout.write(struct.pack("i", hparams['n_text_state']))
231+
fout.write(struct.pack("i", hparams['n_text_head']))
232+
fout.write(struct.pack("i", hparams['n_text_layer']))
233+
fout.write(struct.pack("i", hparams['n_mels']))
234+
fout.write(struct.pack("i", 1 if use_f16 else 0))
235+
fout.write(struct.pack("i", hparams['n_fft']))
236+
fout.write(struct.pack("i", hparams['subsampling_factor']))
237+
fout.write(struct.pack("i", hparams['n_subsampling_channels']))
238+
fout.write(struct.pack("i", hparams['n_pos_max_len']))
239+
fout.write(struct.pack("i", hparams['n_pred_dim']))
240+
fout.write(struct.pack("i", hparams['n_pred_layers']))
241+
242+
# Extract mel filterbank from model
243+
fb_key = None
244+
for key in state_dict.keys():
245+
if 'featurizer.fb' in key or 'filterbank' in key.lower():
246+
fb_key = key
247+
break
248+
249+
if not fb_key:
250+
print("\nERROR: Mel filterbank not found in model!")
251+
print("Expected tensor with 'featurizer.fb' or 'filterbank' in name")
252+
print("\nAvailable preprocessor tensors:")
253+
for key in sorted(state_dict.keys()):
254+
if 'preprocessor' in key or 'featurizer' in key:
255+
print(f" {key}: {state_dict[key].shape}")
256+
raise ValueError("Mel filterbank tensor not found in model")
257+
258+
print(f"\nUsing model's mel filterbank from: {fb_key}")
259+
mel_filters = state_dict[fb_key].squeeze().numpy().astype(np.float32)
260+
print(f" Filterbank shape: {mel_filters.shape}")
261+
262+
if len(mel_filters.shape) != 2:
263+
raise ValueError(f"Expected 2D filterbank, got shape {mel_filters.shape}")
264+
265+
n_mels, n_freqs = mel_filters.shape
266+
fout.write(struct.pack("i", n_mels)) # n_mel
267+
fout.write(struct.pack("i", n_freqs)) # n_fb (frequency bins)
268+
269+
# Write mel filterbank
270+
for i in range(n_mels):
271+
for j in range(n_freqs):
272+
fout.write(struct.pack("f", mel_filters[i, j]))
273+
274+
# Extract window function from model
275+
window_key = None
276+
for key in state_dict.keys():
277+
if 'featurizer.window' in key or 'preproc' in key and 'window' in key:
278+
window_key = key
279+
break
280+
281+
if not window_key:
282+
print("\nERROR: Window function not found in model!")
283+
print("Expected tensor with 'featurizer.window' in name")
284+
raise ValueError("Window function tensor not found in model")
285+
286+
print(f"\nUsing model's window function from: {window_key}")
287+
window = state_dict[window_key].squeeze().numpy().astype(np.float32)
288+
print(f" Window shape: {window.shape}")
289+
290+
if len(window.shape) != 1:
291+
raise ValueError(f"Expected 1D window, got shape {window.shape}")
292+
293+
n_window = window.shape[0]
294+
fout.write(struct.pack("i", n_window))
295+
296+
# Write window function
297+
for i in range(n_window):
298+
fout.write(struct.pack("f", window[i]))
299+
300+
fout.write(struct.pack("i", len(tokens)))
301+
for token_bytes, idx in sorted(tokens.items(), key=lambda x: x[1]):
302+
fout.write(struct.pack("i", len(token_bytes)))
303+
fout.write(token_bytes)
304+
305+
print("\nConverting model weights...")
306+
for name, tensor in state_dict.items():
307+
# Skip the filterbank and window - already written in preprocessing section
308+
if name == fb_key:
309+
print(f"Skipping {name} (already written as mel filterbank)")
310+
continue
311+
if name == window_key:
312+
print(f"Skipping {name} (already written as window function)")
313+
continue
314+
315+
# Don't squeeze Conv2d weights - they need to preserve all 4 dimensions
316+
if 'conv' in name and 'weight' in name and len(tensor.shape) == 4:
317+
data = tensor.numpy()
318+
else:
319+
data = tensor.squeeze().numpy()
320+
321+
# Reshape Conv2d bias from [out_channels] to [1, out_channels, 1, 1] for broadcasting
322+
# This will be written reversed as [1, 1, out_channels, 1] in the file
323+
# which matches ggml conv2d output layout [W, H, C, batch]
324+
if 'pre_encode.conv' in name and 'bias' in name and len(data.shape) == 1:
325+
data = data.reshape(1, -1, 1, 1)
326+
print(f" Reshaped conv bias {name} to {data.shape}")
327+
328+
n_dims = len(data.shape)
329+
330+
ftype = 1 if use_f16 else 0
331+
if use_f16:
332+
# Keep some tensors in f32 for better accuracy
333+
if n_dims < 2 or 'bias' in name or 'norm' in name:
334+
data = data.astype(np.float32)
335+
ftype = 0
336+
else:
337+
data = data.astype(np.float16)
338+
else:
339+
data = data.astype(np.float32)
340+
341+
dims_reversed = [data.shape[n_dims - 1 - i] for i in range(n_dims)]
342+
print(f"Processing: {name} {list(data.shape)}, dtype: {data.dtype}, n_dims: {n_dims}, reversed: {dims_reversed}")
343+
name_bytes = name.encode('utf-8')
344+
fout.write(struct.pack("iii", n_dims, len(name_bytes), ftype))
345+
for i in range(n_dims):
346+
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
347+
fout.write(name_bytes)
348+
349+
data.tofile(fout)
350+
351+
print(f"\nConversion complete!")
352+
print(f"Output file: {fname_out}")
353+
print(f"File size: {fname_out.stat().st_size / (1024**2):.2f} MB")
354+
355+
if __name__ == '__main__':
356+
parser = argparse.ArgumentParser(
357+
description='Convert Parakeet TDT model from NeMo format to ggml format'
358+
)
359+
parser.add_argument('--model', type=str, required=True,
360+
help='Path to Parakeet .nemo model file')
361+
parser.add_argument('--out-dir', type=str, required=True,
362+
help='Directory to write ggml model file')
363+
parser.add_argument('--use-f32', action='store_true', default=False,
364+
help='Use f32 instead of f16 (default: f16)')
365+
parser.add_argument('--out-name', type=str, default=None,
366+
help='Output file name (default: ggml-model.bin or ggml-model-f32.bin)')
367+
368+
args = parser.parse_args()
369+
370+
if not os.path.exists(args.model):
371+
print(f"Error: {args.model} not found")
372+
sys.exit(1)
373+
374+
use_f16 = not args.use_f32
375+
convert_parakeet_to_ggml(args.model, args.out_dir, use_f16, args.out_name)

models/requirements-parakeet.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pyyaml

0 commit comments

Comments
 (0)