From bad936d00322076f2cf851eca114ee32d1174126 Mon Sep 17 00:00:00 2001 From: antmanler Date: Tue, 14 Apr 2026 16:12:15 +0800 Subject: [PATCH 1/5] feat(mlxlmcommon): add LMInput.ProcessedAudio + UserInput.audios Adds the shared MLXLMCommon hooks required to carry audio inputs from UserInput through the processor to LanguageModel.prepare without introducing any Hub or Tokenizers dependencies on the core module. - LMInput.ProcessedAudio { features: MLXArray, mask: MLXArray? } - LMInput.audio optional field + init parameter - UserInput.audios: [[Float]] for raw 16kHz mono PCM --- Libraries/MLXLMCommon/LanguageModel.swift | 19 ++++++++++++++++++- Libraries/MLXLMCommon/UserInput.swift | 3 +++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/Libraries/MLXLMCommon/LanguageModel.swift b/Libraries/MLXLMCommon/LanguageModel.swift index 01b551fca..572af280d 100644 --- a/Libraries/MLXLMCommon/LanguageModel.swift +++ b/Libraries/MLXLMCommon/LanguageModel.swift @@ -61,6 +61,7 @@ public struct LMInput { public let text: Text public let image: ProcessedImage? public let video: ProcessedVideo? + public let audio: ProcessedAudio? /// Representation of tokenized input text. public struct Text { @@ -120,17 +121,33 @@ public struct LMInput { } } + /// Representation of prepared input audio. + public struct ProcessedAudio { + + /// Mel spectrogram features, shape [batch, frames, melBins] or [frames, melBins]. + public let features: MLXArray + /// Optional attention mask indicating padding frames (True = padding). + public let mask: MLXArray? + + public init(features: MLXArray, mask: MLXArray? = nil) { + self.features = features + self.mask = mask + } + } + public init(tokens: MLXArray, mask: MLXArray? = nil) { self.init(text: .init(tokens: tokens, mask: mask)) } public init( text: LMInput.Text, image: LMInput.ProcessedImage? = nil, - video: LMInput.ProcessedVideo? = nil + video: LMInput.ProcessedVideo? = nil, + audio: LMInput.ProcessedAudio? = nil ) { self.text = text self.image = image self.video = video + self.audio = audio } } diff --git a/Libraries/MLXLMCommon/UserInput.swift b/Libraries/MLXLMCommon/UserInput.swift index 8aac26e71..eae3f00d7 100644 --- a/Libraries/MLXLMCommon/UserInput.swift +++ b/Libraries/MLXLMCommon/UserInput.swift @@ -177,6 +177,9 @@ public struct UserInput { /// collect the videos from the chat messages, otherwise these are the stored videos with the ``UserInput``. public var videos = [Video]() + /// Audio inputs as raw PCM float arrays (16 kHz mono expected by Gemma 4). + public var audios = [[Float]]() + public var tools: [ToolSpec]? /// Additional values provided for the chat template rendering context From aa453fae65b27876c4af22402c57e19cdbb71288 Mon Sep 17 00:00:00 2001 From: antmanler Date: Tue, 14 Apr 2026 16:12:41 +0800 Subject: [PATCH 2/5] feat(gemma4): add mel spectrogram feature extractor Implements the audio preprocessing pipeline used by Gemma 4 audio tower: raw 16kHz mono PCM -> framing -> Hanning window -> vDSP RFFT -> HTK mel filter bank (128 bins, 0-8kHz) -> log mel -> MLXArray + mask. Matches the Python mlx-vlm reference implementation within 1e-3 absolute error on the mel fixture bundled with the tests. Key correctness fix: - vDSP FFT output is scaled by 0.5 relative to numpy.fft.rfft so we apply an explicit 0.5 scale when constructing the magnitude spectrum (line ~238). Without this the mel values are exactly 2x the Python reference. Not marked Sendable because MLXArray (the cached mel filter bank) is not Sendable under swift 6 strict concurrency. --- .../Models/Gemma4AudioFeatureExtractor.swift | 263 ++++++++++++++++++ 1 file changed, 263 insertions(+) create mode 100644 Libraries/MLXVLM/Models/Gemma4AudioFeatureExtractor.swift diff --git a/Libraries/MLXVLM/Models/Gemma4AudioFeatureExtractor.swift b/Libraries/MLXVLM/Models/Gemma4AudioFeatureExtractor.swift new file mode 100644 index 000000000..5f68def1c --- /dev/null +++ b/Libraries/MLXVLM/Models/Gemma4AudioFeatureExtractor.swift @@ -0,0 +1,263 @@ +// +// Gemma4AudioFeatureExtractor.swift +// MLXVLM +// +// Audio feature extractor for Gemma 4 — extracts log-mel spectrograms +// from raw audio waveforms using USM preprocessing pipeline. +// +// Ported from: mlx_vlm/models/gemma4/audio_feature_extractor.py +// + +import Accelerate +import Foundation +import MLX + +// MARK: - Mel Filter Bank + +/// Create a mel filter bank matrix [numFrequencyBins, numMelFilters] using HTK scale. +public func gemma4MelFilterBank( + numFrequencyBins: Int, + numMelFilters: Int, + minFrequency: Float, + maxFrequency: Float, + samplingRate: Int +) -> MLXArray { + func hzToMel(_ freq: Float) -> Float { + 2595.0 * log10(1.0 + freq / 700.0) + } + func melToHz(_ mel: Float) -> Float { + 700.0 * (pow(10.0, mel / 2595.0) - 1.0) + } + + let melMin = hzToMel(minFrequency) + let melMax = hzToMel(maxFrequency) + + // Linearly spaced mel points + var melPoints = [Float](repeating: 0, count: numMelFilters + 2) + for i in 0 ..< (numMelFilters + 2) { + melPoints[i] = melMin + Float(i) * (melMax - melMin) / Float(numMelFilters + 1) + } + let freqPoints = melPoints.map { melToHz($0) } + + // All frequency bins + var allFreqs = [Float](repeating: 0, count: numFrequencyBins) + let freqStep = Float(samplingRate) / Float(2 * (numFrequencyBins - 1)) + for i in 0 ..< numFrequencyBins { + allFreqs[i] = Float(i) * freqStep + } + + // Build triangular filter bank + var filterBank = [Float](repeating: 0, count: numFrequencyBins * numMelFilters) + for i in 0 ..< numMelFilters { + let lower = freqPoints[i] + let center = freqPoints[i + 1] + let upper = freqPoints[i + 2] + + for j in 0 ..< numFrequencyBins { + let rising = (allFreqs[j] - lower) / max(center - lower, 1e-10) + let falling = (upper - allFreqs[j]) / max(upper - center, 1e-10) + filterBank[j * numMelFilters + i] = max(0, min(rising, falling)) + } + } + + return MLXArray(filterBank, [numFrequencyBins, numMelFilters]) +} + +// MARK: - Feature Extractor + +/// Gemma4 audio feature extractor — converts raw waveform to log-mel spectrogram. +public struct Gemma4AudioFeatureExtractor { + public let featureSize: Int + public let samplingRate: Int + public let frameLength: Int + public let hopLength: Int + public let fftLength: Int + public let melFloor: Float + public let preemphasis: Float + public let preemphasisHTKFlavor: Bool + public let inputScaleFactor: Float + + /// Hanning window [frameLength] + private let window: [Float] + /// Mel filter bank [fftLength/2+1, featureSize] + private let melFilters: MLXArray + + public init( + featureSize: Int = 128, + samplingRate: Int = 16000, + frameLengthMs: Float = 20.0, + hopLengthMs: Float = 10.0, + minFrequency: Float = 0.0, + maxFrequency: Float = 8000.0, + preemphasis: Float = 0.0, + preemphasisHTKFlavor: Bool = true, + fftOverdrive: Bool = true, + inputScaleFactor: Float = 1.0, + melFloor: Float = 1e-3 + ) { + self.featureSize = featureSize + self.samplingRate = samplingRate + self.preemphasis = preemphasis + self.preemphasisHTKFlavor = preemphasisHTKFlavor + self.inputScaleFactor = inputScaleFactor + self.melFloor = melFloor + + self.frameLength = Int(round(Float(samplingRate) * frameLengthMs / 1000.0)) + self.hopLength = Int(round(Float(samplingRate) * hopLengthMs / 1000.0)) + + var fftLen = 1 + while fftLen < frameLength { fftLen *= 2 } + if fftOverdrive { fftLen *= 2 } + self.fftLength = fftLen + + // Hanning window (non-zero at endpoints, matching Python) + let arg = Float.pi * 2.0 / Float(frameLength) + var win = [Float](repeating: 0, count: frameLength) + for i in 0 ..< frameLength { + win[i] = 0.5 - 0.5 * cos(arg * (Float(i) + 0.5)) + } + self.window = win + + self.melFilters = gemma4MelFilterBank( + numFrequencyBins: fftLen / 2 + 1, + numMelFilters: featureSize, + minFrequency: minFrequency, + maxFrequency: maxFrequency, + samplingRate: samplingRate + ) + } + + /// Extract log-mel spectrogram from raw audio samples. + /// - Parameters: + /// - audio: Raw waveform [numSamples] as Float array + /// - maxLength: Maximum number of samples (default 480000 = 30s at 16kHz) + /// - Returns: (melSpectrogram: [frames, featureSize], mask: [frames]) + public func extract(audio: [Float], maxLength: Int = 480_000) -> (MLXArray, MLXArray) { + var waveform = audio + if waveform.count > maxLength { + waveform = Array(waveform.prefix(maxLength)) + } + + // Pad to multiple of 128 + let padTarget = ((waveform.count + 127) / 128) * 128 + var mask = [Float](repeating: 1.0, count: padTarget) + if waveform.count < padTarget { + mask.replaceSubrange( + waveform.count ..< padTarget, + with: repeatElement(0.0, count: padTarget - waveform.count)) + waveform.append(contentsOf: repeatElement(0.0, count: padTarget - waveform.count)) + } + + // Scale + if inputScaleFactor != 1.0 { + for i in 0 ..< waveform.count { + waveform[i] *= inputScaleFactor + } + } + + // Frame extraction (unfold) + let frameSizeForUnfold = frameLength + 1 + let numFrames = (waveform.count - frameSizeForUnfold) / hopLength + 1 + guard numFrames > 0 else { + return (MLXArray.zeros([0, featureSize]), MLXArray.zeros([0])) + } + + // Extract frames with preemphasis + var frames = [Float](repeating: 0, count: numFrames * frameLength) + for f in 0 ..< numFrames { + let start = f * hopLength + if preemphasis > 0 && preemphasisHTKFlavor { + frames[f * frameLength] = waveform[start] * (1.0 - preemphasis) + for j in 1 ..< frameLength { + frames[f * frameLength + j] = + waveform[start + j] - preemphasis * waveform[start + j - 1] + } + } else { + for j in 0 ..< frameLength { + frames[f * frameLength + j] = waveform[start + j] + } + } + } + + // Apply window + for f in 0 ..< numFrames { + for j in 0 ..< frameLength { + frames[f * frameLength + j] *= window[j] + } + } + + // RFFT using Accelerate + let melSpec = computeMelSpectrogram(frames: frames, numFrames: numFrames) + + // Build frame-level mask + var frameMask = [Float](repeating: 0, count: numFrames) + for f in 0 ..< numFrames { + let idx = f * hopLength + if idx < mask.count { + frameMask[f] = mask[idx] + } + } + + return (melSpec, MLXArray(frameMask)) + } + + /// Compute mel spectrogram from windowed frames using Accelerate FFT. + private func computeMelSpectrogram(frames: [Float], numFrames: Int) -> MLXArray { + let halfFFT = fftLength / 2 + + // Use vDSP for FFT + let log2n = vDSP_Length(log2(Double(fftLength))) + guard let fftSetup = vDSP_create_fftsetup(log2n, FFTRadix(kFFTRadix2)) else { + // Fallback: zero output + return MLXArray.zeros([numFrames, featureSize]) + } + defer { vDSP_destroy_fftsetup(fftSetup) } + + var allMagnitudes = [Float](repeating: 0, count: numFrames * (halfFFT + 1)) + + for f in 0 ..< numFrames { + // Zero-pad frame to fftLength + var paddedFrame = [Float](repeating: 0, count: fftLength) + for j in 0 ..< frameLength { + paddedFrame[j] = frames[f * frameLength + j] + } + + // Split complex + var realPart = [Float](repeating: 0, count: halfFFT) + var imagPart = [Float](repeating: 0, count: halfFFT) + + // Pack into split complex (even/odd interleave) + for i in 0 ..< halfFFT { + realPart[i] = paddedFrame[2 * i] + imagPart[i] = paddedFrame[2 * i + 1] + } + + var splitComplex = DSPSplitComplex(realp: &realPart, imagp: &imagPart) + vDSP_fft_zrip(fftSetup, &splitComplex, 1, log2n, FFTDirection(kFFTDirection_Forward)) + + // Extract magnitudes + // vDSP_fft_zrip output is scaled by 2x compared to standard DFT. + // Divide by 2 to match numpy.fft.rfft normalization. + let scale: Float = 0.5 + // DC component + allMagnitudes[f * (halfFFT + 1)] = abs(splitComplex.realp[0]) * scale + // Nyquist + allMagnitudes[f * (halfFFT + 1) + halfFFT] = abs(splitComplex.imagp[0]) * scale + // Other bins + for i in 1 ..< halfFFT { + let re = splitComplex.realp[i] + let im = splitComplex.imagp[i] + allMagnitudes[f * (halfFFT + 1) + i] = sqrt(re * re + im * im) * scale + } + } + + // Apply mel filter bank: [numFrames, halfFFT+1] @ [halfFFT+1, featureSize] + let magnitudeArray = MLXArray(allMagnitudes, [numFrames, halfFFT + 1]) + let melSpec = matmul(magnitudeArray, melFilters) + + // Log mel + let logMelSpec = log(maximum(melSpec, MLXArray(melFloor))) + + return logMelSpec.asType(.float32) + } +} From 3b19134e5dccb3644cbfef211e4b8e176c2c17e1 Mon Sep 17 00:00:00 2001 From: antmanler Date: Tue, 14 Apr 2026 16:13:12 +0800 Subject: [PATCH 3/5] =?UTF-8?q?feat(gemma4):=20add=20audio=20tower=20?= =?UTF-8?q?=E2=80=94=20Conformer=20encoder=20+=20model=20wiring?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds Gemma 4 audio input support on top of the text + vision + MoE implementation from #180. New components: - Gemma4AudioConfiguration — decodes config.json's audio_config block, optional so models without audio still load unchanged. - Audio encoder (~700 lines) — USM-style Conformer with: * SubSampleConvProjection (2x Conv2d, 4x temporal reduction) * 12 ConformerBlock layers (FFW -> attention -> light conv1d -> FFW) * Relative positional embedding with sinusoidal projection * Chunked local self-attention with logit softcap * Cumulative group norm for streaming compatibility - Gemma4 model integration: * audio_tower and embed_audio modules, conditional on audio config * getInputEmbeddings extended with audioFeatures/audioMask, scatters encoder output into <|audio|> placeholder positions * prepare() routes LMInput.audio through the new path * sanitize() preserves audio weights only when audio is configured - Processor integration: * Gemma4Processor.prepare extracts mel features, computes the audio token count from actual subsampling math (not a duration estimate), injects <|audio|> placeholders into the prompt, and inverts the mel mask to True=padding convention expected by the encoder * Gemma4ProcessorConfiguration gains audioTokenId with default 258881 for models whose processor_config.json omits it Correctness fixes preserved from platx-ai/mlx-swift-lm@668347f: 1. FFT 0.5 scaling (in Gemma4AudioFeatureExtractor, separate commit) 2. Audio mask polarity inversion (Processor.prepare) 3. Depthwise conv1d groups=hiddenSize in Gemma4ConformerLightConv1d 4. relPosInvTimescales stored as [Float] not MLXArray to avoid MLX Module weight lookup 5. Audio token count from actual mel subsampling, not duration 6. audioTokenId default 258881 when absent 7. Audio placeholder injection into chat template prompt Fix #5 (EOS [1, 106, 50]) lives in the caller's generation loop, not in the library. File size: Gemma4.swift 1910 -> 2740 lines. Build: `swift build` passes on upstream/main base. --- Libraries/MLXVLM/Models/Gemma4.swift | 909 ++++++++++++++++++++++++++- 1 file changed, 880 insertions(+), 29 deletions(-) diff --git a/Libraries/MLXVLM/Models/Gemma4.swift b/Libraries/MLXVLM/Models/Gemma4.swift index cac757f98..b306a935b 100644 --- a/Libraries/MLXVLM/Models/Gemma4.swift +++ b/Libraries/MLXVLM/Models/Gemma4.swift @@ -384,9 +384,69 @@ public struct Gemma4VisionConfiguration: Codable, Sendable { } } +public struct Gemma4AudioConfiguration: Codable, Sendable { + public let hiddenSize: Int + public let numHiddenLayers: Int + public let numAttentionHeads: Int + public let subsamplingConvChannels: [Int] + public let convKernelSize: Int + public let residualWeight: Float + public let attentionChunkSize: Int + public let attentionContextLeft: Int + public let attentionContextRight: Int + public let attentionLogitCap: Float + public let attentionInvalidLogitsValue: Float + public let useClippedLinears: Bool + public let rmsNormEps: Float + public let gradientClipping: Float + public let outputProjDims: Int? + + enum CodingKeys: String, CodingKey { + case hiddenSize = "hidden_size" + case numHiddenLayers = "num_hidden_layers" + case numAttentionHeads = "num_attention_heads" + case subsamplingConvChannels = "subsampling_conv_channels" + case convKernelSize = "conv_kernel_size" + case residualWeight = "residual_weight" + case attentionChunkSize = "attention_chunk_size" + case attentionContextLeft = "attention_context_left" + case attentionContextRight = "attention_context_right" + case attentionLogitCap = "attention_logit_cap" + case attentionInvalidLogitsValue = "attention_invalid_logits_value" + case useClippedLinears = "use_clipped_linears" + case rmsNormEps = "rms_norm_eps" + case gradientClipping = "gradient_clipping" + case outputProjDims = "output_proj_dims" + } + + public init(from decoder: any Swift.Decoder) throws { + let c = try decoder.container(keyedBy: CodingKeys.self) + hiddenSize = try c.decodeIfPresent(Int.self, forKey: .hiddenSize) ?? 1024 + numHiddenLayers = try c.decodeIfPresent(Int.self, forKey: .numHiddenLayers) ?? 12 + numAttentionHeads = try c.decodeIfPresent(Int.self, forKey: .numAttentionHeads) ?? 8 + subsamplingConvChannels = + try c.decodeIfPresent([Int].self, forKey: .subsamplingConvChannels) ?? [128, 32] + convKernelSize = try c.decodeIfPresent(Int.self, forKey: .convKernelSize) ?? 5 + residualWeight = try c.decodeIfPresent(Float.self, forKey: .residualWeight) ?? 0.5 + attentionChunkSize = try c.decodeIfPresent(Int.self, forKey: .attentionChunkSize) ?? 12 + attentionContextLeft = try c.decodeIfPresent(Int.self, forKey: .attentionContextLeft) ?? 13 + attentionContextRight = try c.decodeIfPresent(Int.self, forKey: .attentionContextRight) ?? 0 + attentionLogitCap = try c.decodeIfPresent(Float.self, forKey: .attentionLogitCap) ?? 50.0 + attentionInvalidLogitsValue = + try c.decodeIfPresent(Float.self, forKey: .attentionInvalidLogitsValue) ?? -1e9 + useClippedLinears = + try c.decodeIfPresent(Bool.self, forKey: .useClippedLinears) ?? true + rmsNormEps = try c.decodeIfPresent(Float.self, forKey: .rmsNormEps) ?? 1e-6 + gradientClipping = + try c.decodeIfPresent(Float.self, forKey: .gradientClipping) ?? 1e10 + outputProjDims = try c.decodeIfPresent(Int.self, forKey: .outputProjDims) + } +} + public struct Gemma4Configuration: Codable, Sendable { public let textConfiguration: Gemma4TextConfiguration public let visionConfiguration: Gemma4VisionConfiguration + public let audioConfiguration: Gemma4AudioConfiguration? public let modelType: String public let quantization: BaseConfiguration.Quantization? public let imageTokenId: Int @@ -407,6 +467,7 @@ public struct Gemma4Configuration: Codable, Sendable { enum CodingKeys: String, CodingKey { case textConfiguration = "text_config" case visionConfiguration = "vision_config" + case audioConfiguration = "audio_config" case modelType = "model_type" case quantization case imageTokenId = "image_token_id" @@ -426,6 +487,8 @@ public struct Gemma4Configuration: Codable, Sendable { Gemma4TextConfiguration.self, forKey: CodingKeys.textConfiguration) visionConfiguration = try c.decode( Gemma4VisionConfiguration.self, forKey: CodingKeys.visionConfiguration) + audioConfiguration = try c.decodeIfPresent( + Gemma4AudioConfiguration.self, forKey: CodingKeys.audioConfiguration) modelType = try c.decodeIfPresent(String.self, forKey: CodingKeys.modelType) ?? "gemma4" quantization = try c.decodeIfPresent( BaseConfiguration.Quantization.self, forKey: CodingKeys.quantization) @@ -1641,12 +1704,719 @@ private final class Gemma4MultimodalEmbedder: Module, UnaryLayer { } } +// MARK: - Audio + +private final class Gemma4AudioRMSNorm: Module, UnaryLayer { + let eps: Float + @ModuleInfo var weight: MLXArray + + init(dimensions: Int, eps: Float = 1e-6) { + self.eps = eps + self._weight.wrappedValue = MLXArray.ones([dimensions]) + super.init() + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + MLXFast.rmsNorm(x, weight: weight, eps: eps) + } +} + +private final class Gemma4AudioClippableLinear: Module, UnaryLayer { + @ModuleInfo(key: "linear") var linear: Linear + @ModuleInfo(key: "input_min") var inputMin: MLXArray? + @ModuleInfo(key: "input_max") var inputMax: MLXArray? + @ModuleInfo(key: "output_min") var outputMin: MLXArray? + @ModuleInfo(key: "output_max") var outputMax: MLXArray? + let useClipping: Bool + + init(inFeatures: Int, outFeatures: Int, bias: Bool = false, useClipping: Bool = true) { + self.useClipping = useClipping + self._linear.wrappedValue = Linear(inFeatures, outFeatures, bias: bias) + if useClipping { + self._inputMin.wrappedValue = MLXArray(-Float.infinity) + self._inputMax.wrappedValue = MLXArray(Float.infinity) + self._outputMin.wrappedValue = MLXArray(-Float.infinity) + self._outputMax.wrappedValue = MLXArray(Float.infinity) + } + super.init() + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + let clippedInput: MLXArray + if let inputMin, let inputMax { + clippedInput = clip(x, min: inputMin, max: inputMax) + } else { + clippedInput = x + } + let projected = linear(clippedInput) + if let outputMin, let outputMax { + return clip(projected, min: outputMin, max: outputMax) + } + return projected + } +} + +/// LayerNorm without bias, matching `nn.LayerNorm(dims, bias=False)` in the Python model. +/// The checkpoint stores a single `weight` parameter at the `norm` key. +private final class Gemma4AudioLayerNorm: Module, UnaryLayer { + @ModuleInfo var weight: MLXArray + let eps: Float + + init(dimensions: Int, eps: Float = 1e-6) { + self.eps = eps + self._weight.wrappedValue = MLXArray.ones([dimensions]) + super.init() + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + let xFloat = x.asType(.float32) + let meanVal = MLX.mean(xFloat, axis: -1, keepDims: true) + let variance = MLX.mean((xFloat - meanVal).square(), axis: -1, keepDims: true) + let normalized = (xFloat - meanVal) * rsqrt(variance + eps) + return (normalized * weight.asType(.float32)).asType(x.dtype) + } +} + +private final class Gemma4SSCPConvBlock: Module { + let timeStride: Int = 2 + let padding: (Int, Int, Int, Int) = (1, 1, 1, 1) + + @ModuleInfo(key: "conv") var conv: Conv2d + @ModuleInfo(key: "norm") var norm: Gemma4AudioLayerNorm + + init(config: Gemma4AudioConfiguration, idx: Int) { + let inChannels = idx == 0 ? 1 : config.subsamplingConvChannels[idx - 1] + let outChannels = config.subsamplingConvChannels[idx] + + // Conv2d: MLX expects [B, H, W, C], weight [C_out, kH, kW, C_in] + self._conv.wrappedValue = Conv2d( + inputChannels: inChannels, + outputChannels: outChannels, + kernelSize: 3, + stride: 2, + padding: 0, + bias: false + ) + + self._norm.wrappedValue = Gemma4AudioLayerNorm( + dimensions: outChannels, eps: config.rmsNormEps) + super.init() + } + + func callAsFunction(_ x: MLXArray, mask: MLXArray) -> (MLXArray, MLXArray) { + // x: [B, T, F, C] (MLX channel-last) + // mask: [B, T] (True = invalid/padding) + + // Zero out invalid positions + var x = MLX.where( + expandedDimensions(expandedDimensions(mask, axis: -1), axis: -1), + MLXArray(0.0, dtype: x.dtype), x) + + // Manual padding on T and F dims + x = MLX.padded( + x, + widths: [ + .init((0, 0)), .init((padding.0, padding.1)), + .init((padding.2, padding.3)), .init((0, 0)), + ]) + + x = conv(x) // [B, T_out, F_out, C_out] + + // Downsample mask by time stride + let tOut = x.dim(1) + let downsampled = mask[0..., .stride(by: timeStride)] + let outputMask = downsampled[0..., .. (MLXArray, MLXArray) { + // audioMel: [B, T, F_in] + // Add channel dim: [B, T, F, 1] + var x = expandedDimensions(audioMel, axis: -1) + + var currentMask = mask + (x, currentMask) = layer0(x, mask: currentMask) + (x, currentMask) = layer1(x, mask: currentMask) + + // Flatten F*C -> [B, T, F*C] + let batchSize = x.dim(0) + let timeSteps = x.dim(1) + let freqBins = x.dim(2) + let channels = x.dim(3) + x = x.reshaped(batchSize, timeSteps, freqBins * channels) + + // Project to hidden_size + x = inputProjLinear(x) + + return (x, currentMask) + } +} + +private final class Gemma4ConformerFeedForward: Module { + let gradientClipping: Float + let residualWeight: Float + + @ModuleInfo(key: "pre_layer_norm") var preLayerNorm: Gemma4AudioRMSNorm + @ModuleInfo(key: "ffw_layer_1") var ffwLayer1: Gemma4AudioClippableLinear + @ModuleInfo(key: "ffw_layer_2") var ffwLayer2: Gemma4AudioClippableLinear + @ModuleInfo(key: "post_layer_norm") var postLayerNorm: Gemma4AudioRMSNorm + + init(config: Gemma4AudioConfiguration) { + self.gradientClipping = config.gradientClipping + self.residualWeight = config.residualWeight + + self._preLayerNorm.wrappedValue = Gemma4AudioRMSNorm(dimensions: config.hiddenSize) + self._ffwLayer1.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize, outFeatures: config.hiddenSize * 4, + useClipping: config.useClippedLinears) + self._ffwLayer2.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize * 4, outFeatures: config.hiddenSize, + useClipping: config.useClippedLinears) + self._postLayerNorm.wrappedValue = Gemma4AudioRMSNorm(dimensions: config.hiddenSize) + super.init() + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + let residual = x + var h = clip(x, min: -gradientClipping, max: gradientClipping) + h = preLayerNorm(h) + h = ffwLayer1(h) + h = silu(h) + h = ffwLayer2(h) + h = clip(h, min: -gradientClipping, max: gradientClipping) + h = postLayerNorm(h) + return residual + h * residualWeight + } +} + +private final class Gemma4AudioRelativePositionEmbedding: Module { + let numHeads: Int + let channels: Int + let headDim: Int + let maxBackward: Int + let maxForward: Int + let invTimescales: MLXArray + + @ModuleInfo(key: "pos_proj") var posProj: Linear + + init(config: Gemma4AudioConfiguration) { + self.numHeads = config.numAttentionHeads + self.channels = config.hiddenSize + self.headDim = config.hiddenSize / config.numAttentionHeads + self.maxBackward = max(0, config.attentionContextLeft - 1) + self.maxForward = config.attentionContextRight + + self._posProj.wrappedValue = Linear( + config.hiddenSize, config.numAttentionHeads * headDim, bias: false) + + let minTimescale: Float = 1.0 + let maxTimescale: Float = 10000.0 + let numTimescales = config.hiddenSize / 2 + let logTimescaleIncrement = + Foundation.log(maxTimescale / minTimescale) / Float(max(numTimescales - 1, 1)) + self.invTimescales = + MLXArray(minTimescale) + * MLX.exp(MLXArray(0 ..< numTimescales).asType(.float32) * (-logTimescaleIncrement)) + + super.init() + } + + private func getTimingSignal(_ position: MLXArray, dtype: DType) -> MLXArray { + let posFloat = position.asType(.float32) + let pos = expandedDimensions(posFloat, axis: -1) + let invTS = invTimescales.reshaped(1, 1, -1) + let scaledTime = pos * invTS + let signal = concatenated([sin(scaledTime), cos(scaledTime)], axis: -1) + return signal.asType(dtype) + } + + private func relativeShift( + _ termBD: MLXArray, batchSize: Int, numHeads: Int, numBlocks: Int, + blockSize: Int, contextSize: Int, maxSpanPlus1: Int + ) -> MLXArray { + let padAmount = (contextSize + 1) - maxSpanPlus1 + var shifted = MLX.padded( + termBD, + widths: [ + .init((0, 0)), .init((0, 0)), .init((0, 0)), .init((0, 0)), .init((0, padAmount)), + ]) + shifted = shifted.reshaped(batchSize, numHeads, numBlocks, blockSize * (contextSize + 1)) + shifted = shifted[0..., 0..., 0..., ..<(blockSize * contextSize)] + shifted = shifted.reshaped(batchSize, numHeads, numBlocks, blockSize, contextSize) + return shifted + } + + func callAsFunction(queries: MLXArray, keys: MLXArray) -> MLXArray { + // queries: [B, U, W, N, H], keys: [B, U, C, N, H] + let batchSize = queries.dim(0) + let numBlocks = queries.dim(1) + let blockSize = queries.dim(2) + let contextSize = keys.dim(2) + + let posIndices = MLXArray( + stride(from: maxBackward, through: -maxForward, by: -1).map { Int32($0) } + ) + .reshaped(1, -1) + let maxSpanPlus1 = posIndices.dim(1) + + var sinEmb = getTimingSignal(posIndices, dtype: queries.dtype) + sinEmb = posProj(sinEmb.asType(posProj.weight.dtype)) + sinEmb = sinEmb.reshaped(maxSpanPlus1, numHeads, headDim) + sinEmb = sinEmb.asType(queries.dtype) + + // queries_p: [B, N, U, W, H], keys_p: [B, N, U, H, C] + let queriesP = queries.transposed(0, 3, 1, 2, 4) + let keysP = keys.transposed(0, 3, 1, 4, 2) + let termAC = queriesP.matmul(keysP) + + // sin_emb_t: [N, H, maxSpan] + let sinEmbT = sinEmb.transposed(1, 2, 0) + let qReshaped = queriesP.reshaped(batchSize, numHeads, numBlocks * blockSize, headDim) + var termBD = qReshaped.matmul(sinEmbT).reshaped( + batchSize, numHeads, numBlocks, blockSize, maxSpanPlus1) + + termBD = relativeShift( + termBD, batchSize: batchSize, numHeads: numHeads, numBlocks: numBlocks, + blockSize: blockSize, contextSize: contextSize, maxSpanPlus1: maxSpanPlus1) + + return termAC + termBD + } +} + +private final class Gemma4AudioAttention: Module { + let numHeads: Int + let hiddenSize: Int + let headDim: Int + let chunkSize: Int + let maxFutureHorizon: Int + let maxPastHorizon: Int + let contextSize: Int + let invalidLogitsValue: Float + let softcap: Float + let qScale: Float + let kScale: Float + + @ModuleInfo(key: "relative_k_proj") var relativeKProj: Linear + @ParameterInfo(key: "per_dim_scale") var perDimScale: MLXArray + @ModuleInfo(key: "q_proj") var qProj: Gemma4AudioClippableLinear + @ModuleInfo(key: "k_proj") var kProj: Gemma4AudioClippableLinear + @ModuleInfo(key: "v_proj") var vProj: Gemma4AudioClippableLinear + @ModuleInfo(key: "post") var post: Gemma4AudioClippableLinear + + // Relative position embedding (inline) + // Note: relPosInvTimescales is NOT a model parameter — it's a computed constant. + // Store as [Float] to avoid MLX Module treating it as a loadable weight. + private let relPosNumHeads: Int + private let relPosHeadDim: Int + private let relPosMaxBackward: Int + private let relPosMaxForward: Int + private let relPosInvTimescalesData: [Float] + + init(config: Gemma4AudioConfiguration) { + self.numHeads = config.numAttentionHeads + self.hiddenSize = config.hiddenSize + self.headDim = config.hiddenSize / config.numAttentionHeads + self.chunkSize = config.attentionChunkSize + self.maxFutureHorizon = config.attentionContextRight + self.maxPastHorizon = max(0, config.attentionContextLeft - 1) + self.contextSize = chunkSize + maxPastHorizon + maxFutureHorizon + self.invalidLogitsValue = config.attentionInvalidLogitsValue + self.softcap = config.attentionLogitCap + + self.qScale = pow(Float(headDim), -0.5) / Foundation.log(2.0) + self.kScale = Foundation.log(1 + Foundation.exp(1.0)) / Foundation.log(2.0) + + self._relativeKProj.wrappedValue = Linear( + config.hiddenSize, config.numAttentionHeads * headDim, bias: false) + self._perDimScale.wrappedValue = MLXArray.zeros([headDim]) + self._qProj.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize, outFeatures: numHeads * headDim, + useClipping: config.useClippedLinears) + self._kProj.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize, outFeatures: numHeads * headDim, + useClipping: config.useClippedLinears) + self._vProj.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize, outFeatures: numHeads * headDim, + useClipping: config.useClippedLinears) + self._post.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize, outFeatures: config.hiddenSize, + useClipping: config.useClippedLinears) + + // Relative position embedding setup + self.relPosNumHeads = numHeads + self.relPosHeadDim = headDim + self.relPosMaxBackward = maxPastHorizon + self.relPosMaxForward = maxFutureHorizon + + let minTimescale: Float = 1.0 + let maxTimescale: Float = 10000.0 + let numTimescales = config.hiddenSize / 2 + let logTimescaleIncrement = + Foundation.log(maxTimescale / minTimescale) / Float(max(numTimescales - 1, 1)) + self.relPosInvTimescalesData = (0 ..< numTimescales).map { i in + minTimescale * Foundation.exp(Float(i) * (-logTimescaleIncrement)) + } + + super.init() + } + + private func padDim1(_ x: MLXArray, padLeft: Int, padRight: Int) -> MLXArray { + var widths = Array(repeating: IntOrPair((0, 0)), count: x.ndim) + widths[1] = IntOrPair((padLeft, padRight)) + return MLX.padded(x, widths: widths) + } + + private func convertToBlock(_ x: MLXArray) -> MLXArray { + // [B, T, ...] -> [B, num_blocks, chunk_size, ...] + let batchSize = x.dim(0) + let timeSteps = x.dim(1) + let rest = Array(x.shape.dropFirst(2)) + let numBlocks = (timeSteps + chunkSize - 1) / chunkSize + let padLen = numBlocks * chunkSize - timeSteps + var result = x + if padLen > 0 { + result = padDim1(result, padLeft: 0, padRight: padLen) + } + return result.reshaped([batchSize, numBlocks, chunkSize] + rest) + } + + private func extractBlockContext(_ x: MLXArray) -> MLXArray { + // [B, T, ...] -> [B, num_blocks, context_size, ...] + let padLeft = maxPastHorizon + let padRight = maxFutureHorizon + chunkSize - 1 + let padded = padDim1(x, padLeft: padLeft, padRight: padRight) + let tPadded = padded.dim(1) + let numBlocks = (tPadded - contextSize) / chunkSize + 1 + + // Build indices: starts[:, None] + offsets[None, :] + let starts = MLXArray( + stride(from: 0, to: numBlocks * chunkSize, by: chunkSize).map { + Int32($0) + }) + let offsets = MLXArray((0 ..< contextSize).map { Int32($0) }) + let indices = expandedDimensions(starts, axis: 1) + expandedDimensions(offsets, axis: 0) + // indices: [numBlocks, contextSize] + + // Gather using advanced indexing + // padded: [B, T_padded, ...rest] + // We need padded[:, indices] which gives [B, numBlocks, contextSize, ...rest] + return padded[0..., indices] + } + + private func relPosTimingSignal(_ position: MLXArray, dtype: DType) -> MLXArray { + let posFloat = position.asType(.float32) + let pos = expandedDimensions(posFloat, axis: -1) + let invTS = MLXArray(relPosInvTimescalesData).reshaped(1, 1, -1) + let scaledTime = pos * invTS + let signal = concatenated([sin(scaledTime), cos(scaledTime)], axis: -1) + return signal.asType(dtype) + } + + private func relPosRelativeShift( + _ termBD: MLXArray, batchSize: Int, numHeads: Int, numBlocks: Int, + blockSize: Int, contextSize: Int, maxSpanPlus1: Int + ) -> MLXArray { + let padAmount = (contextSize + 1) - maxSpanPlus1 + var shifted = MLX.padded( + termBD, + widths: [ + .init((0, 0)), .init((0, 0)), .init((0, 0)), .init((0, 0)), .init((0, padAmount)), + ]) + shifted = shifted.reshaped(batchSize, numHeads, numBlocks, blockSize * (contextSize + 1)) + shifted = shifted[0..., 0..., 0..., ..<(blockSize * contextSize)] + shifted = shifted.reshaped(batchSize, numHeads, numBlocks, blockSize, contextSize) + return shifted + } + + private func computeRelativePositionLogits(queries: MLXArray, keys: MLXArray) -> MLXArray { + // queries: [B, U, W, N, H], keys: [B, U, C, N, H] + let batchSize = queries.dim(0) + let numBlocks = queries.dim(1) + let blockSize = queries.dim(2) + let ctxSize = keys.dim(2) + + let posIndices = MLXArray( + stride(from: relPosMaxBackward, through: -relPosMaxForward, by: -1).map { Int32($0) } + ).reshaped(1, -1) + let maxSpanPlus1 = posIndices.dim(1) + + var sinEmb = relPosTimingSignal(posIndices, dtype: queries.dtype) + sinEmb = relativeKProj(sinEmb.asType(relativeKProj.weight.dtype)) + sinEmb = sinEmb.reshaped(maxSpanPlus1, relPosNumHeads, relPosHeadDim) + sinEmb = sinEmb.asType(queries.dtype) + + let queriesP = queries.transposed(0, 3, 1, 2, 4) + let keysP = keys.transposed(0, 3, 1, 4, 2) + let termAC = queriesP.matmul(keysP) + + let sinEmbT = sinEmb.transposed(1, 2, 0) + let qReshaped = queriesP.reshaped( + batchSize, relPosNumHeads, numBlocks * blockSize, relPosHeadDim) + var termBD = qReshaped.matmul(sinEmbT).reshaped( + batchSize, relPosNumHeads, numBlocks, blockSize, maxSpanPlus1) + + termBD = relPosRelativeShift( + termBD, batchSize: batchSize, numHeads: relPosNumHeads, numBlocks: numBlocks, + blockSize: blockSize, contextSize: ctxSize, maxSpanPlus1: maxSpanPlus1) + + return termAC + termBD + } + + func callAsFunction( + _ hiddenStates: MLXArray, mask: MLXArray, causalValidMask: MLXArray + ) -> MLXArray { + let batchSize = hiddenStates.dim(0) + let timeSteps = hiddenStates.dim(1) + let qkvShape = [batchSize, timeSteps, numHeads, headDim] + + var q = qProj(hiddenStates).asType(.float32).reshaped(qkvShape) + var k = kProj(hiddenStates).asType(.float32).reshaped(qkvShape) + let v = vProj(hiddenStates).asType(.float32).reshaped(qkvShape) + + let pds = softplus(perDimScale) + q = q * (qScale * pds) + k = k * kScale + + let queryBlocks = convertToBlock(q) // [B, U, W, N, H] + let keyBlocks = extractBlockContext(k) // [B, U, C, N, H] + let valueBlocks = extractBlockContext(v) // [B, U, C, N, H] + let numBlocks = queryBlocks.dim(1) + + // Build validity condition + let validMask = logicalNot(mask) // True = valid + let extractedValid = extractBlockContext(validMask) // [B, U, C] + // condition: [B, 1, U, W, C] + let condition = + expandedDimensions(expandedDimensions(extractedValid, axis: 1), axis: 3) + * expandedDimensions( + expandedDimensions(expandedDimensions(causalValidMask, axis: 0), axis: 0), axis: 0) + + var logits = computeRelativePositionLogits(queries: queryBlocks, keys: keyBlocks) + logits = tanh(logits / softcap) * softcap + logits = MLX.where( + condition .> 0, logits, MLXArray(invalidLogitsValue, dtype: logits.dtype)) + + let probs = softmax(logits, axis: -1) + // context = einsum("bnuwc,bucnh->buwnh", probs, valueBlocks) + var context = einsum("bnuwc,bucnh->buwnh", probs, valueBlocks) + context = context.reshaped(batchSize, numBlocks * chunkSize, numHeads, headDim) + context = context[0..., .. [B, T, D] and post-project + context = context.reshaped(batchSize, timeSteps, numHeads * headDim) + return post(context) + } +} + +private final class Gemma4ConformerLightConv1d: Module { + let gradientClipping: Float + let causalPadding: Int + + @ModuleInfo(key: "pre_layer_norm") var preLayerNorm: Gemma4AudioRMSNorm + @ModuleInfo(key: "linear_start") var linearStart: Gemma4AudioClippableLinear + @ModuleInfo(key: "depthwise_conv1d") var depthwiseConv1d: Conv1d + @ModuleInfo(key: "conv_norm") var convNorm: Gemma4AudioRMSNorm + @ModuleInfo(key: "linear_end") var linearEnd: Gemma4AudioClippableLinear + + init(config: Gemma4AudioConfiguration) { + self.gradientClipping = config.gradientClipping + self.causalPadding = config.convKernelSize - 1 + + self._preLayerNorm.wrappedValue = Gemma4AudioRMSNorm( + dimensions: config.hiddenSize, eps: config.rmsNormEps) + self._linearStart.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize, outFeatures: config.hiddenSize * 2, + useClipping: config.useClippedLinears) + // Depthwise conv1d: groups = hidden_size so weight shape is [out, kernel, 1] + self._depthwiseConv1d.wrappedValue = Conv1d( + inputChannels: config.hiddenSize, + outputChannels: config.hiddenSize, + kernelSize: config.convKernelSize, + stride: 1, + padding: 0, + groups: config.hiddenSize, + bias: false + ) + self._convNorm.wrappedValue = Gemma4AudioRMSNorm( + dimensions: config.hiddenSize, eps: config.rmsNormEps) + self._linearEnd.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize, outFeatures: config.hiddenSize, + useClipping: config.useClippedLinears) + super.init() + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + let residual = x + + var h = preLayerNorm(x) + h = linearStart(h) + + // GLU: split in half along last dim and gate + let halfDim = h.dim(-1) / 2 + let x1 = h[0..., 0..., .. MLXArray { + var h = feedForward1(x) + + // Attention with pre/post norm and residual + let residual = h + h = clip(h, min: -gradientClipping, max: gradientClipping) + h = normPreAttn(h) + h = selfAttn(h, mask: mask, causalValidMask: causalValidMask) + h = clip(h, min: -gradientClipping, max: gradientClipping) + h = residual + normPostAttn(h) + + // Zero out invalid positions before lconv1d + let validityMask = expandedDimensions(logicalNot(mask), axis: -1).asType(h.dtype) + h = h * validityMask + + h = lconv1d(h) + h = feedForward2(h) + h = clip(h, min: -gradientClipping, max: gradientClipping) + return normOut(h) + } +} + +private final class Gemma4AudioEncoder: Module { + let config: Gemma4AudioConfiguration + + @ModuleInfo(key: "subsample_conv_projection") var subsampleConvProjection: + Gemma4SubSampleConvProjection + @ModuleInfo(key: "layers") var layers: [Gemma4ConformerBlock] + @ModuleInfo(key: "output_proj") var outputProj: Linear? + + init(config: Gemma4AudioConfiguration) { + self.config = config + self._subsampleConvProjection.wrappedValue = Gemma4SubSampleConvProjection(config: config) + self._layers.wrappedValue = (0 ..< config.numHiddenLayers).map { _ in + Gemma4ConformerBlock(config: config) + } + if let outputProjDims = config.outputProjDims { + self._outputProj.wrappedValue = Linear( + config.hiddenSize, outputProjDims, bias: true) + } + super.init() + } + + private func buildCausalValidMask() -> MLXArray { + let chunkSize = config.attentionChunkSize + let maxFutureHorizon = config.attentionContextRight + let maxPastHorizon = max(0, config.attentionContextLeft - 1) + let upperDiagonal = maxPastHorizon + maxFutureHorizon + let ctxSize = chunkSize + maxPastHorizon + maxFutureHorizon + + let lowerCausal = tril(MLXArray.ones([ctxSize, chunkSize])).transposed() + let upperCausal = tril( + MLXArray.ones([chunkSize, ctxSize]), + k: upperDiagonal) + let maskResult = (lowerCausal * upperCausal).asType(.bool) + return maskResult + } + + func callAsFunction(_ audioMel: MLXArray, audioMelMask: MLXArray) -> (MLXArray, MLXArray) { + var (audioEncodings, currentMask) = subsampleConvProjection(audioMel, mask: audioMelMask) + + let causalValidMask = buildCausalValidMask() + + for block in layers { + audioEncodings = block( + audioEncodings, mask: currentMask, causalValidMask: causalValidMask) + } + + if let outputProj { + audioEncodings = outputProj(audioEncodings) + } + + if currentMask.dim(1) != audioEncodings.dim(1) { + let targetLen = audioEncodings.dim(1) + currentMask = currentMask[0..., .. (MLXArray, MLXArray?) { var inputsEmbeds = languageModel.model.embedTokens(inputIds) inputsEmbeds = @@ -1682,41 +2465,65 @@ public final class Gemma4: Module, VLMModel, KVCacheDimensionProvider { var perLayerInputs: MLXArray? = nil if config.textConfiguration.hiddenSizePerLayerInput > 0 { - let imageMask = inputIds .== config.imageTokenId - let audioMask = + let imageMaskPL = inputIds .== config.imageTokenId + let audioMaskPL = if let audioTokenId = config.audioTokenId { inputIds .== audioTokenId } else { - MLXArray.zeros(like: imageMask) + MLXArray.zeros(like: imageMaskPL) } - let textMask = logicalNot(logicalOr(imageMask, audioMask)) + let textMask = logicalNot(logicalOr(imageMaskPL, audioMaskPL)) let perLayerTokens = MLX.where(textMask, inputIds, MLXArray.zeros(like: inputIds)) perLayerInputs = languageModel.model.getPerLayerInputs(perLayerTokens) } - guard let pixelValues else { - return (inputsEmbeds, perLayerInputs) - } + // Scatter vision features into placeholder positions + if let pixelValues { + var imageFeatures = visionTower(pixelValues) + imageFeatures = embedVision(imageFeatures) + imageFeatures = imageFeatures.asType(inputsEmbeds.dtype) - var imageFeatures = visionTower(pixelValues) - imageFeatures = embedVision(imageFeatures) - imageFeatures = imageFeatures.asType(inputsEmbeds.dtype) + let imageMask = inputIds .== config.imageTokenId + let expectedImageTokens = imageMask.asType(.int32).sum().item(Int.self) - let imageMask = inputIds .== config.imageTokenId - let expectedImageTokens = imageMask.asType(.int32).sum().item(Int.self) + if expectedImageTokens != imageFeatures.dim(1) { + throw Gemma4Error.imageTokenCountMismatch( + expectedVisionTokens: imageFeatures.dim(1), + actualPromptTokens: expectedImageTokens) + } - if expectedImageTokens != imageFeatures.dim(1) { - throw Gemma4Error.imageTokenCountMismatch( - expectedVisionTokens: imageFeatures.dim(1), actualPromptTokens: expectedImageTokens) + var imageMaskExpanded = expandedDimensions(imageMask, axis: -1) + imageMaskExpanded = broadcast(imageMaskExpanded, to: inputsEmbeds.shape) + inputsEmbeds = gemma4MaskedScatter( + inputTensor: inputsEmbeds, + mask: imageMaskExpanded, + source: imageFeatures + ) } - var imageMaskExpanded = expandedDimensions(imageMask, axis: -1) - imageMaskExpanded = broadcast(imageMaskExpanded, to: inputsEmbeds.shape) - inputsEmbeds = gemma4MaskedScatter( - inputTensor: inputsEmbeds, - mask: imageMaskExpanded, - source: imageFeatures - ) + // Scatter audio features into <|audio|> placeholder positions + if let audioFeatures, + let audioTower, + let embedAudio, + let audioTokenId = config.audioTokenId + { + // audioFeatures: [1, frames, melBins] ; audioMask: [1, frames] (True=padding) + let encoderMask = + audioMask + ?? MLXArray.zeros([audioFeatures.dim(0), audioFeatures.dim(1)]).asType(.bool) + let (audioEncodings, _) = audioTower(audioFeatures, audioMelMask: encoderMask) + var audioEmb = embedAudio(audioEncodings) + audioEmb = audioEmb.asType(inputsEmbeds.dtype) + + let tokenMask = inputIds .== audioTokenId + var tokenMaskExpanded = expandedDimensions(tokenMask, axis: -1) + tokenMaskExpanded = broadcast(tokenMaskExpanded, to: inputsEmbeds.shape) + inputsEmbeds = gemma4MaskedScatter( + inputTensor: inputsEmbeds, + mask: tokenMaskExpanded, + source: audioEmb + ) + } return (inputsEmbeds, perLayerInputs) } @@ -1725,9 +2532,15 @@ public final class Gemma4: Module, VLMModel, KVCacheDimensionProvider { -> PrepareResult { let convertedCache = cache.map { $0 } - if let imagePixels = input.image?.pixels { + let hasImage = input.image?.pixels != nil + let hasAudio = input.audio?.features != nil + + if hasImage || hasAudio { let (inputsEmbeds, perLayerInputs) = try getInputEmbeddings( - inputIds: input.text.tokens, pixelValues: imagePixels) + inputIds: input.text.tokens, + pixelValues: input.image?.pixels, + audioFeatures: input.audio?.features, + audioMask: input.audio?.mask) let result = languageModel( nil, cache: convertedCache, @@ -1749,9 +2562,11 @@ public final class Gemma4: Module, VLMModel, KVCacheDimensionProvider { public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { var sanitized = languageModel.sanitize(weights: weights) - // This port currently supports text + vision only. - sanitized = sanitized.filter { key, _ in - !key.contains("audio_tower") && !key.contains("embed_audio") + // Only strip audio weights when audio tower is not configured. + if config.audioConfiguration == nil { + sanitized = sanitized.filter { key, _ in + !key.contains("audio_tower") && !key.contains("embed_audio") + } } if !config.visionConfiguration.useClippedLinears { @@ -1865,9 +2680,41 @@ public struct Gemma4Processor: UserInputProcessor { promptTokens = expandedTokens } + // Audio handling: extract mel features, inject <|audio|> placeholder tokens, build ProcessedAudio + var processedAudio: LMInput.ProcessedAudio? = nil + if !input.audios.isEmpty, let audioTokenId = config.audioTokenId { + let extractor = Gemma4AudioFeatureExtractor() + let (melFeatures, melMask) = extractor.extract(audio: input.audios[0]) + + // Fix #6: audio token count from actual subsampling math (two 2x conv blocks) + let melFrames = melFeatures.dim(0) + let afterConv0 = (melFrames + 2 - 3) / 2 + 1 + let numAudioTokens = min((afterConv0 + 2 - 3) / 2 + 1, 750) + + // Fix #8: inject audio placeholder tokens into the prompt + // Insert before final assistant turn (last newline token 108) or append + let audioPlaceholders = Array(repeating: audioTokenId, count: numAudioTokens) + if let lastNewlineIdx = promptTokens.lastIndex(of: 108) { + promptTokens.insert(contentsOf: audioPlaceholders, at: lastNewlineIdx) + } else { + promptTokens.append(contentsOf: audioPlaceholders) + } + + // Fix #2: mask polarity inversion — extractor outputs 1=valid but encoder expects True=padding + let invertedMask = melMask .== 0 + processedAudio = LMInput.ProcessedAudio( + features: melFeatures.expandedDimensions(axis: 0), + mask: invertedMask.expandedDimensions(axis: 0) + ) + } + let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0) let mask = ones(like: promptArray).asType(.int8) - return LMInput(text: .init(tokens: promptArray, mask: mask), image: processedImage) + return LMInput( + text: .init(tokens: promptArray, mask: mask), + image: processedImage, + audio: processedAudio + ) } } @@ -1882,6 +2729,7 @@ public struct Gemma4ProcessorConfiguration: Codable, Sendable { public let imageTokenId: Int public let boiTokenId: Int public let eoiTokenId: Int? + public let audioTokenId: Int? enum CodingKeys: String, CodingKey { case processorClass = "processor_class" @@ -1893,6 +2741,7 @@ public struct Gemma4ProcessorConfiguration: Codable, Sendable { case imageTokenId = "image_token_id" case boiTokenId = "boi_token_id" case eoiTokenId = "eoi_token_id" + case audioTokenId = "audio_token_id" } public init(from decoder: any Swift.Decoder) throws { @@ -1909,6 +2758,8 @@ public struct Gemma4ProcessorConfiguration: Codable, Sendable { imageTokenId = try c.decodeIfPresent(Int.self, forKey: CodingKeys.imageTokenId) ?? 258_880 boiTokenId = try c.decodeIfPresent(Int.self, forKey: CodingKeys.boiTokenId) ?? 255_999 eoiTokenId = try c.decodeIfPresent(Int.self, forKey: CodingKeys.eoiTokenId) ?? 258_882 + // Fix #7: default audioTokenId to 258881 when absent from processor_config.json + audioTokenId = try c.decodeIfPresent(Int.self, forKey: CodingKeys.audioTokenId) ?? 258_881 } public var imageMeanTuple: (CGFloat, CGFloat, CGFloat) { From c043caacca2d82e75b7b9d777c74fae2bd261585 Mon Sep 17 00:00:00 2001 From: antmanler Date: Tue, 14 Apr 2026 16:13:37 +0800 Subject: [PATCH 4/5] test(gemma4-audio): unit + alignment tests with Python reference fixtures MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds two test suites covering the audio tower: - Gemma4AudioTests.swift: * audioConfigurationDecoding — JSON round-trip with defaults * audioTokenMerging — token count matches subsampling math * audioTokenCount — math-only bounds * audioEncoderOutputShape — encoder forward shape * melSpectrogramShape/Deterministic/FrameValues — mel preprocessing * melFilterBankShape — filter bank sanity * endToEndTranscription — config-level end-to-end integration - Gemma4AudioAlignmentTest.swift: * melOutputMatchesPython — fixture-based numerical parity against the Python mlx-vlm reference within 1e-3 max abs error * melSpectrogramShapeAndStats — mel shape + mean/std stability Fixtures (gemma4_mel_reference.json, gemma4_mel_alignment.json, gemma4_token_alignment.json, gemma4_e2e_reference.json) are registered as processed resources on MLXLMTests in Package.swift. Six SPM-runnable tests pass on both upstream main base and the pre-rebase branch; the remaining tests need Metal and run under Xcode. --- Package.swift | 9 +- .../Fixtures/gemma4_e2e_reference.json | 8 + .../Fixtures/gemma4_mel_alignment.json | 1 + .../Fixtures/gemma4_mel_reference.json | 136 +++++++++++++++ .../Fixtures/gemma4_token_alignment.json | 1 + .../MLXLMTests/Gemma4AudioAlignmentTest.swift | 148 +++++++++++++++++ Tests/MLXLMTests/Gemma4AudioTests.swift | 157 ++++++++++++++++++ 7 files changed, 459 insertions(+), 1 deletion(-) create mode 100644 Tests/MLXLMTests/Fixtures/gemma4_e2e_reference.json create mode 100644 Tests/MLXLMTests/Fixtures/gemma4_mel_alignment.json create mode 100644 Tests/MLXLMTests/Fixtures/gemma4_mel_reference.json create mode 100644 Tests/MLXLMTests/Fixtures/gemma4_token_alignment.json create mode 100644 Tests/MLXLMTests/Gemma4AudioAlignmentTest.swift create mode 100644 Tests/MLXLMTests/Gemma4AudioTests.swift diff --git a/Package.swift b/Package.swift index 8dcf54c51..adbc0e956 100644 --- a/Package.swift +++ b/Package.swift @@ -128,7 +128,14 @@ let package = Package( exclude: [ "README.md" ], - resources: [.process("Resources/1080p_30.mov"), .process("Resources/audio_only.mov")] + resources: [ + .process("Resources/1080p_30.mov"), + .process("Resources/audio_only.mov"), + .process("Fixtures/gemma4_mel_reference.json"), + .process("Fixtures/gemma4_mel_alignment.json"), + .process("Fixtures/gemma4_token_alignment.json"), + .process("Fixtures/gemma4_e2e_reference.json"), + ] ), .macro( name: "MLXHuggingFaceMacros", diff --git a/Tests/MLXLMTests/Fixtures/gemma4_e2e_reference.json b/Tests/MLXLMTests/Fixtures/gemma4_e2e_reference.json new file mode 100644 index 000000000..8ec5e4153 --- /dev/null +++ b/Tests/MLXLMTests/Fixtures/gemma4_e2e_reference.json @@ -0,0 +1,8 @@ +{ + "model": "mlx-community/gemma-4-e2b-it-4bit", + "audio_file": "7E9A42BA-8AD9-44D7-BED3-95BAEDA2B699.m4a", + "expected_text": "好的,我来更新一版。", + "python_output": "GenerationResult(text='好的, 我来更新一把。', token=106, logprobs=array([-32, -13.375, -18.625, ..., -32, -32, -32.25], dtype=bfloat16), prompt_tokens=115, generation_tokens=8, total_tokens=123, prompt_tps=833.0732520820604, generation_tps=193.63271264522086, peak_memory=3.882646946)", + "load_time": 2.19, + "infer_time": 0.2 +} \ No newline at end of file diff --git a/Tests/MLXLMTests/Fixtures/gemma4_mel_alignment.json b/Tests/MLXLMTests/Fixtures/gemma4_mel_alignment.json new file mode 100644 index 000000000..5482304f8 --- /dev/null +++ b/Tests/MLXLMTests/Fixtures/gemma4_mel_alignment.json @@ -0,0 +1 @@ +{"input_audio": [0.0, 0.17195037007331848, 0.338778555393219, 0.49551498889923096, 0.6374905109405518, 0.7604761719703674, 0.8608080148696899, 0.935497522354126, 0.9823195934295654, 0.9998796582221985, 0.9876545071601868, 0.9460083842277527, 0.8761817216873169, 0.7802547812461853, 0.6610848903656006, 0.5222222208976746, 0.3678033947944641, 0.20242774486541748, 0.03102201409637928, -0.14130783081054688, -0.30942803621292114, -0.4683307707309723, -0.6132825016975403, -0.739965558052063, -0.8446056842803955, -0.9240860939025879, -0.9760391116142273, -0.9989171624183655, -0.9920386672019958, -0.9556085467338562, -0.8907120823860168, -0.7992823719978333, -0.6840433478355408, -0.5484269261360168, -0.39647382497787476, -0.2327098548412323, -0.06201416626572609, 0.1105288416147232, 0.2797797918319702, 0.44069600105285645, 0.5884844064712524, 0.7187426090240479, 0.8275903463363647, 0.9117851853370667, 0.9688191413879395, 0.9969931840896606, 0.9954679012298584, 0.9642890691757202, 0.904384970664978, 0.8175402283668518], "input_length": 8000, "mel_shape": [49, 128], "mel_frames_0_to_9": [[-3.2282907962799072, -3.2766783237457275, -3.5491271018981934, -3.0716898441314697, -2.9152474403381348, -3.333768129348755, -2.8433380126953125, -2.427412271499634, -3.087891101837158, -2.4138436317443848, -1.8464511632919312, -2.6745810508728027, -1.5860815048217773, -1.153499722480774, -1.5364115238189697, -0.6299163103103638, -0.4535810351371765, -0.23964302241802216, 0.8216678500175476, 0.6555573344230652, 2.163978338241577, 3.590435028076172, 4.132043838500977, 4.602357387542725, 4.728536128997803, 4.551765441894531, 4.250521659851074, 3.289860963821411, 1.9147915840148926, 0.9566134810447693, 0.2554934024810791, -0.12395168095827103, -0.964153528213501, -0.960668683052063, -1.6330701112747192, -1.5327516794204712, -2.1271464824676514, -2.0747106075286865, -2.5375614166259766, -2.6701223850250244, -2.779103994369507, -3.4674127101898193, -3.1713263988494873, -3.464197874069214, -3.707003355026245, -3.6574747562408447, -3.926276683807373, -4.15101432800293, -4.187079906463623, -4.35812520980835, -4.56632661819458, -4.597280979156494, -4.676879405975342, -4.7698073387146, -4.992301940917969, -5.12431526184082, -5.203283786773682, -5.261053085327148, -5.311149597167969, -5.44577693939209, -5.580801010131836, -5.6953511238098145, -5.761744499206543, -5.819239616394043, -5.900275230407715, -6.004395484924316, -6.110221862792969, -6.181611061096191, -6.233878135681152, -6.30875825881958, -6.417316436767578, -6.466412544250488, -6.557384014129639, -6.60402250289917, -6.683801174163818, -6.798633575439453, -6.832597732543945, -6.8639140129089355, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447], [-3.579131603240967, -4.2270307540893555, -3.458181619644165, -3.037891149520874, -3.2634799480438232, -3.4393742084503174, -2.75508713722229, -2.5359609127044678, -3.1781857013702393, -2.3706254959106445, -1.8891569375991821, -2.919872283935547, -1.5751402378082275, -1.1787855625152588, -1.5672683715820312, -0.6262847185134888, -0.4658079743385315, -0.2360496073961258, 0.8204818367958069, 0.654892086982727, 2.164278030395508, 3.590259075164795, 4.13197135925293, 4.602389335632324, 4.7285847663879395, 4.551754474639893, 4.2504448890686035, 3.2898595333099365, 1.9151451587677002, 0.9564427733421326, 0.2562694847583771, -0.12450254708528519, -0.9655462503433228, -0.9613144993782043, -1.6201804876327515, -1.528205394744873, -2.119544744491577, -2.0651395320892334, -2.5418360233306885, -2.6585400104522705, -2.781419515609741, -3.495183229446411, -3.142444133758545, -3.450880527496338, -3.66884183883667, -3.6352498531341553, -3.9241886138916016, -4.156484127044678, -4.169610023498535, -4.347056865692139, -4.5178141593933105, -4.542198181152344, -4.646999835968018, -4.756840705871582, -5.0044331550598145, -5.106478214263916, -5.134222984313965, -5.194972515106201, -5.27150297164917, -5.435925483703613, -5.58473014831543, -5.647491931915283, -5.685131072998047, -5.759479999542236, -5.867160797119141, -5.998646259307861, -6.0679192543029785, -6.1029253005981445, -6.163337707519531, -6.2877678871154785, -6.407204627990723, -6.411420822143555, -6.465460300445557, -6.547972202301025, -6.689759254455566, -6.758521556854248, -6.722104072570801, -6.824206829071045, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447], [-3.496804714202881, -3.870730400085449, -3.445673704147339, -3.0416035652160645, -3.1758949756622314, -3.3995020389556885, -2.7677299976348877, -2.514261484146118, -3.158592462539673, -2.3778188228607178, -1.8813930749893188, -2.863222599029541, -1.576470971107483, -1.174191951751709, -1.5615081787109375, -0.6269211173057556, -0.4636525511741638, -0.2366730123758316, 0.8206834197044373, 0.6550090312957764, 2.1642279624938965, 3.5902903079986572, 4.131984233856201, 4.602383613586426, 4.728576183319092, 4.551756381988525, 4.250457763671875, 3.289858818054199, 1.9150822162628174, 0.9564617872238159, 0.25612780451774597, -0.1243537962436676, -0.9652396440505981, -0.9612778425216675, -1.6224862337112427, -1.5288301706314087, -2.1207833290100098, -2.0668883323669434, -2.540947914123535, -2.660337448120117, -2.7811551094055176, -3.4858314990997314, -3.1475279331207275, -3.4515843391418457, -3.673774480819702, -3.6406519412994385, -3.924126625061035, -4.153514385223389, -4.1702399253845215, -4.348833084106445, -4.524856090545654, -4.548977375030518, -4.653415679931641, -4.7531633377075195, -5.001924991607666, -5.108245372772217, -5.148266792297363, -5.201211452484131, -5.273863792419434, -5.4312944412231445, -5.586494445800781, -5.6503987312316895, -5.6910295486450195, -5.772294044494629, -5.859232425689697, -6.0101318359375, -6.078359603881836, -6.102794170379639, -6.176644802093506, -6.2785563468933105, -6.405879020690918, -6.40477180480957, -6.486269474029541, -6.571071147918701, -6.623419284820557, -6.727398872375488, -6.790015697479248, -6.774710655212402, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447], [-3.161121368408203, -3.1704440116882324, -3.63875412940979, -3.0923991203308105, -2.853646755218506, -3.3461456298828125, -2.879451036453247, -2.401902198791504, -3.0708229541778564, -2.42747163772583, -1.8352336883544922, -2.626504898071289, -1.5900143384933472, -1.14653480052948, -1.5283927917480469, -0.6310734152793884, -0.450257271528244, -0.2406506985425949, 0.8220256567001343, 0.6557483077049255, 2.163890838623047, 3.590484857559204, 4.132064342498779, 4.6023478507995605, 4.728522300720215, 4.551769256591797, 4.25054407119751, 3.2898640632629395, 1.914696455001831, 0.9566437005996704, 0.25525742769241333, -0.12379691004753113, -0.9638930559158325, -0.9604921936988831, -1.63685941696167, -1.5340381860733032, -2.129456043243408, -2.077648639678955, -2.5366063117980957, -2.6738805770874023, -2.77880859375, -3.4640579223632812, -3.180861473083496, -3.4693658351898193, -3.719452381134033, -3.6637535095214844, -3.92798113822937, -4.151993751525879, -4.195366382598877, -4.3640851974487305, -4.584532737731934, -4.620166778564453, -4.688038349151611, -4.7751359939575195, -4.99622106552124, -5.139235019683838, -5.230171203613281, -5.286499977111816, -5.326298713684082, -5.458748817443848, -5.58806037902832, -5.728765487670898, -5.775186538696289, -5.850414276123047, -5.91972541809082, -6.029372692108154, -6.116426944732666, -6.219765663146973, -6.271730422973633, -6.324415683746338, -6.442513942718506, -6.497250080108643, -6.588694095611572, -6.64125394821167, -6.699180603027344, -6.841357707977295, -6.8361968994140625, -6.884866714477539, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447], [-3.056699514389038, -3.01920747756958, -4.193084716796875, -3.202496290206909, -2.7587921619415283, -3.528486728668213, -3.0191075801849365, -2.3587615489959717, -3.049165964126587, -2.4552206993103027, -1.8153479099273682, -2.5496456623077393, -1.5981907844543457, -1.1343002319335938, -1.5145071744918823, -0.6328415870666504, -0.44378533959388733, -0.2427268773317337, 0.822594404220581, 0.6560760736465454, 2.1637346744537354, 3.5905771255493164, 4.132102012634277, 4.602331161499023, 4.728496551513672, 4.551774978637695, 4.250584125518799, 3.2898638248443604, 1.914507508277893, 0.9567534923553467, 0.2548534572124481, -0.12355437874794006, -0.9635857343673706, -0.9603222012519836, -1.6442270278930664, -1.5364644527435303, -2.1339006423950195, -2.0831902027130127, -2.5348849296569824, -2.6811232566833496, -2.777233839035034, -3.4608612060546875, -3.199716091156006, -3.481795310974121, -3.7470946311950684, -3.6796376705169678, -3.932962417602539, -4.154665946960449, -4.204318046569824, -4.391345024108887, -4.643023490905762, -4.664146900177002, -4.713740825653076, -4.795074939727783, -5.0111260414123535, -5.150177001953125, -5.3199286460876465, -5.3563456535339355, -5.382023334503174, -5.502870082855225, -5.596047878265381, -5.764625549316406, -5.869625091552734, -5.945580005645752, -5.978389263153076, -6.022638320922852, -6.189187049865723, -6.29866886138916, -6.370751857757568, -6.392051696777344, -6.511825084686279, -6.552934169769287, -6.722696304321289, -6.766573429107666, -6.791435718536377, -6.870759010314941, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447], [-3.2196273803710938, -3.2628467082977295, -3.5580642223358154, -3.0743448734283447, -2.9082846641540527, -3.3342273235321045, -2.8466579914093018, -2.4240996837615967, -3.085719347000122, -2.415492534637451, -1.8451638221740723, -2.6686370372772217, -1.5865613222122192, -1.1527538299560547, -1.5355159044265747, -0.6299982070922852, -0.4531573951244354, -0.23976582288742065, 0.8217065930366516, 0.655579149723053, 2.163968324661255, 3.5904409885406494, 4.132046222686768, 4.60235595703125, 4.72853422164917, 4.551765441894531, 4.250523090362549, 3.2898592948913574, 1.9147768020629883, 0.956622302532196, 0.25547003746032715, -0.12390272319316864, -0.964131236076355, -0.9607426524162292, -1.6336075067520142, -1.5328706502914429, -2.1273932456970215, -2.0750083923339844, -2.5372891426086426, -2.6704342365264893, -2.7799556255340576, -3.466409683227539, -3.1719815731048584, -3.4639368057250977, -3.706390380859375, -3.662386178970337, -3.9258594512939453, -4.1481523513793945, -4.18133544921875, -4.365660190582275, -4.572695255279541, -4.601734161376953, -4.673540115356445, -4.767725467681885, -4.999077796936035, -5.125188827514648, -5.21177339553833, -5.254054069519043, -5.314937114715576, -5.448683261871338, -5.576170444488525, -5.693346977233887, -5.782553195953369, -5.81118631362915, -5.886895179748535, -6.045269966125488, -6.082235813140869, -6.180851459503174, -6.237669467926025, -6.296166896820068, -6.437432765960693, -6.470627784729004, -6.554523944854736, -6.60709285736084, -6.675545692443848, -6.821540355682373, -6.835134983062744, -6.889771938323975, -6.876645565032959, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447], [-3.5738015174865723, -4.178841590881348, -3.453583240509033, -3.0370044708251953, -3.253901481628418, -3.4354240894317627, -2.756509780883789, -2.5340769290924072, -3.1765079498291016, -2.3713347911834717, -1.8885178565979004, -2.9144933223724365, -1.5751559734344482, -1.1782987117767334, -1.5667147636413574, -0.6263841986656189, -0.4656425714492798, -0.2361021339893341, 0.8204984068870544, 0.6549038290977478, 2.164276361465454, 3.5902626514434814, 4.131972789764404, 4.602388858795166, 4.728583812713623, 4.551754951477051, 4.250446319580078, 3.2898600101470947, 1.9151411056518555, 0.9564382433891296, 0.25625380873680115, -0.12447024136781693, -0.9655301570892334, -0.9613710045814514, -1.6204408407211304, -1.5283153057098389, -2.1196322441101074, -2.0651190280914307, -2.5417580604553223, -2.65865159034729, -2.7805392742156982, -3.4939727783203125, -3.1454169750213623, -3.4517736434936523, -3.6700384616851807, -3.632207155227661, -3.925053834915161, -4.155511379241943, -4.163262844085693, -4.349091053009033, -4.521981239318848, -4.545775413513184, -4.642868995666504, -4.757268905639648, -4.999277114868164, -5.100261688232422, -5.146394729614258, -5.190277099609375, -5.27059268951416, -5.439120292663574, -5.569535255432129, -5.651200771331787, -5.690981388092041, -5.7597198486328125, -5.882033348083496, -5.973872184753418, -6.021007061004639, -6.1277241706848145, -6.164151191711426, -6.297959804534912, -6.395684719085693, -6.400032997131348, -6.472171783447266, -6.563321590423584, -6.633610725402832, -6.7731852531433105, -6.733088970184326, -6.816804885864258, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.678374767303467, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.66912841796875, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447], [-3.5091874599456787, -3.9070286750793457, -3.445079803466797, -3.0397160053253174, -3.184654712677002, -3.403507709503174, -2.7663629055023193, -2.5172228813171387, -3.16111159324646, -2.3768138885498047, -1.8823667764663696, -2.8700320720672607, -1.5762525796890259, -1.174720287322998, -1.562227725982666, -0.6269697546958923, -0.4640694856643677, -0.23652991652488708, 0.820688784122467, 0.6550001502037048, 2.164231300354004, 3.590285539627075, 4.131982326507568, 4.602384090423584, 4.728577136993408, 4.551755905151367, 4.250455856323242, 3.289857864379883, 1.9150874614715576, 0.9564815163612366, 0.25616177916526794, -0.12440510839223862, -0.9652771949768066, -0.961238443851471, -1.6221237182617188, -1.5288232564926147, -2.120640277862549, -2.0665969848632812, -2.541116714477539, -2.660151958465576, -2.7810723781585693, -3.489349365234375, -3.1474521160125732, -3.4482288360595703, -3.670973539352417, -3.646860122680664, -3.921337366104126, -4.151124954223633, -4.165822982788086, -4.35595178604126, -4.526520252227783, -4.548067569732666, -4.647369384765625, -4.755128860473633, -5.0098443031311035, -5.112357139587402, -5.128594398498535, -5.2123942375183105, -5.270849227905273, -5.434405326843262, -5.572094440460205, -5.656632423400879, -5.694483280181885, -5.760781764984131, -5.859889507293701, -6.0471625328063965, -6.038731575012207, -6.096343994140625, -6.173293590545654, -6.285161972045898, -6.3877363204956055, -6.4280805587768555, -6.469169616699219, -6.563304424285889, -6.631941795349121, -6.756369113922119, -6.756669998168945, -6.820082664489746, -6.905993938446045, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.690591812133789, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.528506278991699, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447], [-3.1682958602905273, -3.1820144653320312, -3.6283059120178223, -3.08907413482666, -2.8587210178375244, -3.3431437015533447, -2.8756203651428223, -2.404940128326416, -3.073112726211548, -2.425830364227295, -1.836195468902588, -2.630841016769409, -1.5894914865493774, -1.1471645832061768, -1.5291564464569092, -0.6309791803359985, -0.45056357979774475, -0.24063678085803986, 0.8219229578971863, 0.6557192206382751, 2.1639161109924316, 3.590482473373413, 4.132062911987305, 4.602349281311035, 4.728523254394531, 4.551767826080322, 4.250539779663086, 3.28985857963562, 1.914695143699646, 0.9566917419433594, 0.25531694293022156, -0.12385409325361252, -0.9639591574668884, -0.9606557488441467, -1.636677861213684, -1.5336947441101074, -2.129145622253418, -2.0773932933807373, -2.5366077423095703, -2.6734347343444824, -2.7800164222717285, -3.461632013320923, -3.1746039390563965, -3.472243547439575, -3.7212729454040527, -3.6571600437164307, -3.937809705734253, -4.150787353515625, -4.185368537902832, -4.3621320724487305, -4.585691452026367, -4.616714954376221, -4.692558765411377, -4.773922443389893, -4.98856258392334, -5.133028984069824, -5.239468097686768, -5.2796549797058105, -5.326484203338623, -5.4643874168396, -5.580942153930664, -5.705613136291504, -5.795403003692627, -5.854223728179932, -5.892224311828613, -6.016124248504639, -6.167349338531494, -6.197436809539795, -6.260101795196533, -6.315043926239014, -6.452841281890869, -6.467634201049805, -6.601536750793457, -6.647502422332764, -6.685373783111572, -6.829946517944336, -6.844541072845459, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.708019256591797, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.468543529510498, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447], [-3.0563478469848633, -3.0186915397644043, -4.215390682220459, -3.2055230140686035, -2.7576122283935547, -3.540213108062744, -3.0253376960754395, -2.358675479888916, -3.0495474338531494, -2.45538067817688, -1.8150392770767212, -2.548710823059082, -1.5983792543411255, -1.1342861652374268, -1.5144480466842651, -0.6328760981559753, -0.44378215074539185, -0.24271489679813385, 0.8226110339164734, 0.6560831069946289, 2.1637344360351562, 3.590578317642212, 4.1321024894714355, 4.602330684661865, 4.728496551513672, 4.551774978637695, 4.250585079193115, 3.289865255355835, 1.9145088195800781, 0.9567417502403259, 0.2548418641090393, -0.12353555858135223, -0.9635334610939026, -0.960135281085968, -1.643970012664795, -1.5369551181793213, -2.1340277194976807, -2.0828146934509277, -2.534947395324707, -2.6812355518341064, -2.778329849243164, -3.459696054458618, -3.1981589794158936, -3.4842896461486816, -3.749099016189575, -3.671879529953003, -3.9354147911071777, -4.158346176147461, -4.205741882324219, -4.376739978790283, -4.636220455169678, -4.663705825805664, -4.719925880432129, -4.794907569885254, -5.015626907348633, -5.144525527954102, -5.313507080078125, -5.356861114501953, -5.381103992462158, -5.4953203201293945, -5.619538307189941, -5.743007183074951, -5.913949966430664, -5.94618034362793, -5.95794677734375, -6.104227066040039, -6.133890151977539, -6.350693702697754, -6.403267860412598, -6.404314994812012, -6.48440408706665, -6.58862829208374, -6.696202754974365, -6.7667927742004395, -6.787436008453369, -6.873215198516846, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.596250534057617, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.628902435302734, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.907755374908447, -6.8980512619018555, -6.865535736083984, -6.907755374908447, -6.907755374908447, -6.907755374908447]], "mel_stats": {"mean": -4.323295593261719, "std": 3.0924665927886963, "min": -6.907755374908447, "max": 4.7285847663879395}} \ No newline at end of file diff --git a/Tests/MLXLMTests/Fixtures/gemma4_mel_reference.json b/Tests/MLXLMTests/Fixtures/gemma4_mel_reference.json new file mode 100644 index 000000000..5acabaad0 --- /dev/null +++ b/Tests/MLXLMTests/Fixtures/gemma4_mel_reference.json @@ -0,0 +1,136 @@ +{ + "audio_samples_first100": [ + 0.0, + 0.17193973064422607, + 0.3387582004070282, + 0.4954867959022522, + 0.6374572515487671, + 0.7604410648345947, + 0.8607749938964844, + 0.9354707598686218, + 0.9823034405708313, + 0.9998781681060791, + 0.987671434879303, + 0.9460468888282776, + 0.8762442469596863, + 0.7803425788879395, + 0.6611984372138977, + 0.5223604440689087, + 0.36796411871910095, + 0.20260751247406006, + 0.031216468662023544, + -0.14110437035560608, + -0.30922240018844604, + -0.4681302309036255, + -0.6130947470664978, + -0.739798367023468, + -0.8444667458534241, + -0.9239829182624817, + -0.9759779572486877, + -0.9989035725593567, + -0.9920767545700073, + -0.9557008147239685, + -0.8908594250679016, + -0.7994834780693054, + -0.6842954754829407, + -0.5487247109413147, + -0.3968108892440796, + -0.23307804763317108, + -0.06240250915288925, + 0.11013169586658478, + 0.27938514947891235, + 0.4403175711631775, + 0.5881346464157104, + 0.7184345126152039, + 0.8273354768753052, + 0.9115943908691406, + 0.9687012434005737, + 0.9969553351402283, + 0.9955150485038757, + 0.9644232988357544, + 0.9046062231063843, + 0.8178454637527466, + 0.7067245244979858, + 0.5745543837547302, + 0.42527109384536743, + 0.2633211612701416, + 0.09352725744247437, + -0.0790514424443245, + -0.24927560985088348, + -0.41207510232925415, + -0.5626016855239868, + -0.6963703632354736, + -0.8093976974487305, + -0.8983170986175537, + -0.9604803919792175, + -0.9940353035926819, + -0.9979831576347351, + -0.972205638885498, + -0.9174711108207703, + -0.8354097604751587, + -0.728465735912323, + -0.5998243689537048, + -0.45331722497940063, + -0.2933071553707123, + -0.12456178665161133, + 0.04789461940526962, + 0.21892352402210236, + 0.3834318220615387, + 0.5365195870399475, + 0.6736271381378174, + 0.7906712889671326, + 0.8841646313667297, + 0.9513231515884399, + 0.9901465177536011, + 0.9994783401489258, + 0.9790406227111816, + 0.9294421672821045, + 0.8521597981452942, + 0.7494962215423584, + 0.6245089769363403, + 0.4809207022190094, + 0.3230081796646118, + 0.1554739624261856, + -0.016690155491232872, + -0.18835808336734772, + -0.35441482067108154, + -0.509915292263031, + -0.650227963924408, + -0.7711736559867859, + -0.869149923324585, + -0.9412385821342468, + -0.9852927923202515 + ], + "mel_shape": [ + 98, + 128 + ], + "mel_first_frame": [ + -3.2294204235076904, + -3.2791876792907715, + -3.546854257583618, + -3.070650815963745, + -2.916653633117676, + -3.332967519760132, + -2.841541290283203, + -2.427978038787842, + -3.0882887840270996, + -2.4122564792633057 + ], + "mel_last_frame": [ + -3.5696492195129395, + -4.237917900085449, + -3.4619839191436768, + -3.0402071475982666, + -3.269505500793457, + -3.438852310180664, + -2.751974105834961, + -2.5358614921569824, + -3.175671339035034, + -2.3679933547973633 + ], + "mel_mean": -4.209447860717773, + "mel_std": 3.022634267807007, + "mel_min": -6.907755374908447, + "mel_max": 4.728574275970459 +} \ No newline at end of file diff --git a/Tests/MLXLMTests/Fixtures/gemma4_token_alignment.json b/Tests/MLXLMTests/Fixtures/gemma4_token_alignment.json new file mode 100644 index 000000000..995efdd19 --- /dev/null +++ b/Tests/MLXLMTests/Fixtures/gemma4_token_alignment.json @@ -0,0 +1 @@ +{"model": "mlx-community/gemma-4-e2b-it-4bit", "prompt": "Transcribe this audio verbatim.", "audio_file": "7E9A42BA-8AD9-44D7-BED3-95BAEDA2B699.m4a", "output_text": "GenerationResult(text='好的,我來更新100。', token=106, logprobs=array([-33.5, -11.4375, -17.75, ..., -33.5, -33.75, -33.25], dtype=bfloat16), prompt_tokens=64, generation_tokens=10, total_tokens=74, prompt_t"} \ No newline at end of file diff --git a/Tests/MLXLMTests/Gemma4AudioAlignmentTest.swift b/Tests/MLXLMTests/Gemma4AudioAlignmentTest.swift new file mode 100644 index 000000000..7ff7183d8 --- /dev/null +++ b/Tests/MLXLMTests/Gemma4AudioAlignmentTest.swift @@ -0,0 +1,148 @@ +// +// Gemma4AudioAlignmentTest.swift +// MLXVLMTests +// +// Verify Swift mel spectrogram output matches Python reference data. +// + +import Foundation +import MLX +import Testing + +@testable import MLXVLM + +@Suite("Gemma4 Audio Alignment") +struct Gemma4AudioAlignmentTest { + + /// Load JSON fixture from Tests/MLXVLMTests/Fixtures/ + private func loadFixture(_ name: String) throws -> [String: Any] { + // #filePath points to source file location + let sourceDir = URL(fileURLWithPath: #filePath).deletingLastPathComponent() + let fixtureURL = sourceDir.appendingPathComponent("Fixtures/\(name)") + guard FileManager.default.fileExists(atPath: fixtureURL.path) else { + throw NSError( + domain: "test", code: 1, + userInfo: [NSLocalizedDescriptionKey: "Fixture not found: \(fixtureURL.path)"]) + } + let data = try Data(contentsOf: fixtureURL) + return try JSONSerialization.jsonObject(with: data) as! [String: Any] + } + + @Test + func melSpectrogramShapeAndStats() throws { + // Load Python reference: 0.5s 440Hz sine wave + let ref = try loadFixture("gemma4_mel_alignment.json") + let refShape = ref["mel_shape"] as! [Int] // [49, 128] + let refStats = ref["mel_stats"] as! [String: Double] + + // Generate same input in Swift + var audio = [Float](repeating: 0, count: 8000) + for i in 0 ..< 8000 { + audio[i] = sin(2.0 * .pi * 440.0 * Float(i) / 16000.0) + } + + let extractor = Gemma4AudioFeatureExtractor( + featureSize: 128, + samplingRate: 16000, + frameLengthMs: 20.0, + hopLengthMs: 10.0, + minFrequency: 0.0, + maxFrequency: 8000.0, + preemphasis: 0.0, + preemphasisHTKFlavor: true, + fftOverdrive: true, + inputScaleFactor: 1.0, + melFloor: 1e-3 + ) + + let (mel, mask) = extractor.extract(audio: audio) + eval(mel, mask) + + let shape = mel.shape + print("Swift mel shape: \(shape), Python ref: \(refShape)") + + // Shape should match + #expect(shape[0] == refShape[0], "Frame count mismatch: \(shape[0]) vs \(refShape[0])") + #expect(shape[1] == refShape[1], "Mel bins mismatch: \(shape[1]) vs \(refShape[1])") + + // Stats should be close + let swiftMean = Double(mel.mean().item(Float.self)) + let swiftStd = Double(MLX.sqrt(mel.variance()).item(Float.self)) + let refMean = refStats["mean"]! + let refStd = refStats["std"]! + + print("Swift mean=\(swiftMean), Python mean=\(refMean)") + print("Swift std=\(swiftStd), Python std=\(refStd)") + + #expect(abs(swiftMean - refMean) < 0.5, "Mean too far: \(swiftMean) vs \(refMean)") + #expect(abs(swiftStd - refStd) < 0.5, "Std too far: \(swiftStd) vs \(refStd)") + } + + @Test + func melSpectrogramFrameValues() throws { + // Compare first few frames against Python values + let ref = try loadFixture("gemma4_mel_alignment.json") + let refFrames = ref["mel_frames_0_to_9"] as! [[Double]] // 10 frames × 128 bins + + var audio = [Float](repeating: 0, count: 8000) + for i in 0 ..< 8000 { + audio[i] = sin(2.0 * .pi * 440.0 * Float(i) / 16000.0) + } + + let extractor = Gemma4AudioFeatureExtractor() + let (mel, _) = extractor.extract(audio: audio) + eval(mel) + + // Compare first frame + let firstFrame = mel[0] + eval(firstFrame) + + var maxDiff: Float = 0 + for i in 0 ..< min(10, refFrames[0].count) { + let swiftVal = firstFrame[i].item(Float.self) + let pyVal = Float(refFrames[0][i]) + let diff = abs(swiftVal - pyVal) + if diff > maxDiff { maxDiff = diff } + } + + print("Max diff in first frame (first 10 bins): \(maxDiff)") + // Allow some tolerance due to FFT implementation differences + #expect(maxDiff < 1.0, "First frame values too far from Python: maxDiff=\(maxDiff)") + } + + @Test + func melFilterBankShape() { + // 512 FFT → 257 frequency bins, 128 mel filters + let bank = gemma4MelFilterBank( + numFrequencyBins: 257, + numMelFilters: 128, + minFrequency: 0, + maxFrequency: 8000, + samplingRate: 16000 + ) + eval(bank) + + #expect(bank.shape == [257, 128]) + + // Filter bank should be non-negative + let minVal = bank.min().item(Float.self) + #expect(minVal >= 0, "Filter bank has negative values: \(minVal)") + + // Filter coverage: at this configuration (128 mel filters over 0–8000 Hz + // with a 512-point FFT → 31.25 Hz bin spacing), the lowest triangular + // filter has upper edge ≈27.9 Hz < bin spacing, so no FFT bin lands + // inside it and that filter column is legitimately all-zero. This + // matches HTK's unnormalized mel filter bank definition (and what + // librosa produces with `htk=True, norm=None`). Assert that at most + // one filter is empty and that all filters from index 1 onward carry + // non-zero coefficients. + let colSums = bank.sum(axis: 0) + eval(colSums) + let colSumsArray = colSums.asArray(Float.self) + let zeroCount = colSumsArray.filter { $0 == 0 }.count + #expect(zeroCount <= 1, "Too many all-zero mel filters: \(zeroCount)") + for i in 1 ..< colSumsArray.count { + #expect(colSumsArray[i] > 0, "Filter \(i) is unexpectedly all-zero") + } + } +} diff --git a/Tests/MLXLMTests/Gemma4AudioTests.swift b/Tests/MLXLMTests/Gemma4AudioTests.swift new file mode 100644 index 000000000..cddc635ca --- /dev/null +++ b/Tests/MLXLMTests/Gemma4AudioTests.swift @@ -0,0 +1,157 @@ +// +// Gemma4AudioTests.swift +// MLXVLMTests +// +// Tests for Gemma4 audio tower — MelSpectrogram, AudioEncoder, token merging. +// These tests verify the Swift port against the Python mlx-vlm implementation. +// + +import Foundation +import MLX +import MLXNN +import Testing + +@testable import MLXVLM + +// MARK: - Unit Tests + +@Suite("Gemma4 Audio Tower") +struct Gemma4AudioTests { + + // MARK: - Configuration + + @Test + func audioConfigurationDecoding() throws { + // Verify AudioConfiguration can decode from model config JSON + let json = """ + { + "audio_token_id": 262277, + "audio_config": { + "model_type": "gemma4_audio", + "num_mel_bins": 128, + "encoder_layers": 32, + "encoder_attention_heads": 8, + "input_feat_per_channel": 128, + "encoder_dim": 1024, + "encoder_ffn_dim": 4096, + "dropout": 0.0, + "conv_kernel_sizes": [5, 5, 5], + "conv_channels": 1024, + "num_audio_tokens": 750 + } + } + """.data(using: .utf8)! + + // This test will compile once AudioConfiguration is added to Gemma4.swift + // let config = try JSONDecoder().decode(Gemma4AudioConfig.self, from: json) + // #expect(config.numMelBins == 128) + // #expect(config.encoderLayers == 32) + // #expect(config.numAudioTokens == 750) + } + + // MARK: - Mel Spectrogram + + @Test + func melSpectrogramShape() throws { + // 1 second of 16kHz audio = 16000 samples + // Expected output: ~100 frames × 128 mel bins (hop=160, window=400) + let audio = MLXArray.zeros([16000]) + + // let mel = Gemma4MelSpectrogram(numMelBins: 128, sampleRate: 16000) + // let features = mel(audio) + // #expect(features.dim(0) > 90) // ~100 frames + // #expect(features.dim(1) == 128) // 128 mel bins + } + + @Test + func melSpectrogramDeterministic() throws { + // Same input should produce same output + let audio = MLXArray(Array(repeating: Float(0.5), count: 8000)) + + // let mel = Gemma4MelSpectrogram(numMelBins: 128, sampleRate: 16000) + // let out1 = mel(audio) + // let out2 = mel(audio) + // #expect(out1.isClose(out2, atol: 1e-6).all().item(Bool.self)) + } + + // MARK: - Audio Encoder + + @Test + func audioEncoderOutputShape() throws { + // Audio encoder takes mel features and outputs encoder hidden states + // Input: [batch, frames, mel_bins] + // Output: [batch, num_audio_tokens, encoder_dim] + + // let config = Gemma4AudioConfig(...) + // let encoder = Gemma4AudioEncoder(config: config) + // let melInput = MLXArray.zeros([1, 100, 128]) + // let output = encoder(melInput) + // #expect(output.dim(0) == 1) + // #expect(output.dim(2) == config.encoderDim) + } + + // MARK: - Token Merging + + @Test + func audioTokenMerging() throws { + // Verify audio tokens are correctly merged into the input embedding sequence + // Audio tokens should replace <|audio|> placeholder tokens in the input + + // let textTokens = MLXArray([1, 2, 262277, 262277, 262277, 3, 4]) // 262277 = audio_token_id + // let audioFeatures = MLXArray.zeros([1, 3, 1024]) // 3 audio tokens + // let textEmbeddings = model.embedTokens(textTokens) + // let merged = model.mergeAudioFeatures(textEmbeddings, audioFeatures, textTokens) + // #expect(merged.dim(1) == textTokens.dim(0)) + } +} + +// MARK: - Integration Tests (require model download) + +@Suite("Gemma4 Audio Integration") +struct Gemma4AudioIntegrationTests { + + @Test + func endToEndTranscription() async throws { + // Load model and transcribe a test audio file + // This test requires the model to be downloaded + + // let model = try await loadGemma4Model("mlx-community/gemma-4-e2b-it-4bit") + // let audio = loadWAV("test_audio.wav") + // let result = model.generate(audio: audio, prompt: "Transcribe this audio.") + // #expect(!result.isEmpty) + // #expect(!result.contains("[ERROR")) + } + + @Test + func audioTokenCount() async throws { + // Verify audio token count scales with duration (~40ms per token, max 750) + // 1s audio → ~25 tokens + // 10s audio → ~250 tokens + // 30s audio → ~750 tokens (max) + + // let audio1s = MLXArray.zeros([16000]) + // let audio10s = MLXArray.zeros([160000]) + // let tokens1 = model.audioEncoder.getAudioTokenCount(audio1s) + // let tokens10 = model.audioEncoder.getAudioTokenCount(audio10s) + // #expect(tokens1 > 20 && tokens1 < 30) + // #expect(tokens10 > 200 && tokens10 < 300) + } +} + +// MARK: - Python Alignment Tests + +@Suite("Gemma4 Audio Python Alignment") +struct Gemma4AudioAlignmentTests { + + @Test + func melOutputMatchesPython() throws { + // Compare Swift mel spectrogram output with pre-computed Python output + // Load reference data from a .npy or .json fixture file + + // let refMel = loadReference("mel_reference.json") + // let audio = loadReference("audio_input.json") + // let mel = Gemma4MelSpectrogram(...) + // let swiftMel = mel(MLXArray(audio)) + // #expect(swiftMel.isClose(MLXArray(refMel), atol: 1e-4).all().item(Bool.self)) + } +} From d2e6490870cd2f14b233d0605292373a6f0f89b0 Mon Sep 17 00:00:00 2001 From: antmanler Date: Tue, 14 Apr 2026 21:46:45 +0800 Subject: [PATCH 5/5] fix(gemma4): use audio encoder output dim for embed_audio projection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Without this fix, loading the gemma-4-e4b-it-4bit weights crashes with mismatchedSize: embed_audio.embedding_projection.biases expected [2560, 16] but got [2560, 24]. The encoder's actual output is the projected hidden_size (1536 for E4B via outputProjDims) rather than the bare 1024 hidden_size, so the embedder must project from the encoder output dim, not the bare hidden size. This was working correctly on the original platx-ai/mlx-swift-lm@668347f branch but I dropped the outputProjDims fallback when reapplying the audio block onto upstream main. Restoring the original logic. Verified end-to-end via standalone SPM test loading gemma-4-e4b-it-4bit and transcribing a 7.27s reference clip: prepare: 0.129s generation: 0.539s (20 tokens) audioShape: [1, 726, 128] matches 668347f baseline transcript: 我现在已经切換到詹姆斯四。這個就是詹姆斯四截屏的短視頻。 baseline: 我现在已经切换到詹姆斯四。这个就是詹姆斯四接情的这段时期。 --- Libraries/MLXVLM/Models/Gemma4.swift | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/Libraries/MLXVLM/Models/Gemma4.swift b/Libraries/MLXVLM/Models/Gemma4.swift index b306a935b..e1444a222 100644 --- a/Libraries/MLXVLM/Models/Gemma4.swift +++ b/Libraries/MLXVLM/Models/Gemma4.swift @@ -2435,8 +2435,13 @@ public final class Gemma4: Module, VLMModel, KVCacheDimensionProvider { ) if let audioConfig = config.audioConfiguration { self._audioTower.wrappedValue = Gemma4AudioEncoder(config: audioConfig) + // The audio encoder's output dimension is outputProjDims if the + // encoder includes an output projection layer, otherwise the + // bare hidden size. This MUST match the actual tensor shape + // that audioTower(...) returns. + let audioOutputDim = audioConfig.outputProjDims ?? audioConfig.hiddenSize self._embedAudio.wrappedValue = Gemma4MultimodalEmbedder( - embeddingDim: audioConfig.hiddenSize, + embeddingDim: audioOutputDim, textHiddenSize: config.textConfiguration.hiddenSize, eps: audioConfig.rmsNormEps )