Skip to content

Commit 3472132

Browse files
stephencox-ictstephencoxBlaizzy
authored
fix: Gemma 4 audio — mel preprocessing, weight loading, feature extractor (#931)
* fix: Gemma 4 audio — mel preprocessing, weight loading, feature extractor Fix four bugs preventing Gemma 4 audio from working: 1. Missing semicausal left-padding in audio feature extractor. The HF reference prepends frame_length//2 (160) zero samples before the unfold, centering the first frame at t=0. Without this, the mel spectrogram is misaligned and the frame count is wrong, which also causes the broadcast shapes error (issue #923). 2. Wrong Hann window formula. Used cos(2*pi*(n+0.5)/N) instead of the correct periodic Hann cos(2*pi*n/N). The +0.5 phase shift produces meaningfully different spectral values from what the model was trained on. 3. sanitize() double-nests language_model weights (issue #912). HF keys like model.language_model.model.embed_tokens.weight become language_model.model.embed_tokens.weight after stripping model., which already matches the MLX path. The unconditional insertion of .model. created language_model.model.model.*, so all LM weights loaded as zero. 4. Feature extractor not instantiated (issue #903). Only created when processor_config.json contains a "feature_extractor" key, which standard HF checkpoints don't include. Now instantiates with USM defaults unconditionally. Fixes #903, #912, #923 * format * format * Update audio feature extractor in Gemma4 model to match hf --------- Co-authored-by: Stephen Cox <stephencoxmail@gmail.com> Co-authored-by: Prince Canuma <prince.gdt@gmail.com>
1 parent b2cffea commit 3472132

3 files changed

Lines changed: 46 additions & 21 deletions

File tree

mlx_vlm/models/gemma4/audio_feature_extractor.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def __init__(
126126
max_frequency: float = 8000.0,
127127
preemphasis: float = 0.0,
128128
preemphasis_htk_flavor: bool = True,
129-
fft_overdrive: bool = True,
129+
fft_overdrive: bool = False,
130130
dither: float = 0.0,
131131
input_scale_factor: float = 1.0,
132132
mel_floor: float = 1e-3,
@@ -153,12 +153,14 @@ def __init__(
153153
fft_length *= 2
154154
self.fft_length = fft_length
155155

156-
# Hanning window (non-zero at endpoints)
157-
arg = math.pi * 2.0 / self.frame_length
158-
window = 0.5 - (
159-
0.5 * np.cos(arg * (np.arange(self.frame_length, dtype=np.float32) + 0.5))
156+
# Periodic Hann window: w[n] = 0.5 - 0.5 * cos(2*pi*n / frame_length)
157+
# Matches HuggingFace Transformers (signal.hann_window with periodic=True)
158+
self.window = 0.5 - 0.5 * np.cos(
159+
2.0
160+
* np.pi
161+
* np.arange(self.frame_length, dtype=np.float32)
162+
/ self.frame_length
160163
)
161-
self.window = window.astype(np.float32)
162164

163165
# Mel filter bank
164166
try:
@@ -209,6 +211,14 @@ def _extract_spectrogram(
209211
if self.input_scale_factor != 1.0:
210212
waveform = waveform * self.input_scale_factor
211213

214+
# Semicausal left-padding: prepend frame_length // 2 zeros so that
215+
# the first frame is centered at t=0, matching HuggingFace Transformers
216+
pad_left = self.frame_length // 2
217+
waveform = np.pad(waveform, ((0, 0), (pad_left, 0)), mode="constant")
218+
attention_mask = np.pad(
219+
attention_mask, (pad_left, 0), mode="constant", constant_values=0
220+
)
221+
212222
frame_size_for_unfold = self.frame_length + 1
213223

214224
frames_to_process = _unfold(
@@ -239,7 +249,7 @@ def _extract_spectrogram(
239249

240250
magnitude_spec = np.abs(stft)
241251
mel_spec = np.matmul(magnitude_spec, self.mel_filters)
242-
log_mel_spec = np.log(np.maximum(mel_spec, self.mel_floor))
252+
log_mel_spec = np.log(mel_spec + self.mel_floor)
243253

244254
if self.per_bin_mean is not None:
245255
log_mel_spec = log_mel_spec - self.per_bin_mean
@@ -248,8 +258,13 @@ def _extract_spectrogram(
248258
log_mel_spec = log_mel_spec / self.per_bin_stddev
249259

250260
mel_spectrogram = log_mel_spec.squeeze(0)
251-
mask = attention_mask[:: self.hop_length].astype(bool)
252-
return mel_spectrogram, mask[: mel_spectrogram.shape[0]]
261+
num_mel_frames = mel_spectrogram.shape[0]
262+
263+
frame_end_indices = (
264+
np.arange(num_mel_frames) * self.hop_length + frame_size_for_unfold - 1
265+
)
266+
mask = attention_mask[frame_end_indices].astype(bool)
267+
return mel_spectrogram, mask
253268

254269
def _pad_waveforms(self, waveforms, max_length=None, pad_to_multiple_of=None):
255270
"""Pad a list of waveforms to equal length."""
@@ -341,6 +356,12 @@ def __call__(
341356
prepared_speech.append(spec.astype(np.float32))
342357
prepared_speech_mask.append(spec_mask)
343358

359+
# Zero out padded spectrogram positions, matching HuggingFace Transformers
360+
prepared_speech = [
361+
spec * m[..., None]
362+
for spec, m in zip(prepared_speech, prepared_speech_mask)
363+
]
364+
344365
return {
345366
"input_features": prepared_speech,
346367
"input_features_mask": prepared_speech_mask,

mlx_vlm/models/gemma4/gemma4.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,9 @@ def sanitize(self, weights):
195195
else:
196196
new_key = k
197197

198-
if new_key.startswith("language_model."):
198+
if new_key.startswith("language_model.") and not new_key.startswith(
199+
"language_model.model."
200+
):
199201
rest = new_key[len("language_model.") :]
200202
new_key = "language_model.model." + rest
201203

mlx_vlm/models/gemma4/processing_gemma4.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -548,20 +548,22 @@ def _load_json(path):
548548

549549
image_processor = Gemma4ImageProcessor(**ip_config)
550550

551-
# Load audio feature extractor if config available
551+
# Load audio feature extractor.
552+
# The standard HF checkpoint does not include a "feature_extractor" key
553+
# in processor_config.json, so we instantiate with defaults when the
554+
# config is missing — the USM parameters are fixed for all Gemma 4 models.
552555
feature_extractor = None
553-
if fe_config:
554-
try:
555-
from .audio_feature_extractor import Gemma4AudioFeatureExtractor
556+
try:
557+
from .audio_feature_extractor import Gemma4AudioFeatureExtractor
556558

557-
feature_extractor = Gemma4AudioFeatureExtractor(**fe_config)
558-
except ImportError:
559-
try:
560-
from transformers import Gemma4AudioFeatureExtractor
559+
feature_extractor = Gemma4AudioFeatureExtractor(**(fe_config or {}))
560+
except ImportError:
561+
try:
562+
from transformers import Gemma4AudioFeatureExtractor
561563

562-
feature_extractor = Gemma4AudioFeatureExtractor(**fe_config)
563-
except (ImportError, Exception):
564-
pass
564+
feature_extractor = Gemma4AudioFeatureExtractor(**(fe_config or {}))
565+
except (ImportError, Exception):
566+
pass
565567

566568
image_seq_length = ip_config.get("max_soft_tokens", 280)
567569
audio_seq_length = proc_config.get("audio_seq_length", 750)

0 commit comments

Comments
 (0)